From 7d60941dbefbf8fc526052fa1398b53b788e8c02 Mon Sep 17 00:00:00 2001 From: Tze-Yang Tung Date: Fri, 13 Sep 2024 10:05:25 -0400 Subject: [PATCH 1/7] adding litellm integration into toolkit --- notdiamond/toolkit/litellm.py | 6251 ++++++++++++++++++++++ notdiamond/toolkit/litellm_notdiamond.py | 266 + poetry.lock | 50 +- tests/test_toolkit/test_litellm.py | 216 + 4 files changed, 6780 insertions(+), 3 deletions(-) create mode 100644 notdiamond/toolkit/litellm.py create mode 100644 notdiamond/toolkit/litellm_notdiamond.py create mode 100644 tests/test_toolkit/test_litellm.py diff --git a/notdiamond/toolkit/litellm.py b/notdiamond/toolkit/litellm.py new file mode 100644 index 00000000..591bf569 --- /dev/null +++ b/notdiamond/toolkit/litellm.py @@ -0,0 +1,6251 @@ +# flake8: noqa + +import asyncio +import contextvars +import datetime +import inspect +import json +import os +import random +import sys +import threading +import time +import traceback +import uuid +from concurrent.futures import ThreadPoolExecutor +from copy import deepcopy +from functools import partial +from typing import ( + Any, + Callable, + Dict, + List, + Literal, + Mapping, + Optional, + Tuple, + Type, + Union, +) + +import dotenv +import httpx +import litellm +import openai +import tiktoken +from litellm import ( # type: ignore + Logging, + client, + exception_type, + get_litellm_params, + get_optional_params, +) +from litellm._logging import verbose_logger +from litellm.caching import disable_cache, enable_cache, update_cache +from litellm.integrations.custom_logger import CustomLogger +from litellm.litellm_core_utils.litellm_logging import ( + Logging as LiteLLMLoggingObj, +) +from litellm.llms import ( + aleph_alpha, + baseten, + clarifai, + cloudflare, + maritalk, + nlp_cloud, + ollama, + ollama_chat, + oobabooga, + openrouter, + palm, + petals, + replicate, + vllm, +) +from litellm.llms.AI21 import completion as ai21 +from litellm.llms.anthropic.chat import AnthropicChatCompletion +from litellm.llms.anthropic.completion import AnthropicTextCompletion +from litellm.llms.azure_text import AzureTextCompletion +from litellm.llms.AzureOpenAI.audio_transcriptions import ( + AzureAudioTranscription, +) +from litellm.llms.AzureOpenAI.azure import ( + AzureChatCompletion, + _check_dynamic_azure_params, +) +from litellm.llms.bedrock import ( + image_generation as bedrock_image_generation, # type: ignore +) +from litellm.llms.bedrock.chat import BedrockConverseLLM, BedrockLLM +from litellm.llms.bedrock.embed.embedding import BedrockEmbedding +from litellm.llms.cohere import chat as cohere_chat +from litellm.llms.cohere import completion as cohere_completion # type: ignore +from litellm.llms.cohere import embed as cohere_embed +from litellm.llms.custom_llm import CustomLLM, custom_chat_llm_router +from litellm.llms.databricks.chat import DatabricksChatCompletion +from litellm.llms.huggingface_restapi import Huggingface +from litellm.llms.OpenAI.audio_transcriptions import OpenAIAudioTranscription +from litellm.llms.OpenAI.openai import ( + OpenAIChatCompletion, + OpenAITextCompletion, +) +from litellm.llms.predibase import PredibaseChatCompletion +from litellm.llms.prompt_templates.factory import ( + custom_prompt, + function_call_prompt, + map_system_message_pt, + prompt_factory, + stringify_json_tool_call_content, +) +from litellm.llms.sagemaker.sagemaker import SagemakerLLM +from litellm.llms.text_completion_codestral import CodestralTextCompletion +from litellm.llms.triton import TritonChatCompletion +from litellm.llms.vertex_ai_and_google_ai_studio import ( + vertex_ai_anthropic, + vertex_ai_non_gemini, +) +from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import ( + VertexLLM, +) +from litellm.llms.vertex_ai_and_google_ai_studio.gemini_embeddings.batch_embed_content_handler import ( + GoogleBatchEmbeddings, +) +from litellm.llms.vertex_ai_and_google_ai_studio.image_generation.image_generation_handler import ( + VertexImageGeneration, +) +from litellm.llms.vertex_ai_and_google_ai_studio.multimodal_embeddings.embedding_handler import ( + VertexMultimodalEmbedding, +) +from litellm.llms.vertex_ai_and_google_ai_studio.text_to_speech.text_to_speech_handler import ( + VertexTextToSpeechAPI, +) +from litellm.llms.vertex_ai_and_google_ai_studio.vertex_ai_partner_models.main import ( + VertexAIPartnerModels, +) +from litellm.llms.vertex_ai_and_google_ai_studio.vertex_embeddings import ( + embedding_handler as vertex_ai_embedding_handler, +) +from litellm.llms.watsonx import IBMWatsonXAI +from litellm.types.llms.openai import HttpxBinaryResponseContent +from litellm.types.utils import ( + AdapterCompletionStreamWrapper, + ChatCompletionMessageToolCall, + FileTypes, + HiddenParams, + all_litellm_params, +) +from litellm.utils import ( + CustomStreamWrapper, + Usage, + async_mock_completion_streaming_obj, + completion_with_fallbacks, + convert_to_model_response_object, + create_pretrained_tokenizer, + create_tokenizer, + get_optional_params_embeddings, + get_optional_params_image_gen, + get_optional_params_transcription, + get_secret, + mock_completion_streaming_obj, + read_config_args, + supports_httpx_timeout, + token_counter, +) +from pydantic import BaseModel +from typing_extensions import overload + +encoding = tiktoken.get_encoding("cl100k_base") +from litellm.types.router import LiteLLM_Params +from litellm.utils import ( + Choices, + CustomStreamWrapper, + EmbeddingResponse, + ImageResponse, + Message, + ModelResponse, + TextChoices, + TextCompletionResponse, + TextCompletionStreamWrapper, + TranscriptionResponse, + get_secret, + read_config_args, +) + +from .litellm_notdiamond import completion as notdiamond_completion + +openai_chat_completions = OpenAIChatCompletion() +openai_text_completions = OpenAITextCompletion() +openai_audio_transcriptions = OpenAIAudioTranscription() +databricks_chat_completions = DatabricksChatCompletion() +anthropic_chat_completions = AnthropicChatCompletion() +anthropic_text_completions = AnthropicTextCompletion() +azure_chat_completions = AzureChatCompletion() +azure_text_completions = AzureTextCompletion() +azure_audio_transcriptions = AzureAudioTranscription() +huggingface = Huggingface() +predibase_chat_completions = PredibaseChatCompletion() +codestral_text_completions = CodestralTextCompletion() +triton_chat_completions = TritonChatCompletion() +bedrock_chat_completion = BedrockLLM() +bedrock_converse_chat_completion = BedrockConverseLLM() +bedrock_embedding = BedrockEmbedding() +vertex_chat_completion = VertexLLM() +vertex_multimodal_embedding = VertexMultimodalEmbedding() +vertex_image_generation = VertexImageGeneration() +google_batch_embeddings = GoogleBatchEmbeddings() +vertex_partner_models_chat_completion = VertexAIPartnerModels() +vertex_text_to_speech = VertexTextToSpeechAPI() +watsonxai = IBMWatsonXAI() +sagemaker_llm = SagemakerLLM() + +litellm.provider_list.append("notdiamond") +litellm.notdiamond_key = None + + +class LiteLLM: + def __init__( + self, + *, + api_key=None, + organization: Optional[str] = None, + base_url: Optional[str] = None, + timeout: Optional[float] = 600, + max_retries: Optional[int] = litellm.num_retries, + default_headers: Optional[Mapping[str, str]] = None, + ): + self.params = locals() + self.chat = Chat(self.params, router_obj=None) + + +class Chat: + def __init__(self, params, router_obj: Optional[Any]): + self.params = params + if self.params.get("acompletion", False) == True: + self.params.pop("acompletion") + self.completions: Union[ + AsyncCompletions, Completions + ] = AsyncCompletions(self.params, router_obj=router_obj) + else: + self.completions = Completions(self.params, router_obj=router_obj) + + +class Completions: + def __init__(self, params, router_obj: Optional[Any]): + self.params = params + self.router_obj = router_obj + + def create(self, messages, model=None, **kwargs): + for k, v in kwargs.items(): + self.params[k] = v + model = model or self.params.get("model") + if self.router_obj is not None: + response = self.router_obj.completion( + model=model, messages=messages, **self.params + ) + else: + response = completion( + model=model, messages=messages, **self.params + ) + return response + + +class AsyncCompletions: + def __init__(self, params, router_obj: Optional[Any]): + self.params = params + self.router_obj = router_obj + + async def create(self, messages, model=None, **kwargs): + for k, v in kwargs.items(): + self.params[k] = v + model = model or self.params.get("model") + if self.router_obj is not None: + response = await self.router_obj.acompletion( + model=model, messages=messages, **self.params + ) + else: + response = await acompletion( + model=model, messages=messages, **self.params + ) + return response + + +def get_api_key(llm_provider: str, dynamic_api_key: Optional[str]): + api_key = dynamic_api_key or litellm.api_key + # openai + if llm_provider == "openai" or llm_provider == "text-completion-openai": + api_key = api_key or litellm.openai_key or get_secret("OPENAI_API_KEY") + # anthropic + elif llm_provider == "anthropic": + api_key = ( + api_key or litellm.anthropic_key or get_secret("ANTHROPIC_API_KEY") + ) + # ai21 + elif llm_provider == "ai21": + api_key = api_key or litellm.ai21_key or get_secret("AI211_API_KEY") + # aleph_alpha + elif llm_provider == "aleph_alpha": + api_key = ( + api_key + or litellm.aleph_alpha_key + or get_secret("ALEPH_ALPHA_API_KEY") + ) + # baseten + elif llm_provider == "baseten": + api_key = ( + api_key or litellm.baseten_key or get_secret("BASETEN_API_KEY") + ) + # cohere + elif llm_provider == "cohere" or llm_provider == "cohere_chat": + api_key = api_key or litellm.cohere_key or get_secret("COHERE_API_KEY") + # huggingface + elif llm_provider == "huggingface": + api_key = ( + api_key + or litellm.huggingface_key + or get_secret("HUGGINGFACE_API_KEY") + ) + # notdiamond + elif llm_provider == "notdiamond": + api_key = ( + api_key + or litellm.notdiamond_key + or get_secret("NOTDIAMOND_API_KEY") + ) + # nlp_cloud + elif llm_provider == "nlp_cloud": + api_key = ( + api_key or litellm.nlp_cloud_key or get_secret("NLP_CLOUD_API_KEY") + ) + # replicate + elif llm_provider == "replicate": + api_key = ( + api_key or litellm.replicate_key or get_secret("REPLICATE_API_KEY") + ) + # together_ai + elif llm_provider == "together_ai": + api_key = ( + api_key + or litellm.togetherai_api_key + or get_secret("TOGETHERAI_API_KEY") + or get_secret("TOGETHER_AI_TOKEN") + ) + return api_key + + +def get_llm_provider( + model: str, + custom_llm_provider: Optional[str] = None, + api_base: Optional[str] = None, + api_key: Optional[str] = None, + litellm_params: Optional[LiteLLM_Params] = None, +) -> Tuple[str, str, Optional[str], Optional[str]]: + """ + Returns the provider for a given model name - e.g. 'azure/chatgpt-v-2' -> 'azure' + + For router -> Can also give the whole litellm param dict -> this function will extract the relevant details + + Raises Error - if unable to map model to a provider + """ + try: + ## IF LITELLM PARAMS GIVEN ## + if litellm_params is not None: + assert ( + custom_llm_provider is None + and api_base is None + and api_key is None + ), "Either pass in litellm_params or the custom_llm_provider/api_base/api_key. Otherwise, these values will be overriden." + custom_llm_provider = litellm_params.custom_llm_provider + api_base = litellm_params.api_base + api_key = litellm_params.api_key + + dynamic_api_key = None + # check if llm provider provided + # AZURE AI-Studio Logic - Azure AI Studio supports AZURE/Cohere + # If User passes azure/command-r-plus -> we should send it to cohere_chat/command-r-plus + if model.split("/", 1)[0] == "azure": + if _is_non_openai_azure_model(model): + custom_llm_provider = "openai" + return model, custom_llm_provider, dynamic_api_key, api_base + + if custom_llm_provider: + return model, custom_llm_provider, dynamic_api_key, api_base + + if api_key and api_key.startswith("os.environ/"): + dynamic_api_key = get_secret(api_key) + # check if llm provider part of model name + if ( + model.split("/", 1)[0] in litellm.provider_list + and model.split("/", 1)[0] not in litellm.model_list + and len(model.split("/")) + > 1 # handle edge case where user passes in `litellm --model mistral` https://github.com/BerriAI/litellm/issues/1351 + ): + custom_llm_provider = model.split("/", 1)[0] + model = model.split("/", 1)[1] + if custom_llm_provider == "perplexity": + # perplexity is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.perplexity.ai + api_base = api_base or "https://api.perplexity.ai" + dynamic_api_key = api_key or get_secret("PERPLEXITYAI_API_KEY") + elif custom_llm_provider == "anyscale": + # anyscale is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.endpoints.anyscale.com/v1 + api_base = api_base or "https://api.endpoints.anyscale.com/v1" + dynamic_api_key = api_key or get_secret("ANYSCALE_API_KEY") + elif custom_llm_provider == "deepinfra": + # deepinfra is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.endpoints.anyscale.com/v1 + api_base = api_base or "https://api.deepinfra.com/v1/openai" + dynamic_api_key = api_key or get_secret("DEEPINFRA_API_KEY") + elif custom_llm_provider == "empower": + api_base = api_base or "https://app.empower.dev/api/v1" + dynamic_api_key = api_key or get_secret("EMPOWER_API_KEY") + elif custom_llm_provider == "groq": + # groq is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.groq.com/openai/v1 + api_base = api_base or "https://api.groq.com/openai/v1" + dynamic_api_key = api_key or get_secret("GROQ_API_KEY") + elif custom_llm_provider == "nvidia_nim": + # nvidia_nim is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.endpoints.anyscale.com/v1 + api_base = api_base or "https://integrate.api.nvidia.com/v1" + dynamic_api_key = api_key or get_secret("NVIDIA_NIM_API_KEY") + elif custom_llm_provider == "volcengine": + # volcengine is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.endpoints.anyscale.com/v1 + api_base = ( + api_base or "https://ark.cn-beijing.volces.com/api/v3" + ) + dynamic_api_key = api_key or get_secret("VOLCENGINE_API_KEY") + elif custom_llm_provider == "codestral": + # codestral is openai compatible, we just need to set this to custom_openai and have the api_base be https://codestral.mistral.ai/v1 + api_base = api_base or "https://codestral.mistral.ai/v1" + dynamic_api_key = api_key or get_secret("CODESTRAL_API_KEY") + elif custom_llm_provider == "deepseek": + # deepseek is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.deepseek.com/v1 + api_base = api_base or "https://api.deepseek.com/v1" + dynamic_api_key = api_key or get_secret("DEEPSEEK_API_KEY") + elif custom_llm_provider == "fireworks_ai": + # fireworks is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.fireworks.ai/inference/v1 + if not model.startswith("accounts/"): + model = f"accounts/fireworks/models/{model}" + api_base = api_base or "https://api.fireworks.ai/inference/v1" + dynamic_api_key = api_key or ( + get_secret("FIREWORKS_API_KEY") + or get_secret("FIREWORKS_AI_API_KEY") + or get_secret("FIREWORKSAI_API_KEY") + or get_secret("FIREWORKS_AI_TOKEN") + ) + elif custom_llm_provider == "azure_ai": + api_base = api_base or get_secret("AZURE_AI_API_BASE") # type: ignore + dynamic_api_key = api_key or get_secret("AZURE_AI_API_KEY") + elif custom_llm_provider == "mistral": + # mistral is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.mistral.ai + api_base = ( + api_base + or get_secret( + "MISTRAL_AZURE_API_BASE" + ) # for Azure AI Mistral + or "https://api.mistral.ai/v1" + ) # type: ignore + + # if api_base does not end with /v1 we add it + if api_base is not None and not api_base.endswith( + "/v1" + ): # Mistral always needs a /v1 at the end + api_base = api_base + "/v1" + dynamic_api_key = ( + api_key + or get_secret( + "MISTRAL_AZURE_API_KEY" + ) # for Azure AI Mistral + or get_secret("MISTRAL_API_KEY") + ) + elif custom_llm_provider == "voyage": + # voyage is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.voyageai.com/v1 + api_base = "https://api.voyageai.com/v1" + dynamic_api_key = api_key or get_secret("VOYAGE_API_KEY") + elif custom_llm_provider == "together_ai": + api_base = "https://api.together.xyz/v1" + dynamic_api_key = api_key or ( + get_secret("TOGETHER_API_KEY") + or get_secret("TOGETHER_AI_API_KEY") + or get_secret("TOGETHERAI_API_KEY") + or get_secret("TOGETHER_AI_TOKEN") + ) + elif custom_llm_provider == "friendliai": + api_base = ( + api_base + or get_secret("FRIENDLI_API_BASE") + or "https://inference.friendli.ai/v1" + ) + dynamic_api_key = ( + api_key + or get_secret("FRIENDLIAI_API_KEY") + or get_secret("FRIENDLI_TOKEN") + ) + elif custom_llm_provider == "notdiamond": + api_base = "https://not-diamond-server.onrender.com/v2/optimizer/modelSelect" + dynamic_api_key = get_secret("NOTDIAMOND_API_KEY") or None + if api_base is not None and not isinstance(api_base, str): + raise Exception( + "api base needs to be a string. api_base={}".format( + api_base + ) + ) + if dynamic_api_key is not None and not isinstance( + dynamic_api_key, str + ): + raise Exception( + "dynamic_api_key needs to be a string. dynamic_api_key={}".format( + dynamic_api_key + ) + ) + return model, custom_llm_provider, dynamic_api_key, api_base + elif model.split("/", 1)[0] in litellm.provider_list: + custom_llm_provider = model.split("/", 1)[0] + model = model.split("/", 1)[1] + if api_base is not None and not isinstance(api_base, str): + raise Exception( + "api base needs to be a string. api_base={}".format( + api_base + ) + ) + if dynamic_api_key is not None and not isinstance( + dynamic_api_key, str + ): + raise Exception( + "dynamic_api_key needs to be a string. dynamic_api_key={}".format( + dynamic_api_key + ) + ) + return model, custom_llm_provider, dynamic_api_key, api_base + # check if api base is a known openai compatible endpoint + if api_base: + for endpoint in litellm.openai_compatible_endpoints: + if endpoint in api_base: + if endpoint == "api.perplexity.ai": + custom_llm_provider = "perplexity" + dynamic_api_key = get_secret("PERPLEXITYAI_API_KEY") + elif endpoint == "api.endpoints.anyscale.com/v1": + custom_llm_provider = "anyscale" + dynamic_api_key = get_secret("ANYSCALE_API_KEY") + elif endpoint == "api.deepinfra.com/v1/openai": + custom_llm_provider = "deepinfra" + dynamic_api_key = get_secret("DEEPINFRA_API_KEY") + elif endpoint == "api.mistral.ai/v1": + custom_llm_provider = "mistral" + dynamic_api_key = get_secret("MISTRAL_API_KEY") + elif endpoint == "api.groq.com/openai/v1": + custom_llm_provider = "groq" + dynamic_api_key = get_secret("GROQ_API_KEY") + elif endpoint == "https://integrate.api.nvidia.com/v1": + custom_llm_provider = "nvidia_nim" + dynamic_api_key = get_secret("NVIDIA_NIM_API_KEY") + elif endpoint == "https://codestral.mistral.ai/v1": + custom_llm_provider = "codestral" + dynamic_api_key = get_secret("CODESTRAL_API_KEY") + elif endpoint == "https://codestral.mistral.ai/v1": + custom_llm_provider = "text-completion-codestral" + dynamic_api_key = get_secret("CODESTRAL_API_KEY") + elif endpoint == "app.empower.dev/api/v1": + custom_llm_provider = "empower" + dynamic_api_key = get_secret("EMPOWER_API_KEY") + elif endpoint == "api.deepseek.com/v1": + custom_llm_provider = "deepseek" + dynamic_api_key = get_secret("DEEPSEEK_API_KEY") + elif endpoint == "inference.friendli.ai/v1": + custom_llm_provider = "friendliai" + dynamic_api_key = get_secret( + "FRIENDLIAI_API_KEY" + ) or get_secret("FRIENDLI_TOKEN") + + if api_base is not None and not isinstance(api_base, str): + raise Exception( + "api base needs to be a string. api_base={}".format( + api_base + ) + ) + if dynamic_api_key is not None and not isinstance( + dynamic_api_key, str + ): + raise Exception( + "dynamic_api_key needs to be a string. dynamic_api_key={}".format( + dynamic_api_key + ) + ) + return model, custom_llm_provider, dynamic_api_key, api_base # type: ignore + + # check if model in known model provider list -> for huggingface models, raise exception as they don't have a fixed provider (can be togetherai, anyscale, baseten, runpod, et.) + ## openai - chatcompletion + text completion + if ( + model in litellm.open_ai_chat_completion_models + or "ft:gpt-3.5-turbo" in model + or "ft:gpt-4" in model # catches ft:gpt-4-0613, ft:gpt-4o + or model in litellm.openai_image_generation_models + ): + custom_llm_provider = "openai" + elif model in litellm.open_ai_text_completion_models: + custom_llm_provider = "text-completion-openai" + ## anthropic + elif model in litellm.anthropic_models: + custom_llm_provider = "anthropic" + ## cohere + elif ( + model in litellm.cohere_models + or model in litellm.cohere_embedding_models + ): + custom_llm_provider = "cohere" + ## cohere chat models + elif model in litellm.cohere_chat_models: + custom_llm_provider = "cohere_chat" + ## replicate + elif model in litellm.replicate_models or ( + ":" in model and len(model) > 64 + ): + model_parts = model.split(":") + if ( + len(model_parts) > 1 and len(model_parts[1]) == 64 + ): ## checks if model name has a 64 digit code - e.g. "meta/llama-2-70b-chat:02e509c789964a7ea8736978a43525956ef40397be9033abf9fd2badfe68c9e3" + custom_llm_provider = "replicate" + elif model in litellm.replicate_models: + custom_llm_provider = "replicate" + ## openrouter + elif model in litellm.openrouter_models: + custom_llm_provider = "openrouter" + ## openrouter + elif model in litellm.maritalk_models: + custom_llm_provider = "maritalk" + ## vertex - text + chat + language (gemini) models + elif ( + model in litellm.vertex_chat_models + or model in litellm.vertex_code_chat_models + or model in litellm.vertex_text_models + or model in litellm.vertex_code_text_models + or model in litellm.vertex_language_models + or model in litellm.vertex_embedding_models + or model in litellm.vertex_vision_models + ): + custom_llm_provider = "vertex_ai" + ## ai21 + elif model in litellm.ai21_models: + custom_llm_provider = "ai21" + ## aleph_alpha + elif model in litellm.aleph_alpha_models: + custom_llm_provider = "aleph_alpha" + ## baseten + elif model in litellm.baseten_models: + custom_llm_provider = "baseten" + ## nlp_cloud + elif model in litellm.nlp_cloud_models: + custom_llm_provider = "nlp_cloud" + ## petals + elif model in litellm.petals_models: + custom_llm_provider = "petals" + ## bedrock + elif ( + model in litellm.bedrock_models + or model in litellm.bedrock_embedding_models + ): + custom_llm_provider = "bedrock" + elif model in litellm.watsonx_models: + custom_llm_provider = "watsonx" + # openai embeddings + elif model in litellm.open_ai_embedding_models: + custom_llm_provider = "openai" + elif model in litellm.empower_models: + custom_llm_provider = "empower" + elif model == "*": + custom_llm_provider = "openai" + if custom_llm_provider is None or custom_llm_provider == "": + if litellm.suppress_debug_info == False: + print() # noqa + print( # noqa + "\033[1;31mProvider List: https://docs.litellm.ai/docs/providers\033[0m" # noqa + ) # noqa + print() # noqa + error_str = f"LLM Provider NOT provided. Pass in the LLM provider you are trying to call. You passed model={model}\n Pass model as E.g. For 'Huggingface' inference endpoints pass in `completion(model='huggingface/starcoder',..)` Learn more: https://docs.litellm.ai/docs/providers" + # maps to openai.NotFoundError, this is raised when openai does not recognize the llm + raise litellm.exceptions.BadRequestError( # type: ignore + message=error_str, + model=model, + response=httpx.Response( + status_code=400, + content=error_str, + request=httpx.Request(method="completion", url="https://github.com/BerriAI/litellm"), # type: ignore + ), + llm_provider="", + ) + if api_base is not None and not isinstance(api_base, str): + raise Exception( + "api base needs to be a string. api_base={}".format(api_base) + ) + if dynamic_api_key is not None and not isinstance( + dynamic_api_key, str + ): + raise Exception( + "dynamic_api_key needs to be a string. dynamic_api_key={}".format( + dynamic_api_key + ) + ) + return model, custom_llm_provider, dynamic_api_key, api_base + except Exception as e: + if isinstance(e, litellm.exceptions.BadRequestError): + raise e + else: + error_str = f"GetLLMProvider Exception - {str(e)}\n\noriginal model: {model}" + raise litellm.exceptions.BadRequestError( # type: ignore + message=f"GetLLMProvider Exception - {str(e)}\n\noriginal model: {model}", + model=model, + response=httpx.Response( + status_code=400, + content=error_str, + request=httpx.Request(method="completion", url="https://github.com/BerriAI/litellm"), # type: ignore + ), + llm_provider="", + ) + + +@client +async def acompletion( + model: str, + # Optional OpenAI params: see https://platform.openai.com/docs/api-reference/chat/create + messages: List = [], + functions: Optional[List] = None, + function_call: Optional[str] = None, + timeout: Optional[Union[float, int]] = None, + temperature: Optional[float] = None, + top_p: Optional[float] = None, + n: Optional[int] = None, + stream: Optional[bool] = None, + stream_options: Optional[dict] = None, + stop=None, + max_tokens: Optional[int] = None, + presence_penalty: Optional[float] = None, + frequency_penalty: Optional[float] = None, + logit_bias: Optional[dict] = None, + user: Optional[str] = None, + # openai v1.0+ new params + response_format: Optional[Union[dict, Type[BaseModel]]] = None, + seed: Optional[int] = None, + tools: Optional[List] = None, + tool_choice: Optional[str] = None, + parallel_tool_calls: Optional[bool] = None, + logprobs: Optional[bool] = None, + top_logprobs: Optional[int] = None, + deployment_id=None, + # set api_base, api_version, api_key + base_url: Optional[str] = None, + api_version: Optional[str] = None, + api_key: Optional[str] = None, + model_list: Optional[list] = None, # pass in a list of api_base,keys, etc. + extra_headers: Optional[dict] = None, + # Optional liteLLM function params + **kwargs, +) -> Union[ModelResponse, CustomStreamWrapper]: + """ + Asynchronously executes a litellm.completion() call for any of litellm supported llms (example gpt-4, gpt-3.5-turbo, claude-2, command-nightly) + + Parameters: + model (str): The name of the language model to use for text completion. see all supported LLMs: https://docs.litellm.ai/docs/providers/ + messages (List): A list of message objects representing the conversation context (default is an empty list). + + OPTIONAL PARAMS + functions (List, optional): A list of functions to apply to the conversation messages (default is an empty list). + function_call (str, optional): The name of the function to call within the conversation (default is an empty string). + temperature (float, optional): The temperature parameter for controlling the randomness of the output (default is 1.0). + top_p (float, optional): The top-p parameter for nucleus sampling (default is 1.0). + n (int, optional): The number of completions to generate (default is 1). + stream (bool, optional): If True, return a streaming response (default is False). + stream_options (dict, optional): A dictionary containing options for the streaming response. Only use this if stream is True. + stop(string/list, optional): - Up to 4 sequences where the LLM API will stop generating further tokens. + max_tokens (integer, optional): The maximum number of tokens in the generated completion (default is infinity). + presence_penalty (float, optional): It is used to penalize new tokens based on their existence in the text so far. + frequency_penalty: It is used to penalize new tokens based on their frequency in the text so far. + logit_bias (dict, optional): Used to modify the probability of specific tokens appearing in the completion. + user (str, optional): A unique identifier representing your end-user. This can help the LLM provider to monitor and detect abuse. + metadata (dict, optional): Pass in additional metadata to tag your completion calls - eg. prompt version, details, etc. + api_base (str, optional): Base URL for the API (default is None). + api_version (str, optional): API version (default is None). + api_key (str, optional): API key (default is None). + model_list (list, optional): List of api base, version, keys + timeout (float, optional): The maximum execution time in seconds for the completion request. + + LITELLM Specific Params + mock_response (str, optional): If provided, return a mock completion response for testing or debugging purposes (default is None). + custom_llm_provider (str, optional): Used for Non-OpenAI LLMs, Example usage for bedrock, set model="amazon.titan-tg1-large" and custom_llm_provider="bedrock" + Returns: + ModelResponse: A response object containing the generated completion and associated metadata. + + Notes: + - This function is an asynchronous version of the `completion` function. + - The `completion` function is called using `run_in_executor` to execute synchronously in the event loop. + - If `stream` is True, the function returns an async generator that yields completion lines. + """ + loop = asyncio.get_event_loop() + custom_llm_provider = kwargs.get("custom_llm_provider", None) + # Adjusted to use explicit arguments instead of *args and **kwargs + completion_kwargs = { + "model": model, + "messages": messages, + "functions": functions, + "function_call": function_call, + "timeout": timeout, + "temperature": temperature, + "top_p": top_p, + "n": n, + "stream": stream, + "stream_options": stream_options, + "stop": stop, + "max_tokens": max_tokens, + "presence_penalty": presence_penalty, + "frequency_penalty": frequency_penalty, + "logit_bias": logit_bias, + "user": user, + "response_format": response_format, + "seed": seed, + "tools": tools, + "tool_choice": tool_choice, + "parallel_tool_calls": parallel_tool_calls, + "logprobs": logprobs, + "top_logprobs": top_logprobs, + "deployment_id": deployment_id, + "base_url": base_url, + "api_version": api_version, + "api_key": api_key, + "model_list": model_list, + "extra_headers": extra_headers, + "acompletion": True, # assuming this is a required parameter + } + if custom_llm_provider is None: + _, custom_llm_provider, _, _ = get_llm_provider( + model=model, api_base=completion_kwargs.get("base_url", None) + ) + try: + # Use a partial function to pass your keyword arguments + func = partial(completion, **completion_kwargs, **kwargs) + + # Add the context to the function + ctx = contextvars.copy_context() + func_with_context = partial(ctx.run, func) + + if ( + custom_llm_provider == "openai" + or custom_llm_provider == "azure" + or custom_llm_provider == "azure_text" + or custom_llm_provider == "custom_openai" + or custom_llm_provider == "anyscale" + or custom_llm_provider == "mistral" + or custom_llm_provider == "openrouter" + or custom_llm_provider == "deepinfra" + or custom_llm_provider == "perplexity" + or custom_llm_provider == "groq" + or custom_llm_provider == "nvidia_nim" + or custom_llm_provider == "cerebras" + or custom_llm_provider == "ai21_chat" + or custom_llm_provider == "volcengine" + or custom_llm_provider == "codestral" + or custom_llm_provider == "text-completion-codestral" + or custom_llm_provider == "deepseek" + or custom_llm_provider == "text-completion-openai" + or custom_llm_provider == "huggingface" + or custom_llm_provider == "ollama" + or custom_llm_provider == "ollama_chat" + or custom_llm_provider == "replicate" + or custom_llm_provider == "vertex_ai" + or custom_llm_provider == "vertex_ai_beta" + or custom_llm_provider == "gemini" + or custom_llm_provider == "sagemaker" + or custom_llm_provider == "sagemaker_chat" + or custom_llm_provider == "anthropic" + or custom_llm_provider == "predibase" + or custom_llm_provider == "bedrock" + or custom_llm_provider == "databricks" + or custom_llm_provider == "triton" + or custom_llm_provider == "clarifai" + or custom_llm_provider == "watsonx" + or custom_llm_provider == "notdiamond" + or custom_llm_provider in litellm.openai_compatible_providers + or custom_llm_provider in litellm._custom_providers + ): # currently implemented aiohttp calls for just azure, openai, hf, ollama, vertex ai soon all. + init_response = await loop.run_in_executor(None, func_with_context) + if isinstance(init_response, dict) or isinstance( + init_response, ModelResponse + ): ## CACHING SCENARIO + if isinstance(init_response, dict): + response = ModelResponse(**init_response) + response = init_response + elif asyncio.iscoroutine(init_response): + response = await init_response + else: + response = init_response # type: ignore + + if ( + custom_llm_provider == "text-completion-openai" + or custom_llm_provider == "text-completion-codestral" + ) and isinstance(response, TextCompletionResponse): + response = litellm.OpenAITextCompletionConfig().convert_to_chat_model_response_object( + response_object=response, + model_response_object=litellm.ModelResponse(), + ) + else: + # Call the synchronous function using run_in_executor + response = await loop.run_in_executor(None, func_with_context) # type: ignore + if isinstance(response, CustomStreamWrapper): + response.set_logging_event_loop( + loop=loop + ) # sets the logging event loop if the user does sync streaming (e.g. on proxy for sagemaker calls) + return response + except Exception as e: + custom_llm_provider = custom_llm_provider or "openai" + raise exception_type( + model=model, + custom_llm_provider=custom_llm_provider, + original_exception=e, + completion_kwargs=completion_kwargs, + extra_kwargs=kwargs, + ) + + +async def _async_streaming(response, model, custom_llm_provider, args): + try: + print_verbose(f"received response in _async_streaming: {response}") + if asyncio.iscoroutine(response): + response = await response + async for line in response: + print_verbose(f"line in async streaming: {line}") + yield line + except Exception as e: + custom_llm_provider = custom_llm_provider or "openai" + raise exception_type( + model=model, + custom_llm_provider=custom_llm_provider, + original_exception=e, + ) + + +def mock_completion( + model: str, + messages: List, + stream: Optional[bool] = False, + n: Optional[int] = None, + mock_response: Union[str, Exception, dict] = "This is a mock request", + mock_tool_calls: Optional[List] = None, + logging=None, + custom_llm_provider=None, + **kwargs, +): + """ + Generate a mock completion response for testing or debugging purposes. + + This is a helper function that simulates the response structure of the OpenAI completion API. + + Parameters: + model (str): The name of the language model for which the mock response is generated. + messages (List): A list of message objects representing the conversation context. + stream (bool, optional): If True, returns a mock streaming response (default is False). + mock_response (str, optional): The content of the mock response (default is "This is a mock request"). + **kwargs: Additional keyword arguments that can be used but are not required. + + Returns: + litellm.ModelResponse: A ModelResponse simulating a completion response with the specified model, messages, and mock response. + + Raises: + Exception: If an error occurs during the generation of the mock completion response. + + Note: + - This function is intended for testing or debugging purposes to generate mock completion responses. + - If 'stream' is True, it returns a response that mimics the behavior of a streaming completion. + """ + try: + ## LOGGING + if logging is not None: + logging.pre_call( + input=messages, + api_key="mock-key", + ) + if isinstance(mock_response, Exception): + if isinstance(mock_response, openai.APIError): + raise mock_response + raise litellm.MockException( + status_code=getattr(mock_response, "status_code", 500), # type: ignore + message=getattr(mock_response, "text", str(mock_response)), + llm_provider=getattr(mock_response, "llm_provider", custom_llm_provider or "openai"), # type: ignore + model=model, # type: ignore + request=httpx.Request( + method="POST", url="https://api.openai.com/v1/" + ), + ) + elif ( + isinstance(mock_response, str) + and mock_response == "litellm.RateLimitError" + ): + raise litellm.RateLimitError( + message="this is a mock rate limit error", + llm_provider=getattr(mock_response, "llm_provider", custom_llm_provider or "openai"), # type: ignore + model=model, + ) + elif isinstance(mock_response, str) and mock_response.startswith( + "Exception: content_filter_policy" + ): + raise litellm.MockException( + status_code=400, + message=mock_response, + llm_provider="azure", + model=model, # type: ignore + request=httpx.Request( + method="POST", url="https://api.openai.com/v1/" + ), + ) + elif isinstance(mock_response, str) and mock_response.startswith( + "Exception: mock_streaming_error" + ): + mock_response = litellm.MockException( + message="This is a mock error raised mid-stream", + llm_provider="anthropic", + model=model, + status_code=529, + ) + time_delay = kwargs.get("mock_delay", None) + if time_delay is not None: + time.sleep(time_delay) + + if isinstance(mock_response, dict): + return ModelResponse(**mock_response) + + model_response = ModelResponse(stream=stream) + if stream is True: + # don't try to access stream object, + if kwargs.get("acompletion", False) is True: + return CustomStreamWrapper( + completion_stream=async_mock_completion_streaming_obj( + model_response, + mock_response=mock_response, + model=model, + n=n, + ), + model=model, + custom_llm_provider="openai", + logging_obj=logging, + ) + return CustomStreamWrapper( + completion_stream=mock_completion_streaming_obj( + model_response, + mock_response=mock_response, + model=model, + n=n, + ), + model=model, + custom_llm_provider="openai", + logging_obj=logging, + ) + if isinstance(mock_response, litellm.MockException): + raise mock_response + if n is None: + model_response.choices[0].message.content = mock_response # type: ignore + else: + _all_choices = [] + for i in range(n): + _choice = litellm.utils.Choices( + index=i, + message=litellm.utils.Message( + content=mock_response, role="assistant" + ), + ) + _all_choices.append(_choice) + model_response.choices = _all_choices # type: ignore + model_response.created = int(time.time()) + model_response.model = model + + if mock_tool_calls: + model_response.choices[0].message.tool_calls = [ # type: ignore + ChatCompletionMessageToolCall(**tool_call) + for tool_call in mock_tool_calls + ] + + setattr( + model_response, + "usage", + Usage(prompt_tokens=10, completion_tokens=20, total_tokens=30), + ) + + try: + _, custom_llm_provider, _, _ = litellm.utils.get_llm_provider( + model=model + ) + model_response._hidden_params[ + "custom_llm_provider" + ] = custom_llm_provider + except Exception: + # dont let setting a hidden param block a mock_respose + pass + + if logging is not None: + logging.post_call( + input=messages, + api_key="my-secret-key", + original_response="my-original-response", + ) + return model_response + + except Exception as e: + if isinstance(e, openai.APIError): + raise e + raise Exception("Mock completion response failed") + + +@client +def completion( + model: str, + # Optional OpenAI params: see https://platform.openai.com/docs/api-reference/chat/create + messages: List = [], + timeout: Optional[Union[float, str, httpx.Timeout]] = None, + temperature: Optional[float] = None, + top_p: Optional[float] = None, + n: Optional[int] = None, + stream: Optional[bool] = None, + stream_options: Optional[dict] = None, + stop=None, + max_tokens: Optional[int] = None, + presence_penalty: Optional[float] = None, + frequency_penalty: Optional[float] = None, + logit_bias: Optional[dict] = None, + user: Optional[str] = None, + # openai v1.0+ new params + response_format: Optional[Union[dict, Type[BaseModel]]] = None, + seed: Optional[int] = None, + tools: Optional[List] = None, + tool_choice: Optional[Union[str, dict]] = None, + logprobs: Optional[bool] = None, + top_logprobs: Optional[int] = None, + parallel_tool_calls: Optional[bool] = None, + deployment_id=None, + extra_headers: Optional[dict] = None, + # soon to be deprecated params by OpenAI + functions: Optional[List] = None, + function_call: Optional[str] = None, + # set api_base, api_version, api_key + base_url: Optional[str] = None, + api_version: Optional[str] = None, + api_key: Optional[str] = None, + model_list: Optional[list] = None, # pass in a list of api_base,keys, etc. + # Optional liteLLM function params + **kwargs, +) -> Union[ModelResponse, CustomStreamWrapper]: + """ + Perform a completion() using any of litellm supported llms (example gpt-4, gpt-3.5-turbo, claude-2, command-nightly) + Parameters: + model (str): The name of the language model to use for text completion. see all supported LLMs: https://docs.litellm.ai/docs/providers/ + messages (List): A list of message objects representing the conversation context (default is an empty list). + + OPTIONAL PARAMS + functions (List, optional): A list of functions to apply to the conversation messages (default is an empty list). + function_call (str, optional): The name of the function to call within the conversation (default is an empty string). + temperature (float, optional): The temperature parameter for controlling the randomness of the output (default is 1.0). + top_p (float, optional): The top-p parameter for nucleus sampling (default is 1.0). + n (int, optional): The number of completions to generate (default is 1). + stream (bool, optional): If True, return a streaming response (default is False). + stream_options (dict, optional): A dictionary containing options for the streaming response. Only set this when you set stream: true. + stop(string/list, optional): - Up to 4 sequences where the LLM API will stop generating further tokens. + max_tokens (integer, optional): The maximum number of tokens in the generated completion (default is infinity). + presence_penalty (float, optional): It is used to penalize new tokens based on their existence in the text so far. + frequency_penalty: It is used to penalize new tokens based on their frequency in the text so far. + logit_bias (dict, optional): Used to modify the probability of specific tokens appearing in the completion. + user (str, optional): A unique identifier representing your end-user. This can help the LLM provider to monitor and detect abuse. + logprobs (bool, optional): Whether to return log probabilities of the output tokens or not. If true, returns the log probabilities of each output token returned in the content of message + top_logprobs (int, optional): An integer between 0 and 5 specifying the number of most likely tokens to return at each token position, each with an associated log probability. logprobs must be set to true if this parameter is used. + metadata (dict, optional): Pass in additional metadata to tag your completion calls - eg. prompt version, details, etc. + api_base (str, optional): Base URL for the API (default is None). + api_version (str, optional): API version (default is None). + api_key (str, optional): API key (default is None). + model_list (list, optional): List of api base, version, keys + extra_headers (dict, optional): Additional headers to include in the request. + + LITELLM Specific Params + mock_response (str, optional): If provided, return a mock completion response for testing or debugging purposes (default is None). + custom_llm_provider (str, optional): Used for Non-OpenAI LLMs, Example usage for bedrock, set model="amazon.titan-tg1-large" and custom_llm_provider="bedrock" + max_retries (int, optional): The number of retries to attempt (default is 0). + Returns: + ModelResponse: A response object containing the generated completion and associated metadata. + + Note: + - This function is used to perform completions() using the specified language model. + - It supports various optional parameters for customizing the completion behavior. + - If 'mock_response' is provided, a mock completion response is returned for testing or debugging. + """ + ######### unpacking kwargs ##################### + args = locals() + api_base = kwargs.get("api_base", None) + mock_response = kwargs.get("mock_response", None) + mock_tool_calls = kwargs.get("mock_tool_calls", None) + force_timeout = kwargs.get("force_timeout", 600) ## deprecated + logger_fn = kwargs.get("logger_fn", None) + verbose = kwargs.get("verbose", False) + custom_llm_provider = kwargs.get("custom_llm_provider", None) + litellm_logging_obj = kwargs.get("litellm_logging_obj", None) + id = kwargs.get("id", None) + metadata = kwargs.get("metadata", None) + model_info = kwargs.get("model_info", None) + proxy_server_request = kwargs.get("proxy_server_request", None) + fallbacks = kwargs.get("fallbacks", None) + headers = kwargs.get("headers", None) or extra_headers + num_retries = kwargs.get( + "num_retries", None + ) ## alt. param for 'max_retries'. Use this to pass retries w/ instructor. + max_retries = kwargs.get("max_retries", None) + cooldown_time = kwargs.get("cooldown_time", None) + context_window_fallback_dict = kwargs.get( + "context_window_fallback_dict", None + ) + organization = kwargs.get("organization", None) + ### CUSTOM MODEL COST ### + input_cost_per_token = kwargs.get("input_cost_per_token", None) + output_cost_per_token = kwargs.get("output_cost_per_token", None) + input_cost_per_second = kwargs.get("input_cost_per_second", None) + output_cost_per_second = kwargs.get("output_cost_per_second", None) + ### CUSTOM PROMPT TEMPLATE ### + initial_prompt_value = kwargs.get("initial_prompt_value", None) + roles = kwargs.get("roles", None) + final_prompt_value = kwargs.get("final_prompt_value", None) + bos_token = kwargs.get("bos_token", None) + eos_token = kwargs.get("eos_token", None) + preset_cache_key = kwargs.get("preset_cache_key", None) + hf_model_name = kwargs.get("hf_model_name", None) + supports_system_message = kwargs.get("supports_system_message", None) + ### TEXT COMPLETION CALLS ### + text_completion = kwargs.get("text_completion", False) + atext_completion = kwargs.get("atext_completion", False) + ### ASYNC CALLS ### + acompletion = kwargs.get("acompletion", False) + client = kwargs.get("client", None) + ### Admin Controls ### + no_log = kwargs.get("no-log", False) + ### COPY MESSAGES ### - related issue https://github.com/BerriAI/litellm/discussions/4489 + messages = deepcopy(messages) + ######## end of unpacking kwargs ########### + openai_params = [ + "functions", + "function_call", + "temperature", + "temperature", + "top_p", + "n", + "stream", + "stream_options", + "stop", + "max_tokens", + "presence_penalty", + "frequency_penalty", + "logit_bias", + "user", + "request_timeout", + "api_base", + "api_version", + "api_key", + "deployment_id", + "organization", + "base_url", + "default_headers", + "timeout", + "response_format", + "seed", + "tools", + "tool_choice", + "max_retries", + "parallel_tool_calls", + "logprobs", + "top_logprobs", + "extra_headers", + ] + litellm_params = all_litellm_params # use the external var., used in creating cache key as well. + + default_params = openai_params + litellm_params + non_default_params = { + k: v for k, v in kwargs.items() if k not in default_params + } # model-specific params - pass them straight to the model/provider + + try: + if base_url is not None: + api_base = base_url + if num_retries is not None: + max_retries = num_retries + logging = litellm_logging_obj + fallbacks = fallbacks or litellm.model_fallbacks + if fallbacks is not None: + return completion_with_fallbacks(**args) + if model_list is not None: + deployments = [ + m["litellm_params"] + for m in model_list + if m["model_name"] == model + ] + return batch_completion_models(deployments=deployments, **args) + if litellm.model_alias_map and model in litellm.model_alias_map: + model = litellm.model_alias_map[ + model + ] # update the model to the actual value if an alias has been passed in + model_response = ModelResponse() + setattr(model_response, "usage", litellm.Usage()) + if ( + kwargs.get("azure", False) == True + ): # don't remove flag check, to remain backwards compatible for repos like Codium + custom_llm_provider = "azure" + if deployment_id != None: # azure llms + model = deployment_id + custom_llm_provider = "azure" + ( + model, + custom_llm_provider, + dynamic_api_key, + api_base, + ) = get_llm_provider( + model=model, + custom_llm_provider=custom_llm_provider, + api_base=api_base, + api_key=api_key, + ) + if model_response is not None and hasattr( + model_response, "_hidden_params" + ): + model_response._hidden_params[ + "custom_llm_provider" + ] = custom_llm_provider + model_response._hidden_params["region_name"] = kwargs.get( + "aws_region_name", None + ) # support region-based pricing for bedrock + + ### TIMEOUT LOGIC ### + timeout = timeout or kwargs.get("request_timeout", 600) or 600 + # set timeout for 10 minutes by default + if isinstance(timeout, httpx.Timeout) and not supports_httpx_timeout( + custom_llm_provider + ): + timeout = timeout.read or 600 # default 10 min timeout + elif not isinstance(timeout, httpx.Timeout): + timeout = float(timeout) # type: ignore + + ### REGISTER CUSTOM MODEL PRICING -- IF GIVEN ### + if ( + input_cost_per_token is not None + and output_cost_per_token is not None + ): + litellm.register_model( + { + f"{custom_llm_provider}/{model}": { + "input_cost_per_token": input_cost_per_token, + "output_cost_per_token": output_cost_per_token, + "litellm_provider": custom_llm_provider, + }, + model: { + "input_cost_per_token": input_cost_per_token, + "output_cost_per_token": output_cost_per_token, + "litellm_provider": custom_llm_provider, + }, + } + ) + elif ( + input_cost_per_second is not None + ): # time based pricing just needs cost in place + output_cost_per_second = output_cost_per_second + litellm.register_model( + { + f"{custom_llm_provider}/{model}": { + "input_cost_per_second": input_cost_per_second, + "output_cost_per_second": output_cost_per_second, + "litellm_provider": custom_llm_provider, + }, + model: { + "input_cost_per_second": input_cost_per_second, + "output_cost_per_second": output_cost_per_second, + "litellm_provider": custom_llm_provider, + }, + } + ) + ### BUILD CUSTOM PROMPT TEMPLATE -- IF GIVEN ### + custom_prompt_dict = {} # type: ignore + if ( + initial_prompt_value + or roles + or final_prompt_value + or bos_token + or eos_token + ): + custom_prompt_dict = {model: {}} + if initial_prompt_value: + custom_prompt_dict[model][ + "initial_prompt_value" + ] = initial_prompt_value + if roles: + custom_prompt_dict[model]["roles"] = roles + if final_prompt_value: + custom_prompt_dict[model][ + "final_prompt_value" + ] = final_prompt_value + if bos_token: + custom_prompt_dict[model]["bos_token"] = bos_token + if eos_token: + custom_prompt_dict[model]["eos_token"] = eos_token + + if ( + supports_system_message is not None + and isinstance(supports_system_message, bool) + and supports_system_message is False + ): + messages = map_system_message_pt(messages=messages) + model_api_key = get_api_key( + llm_provider=custom_llm_provider, dynamic_api_key=api_key + ) # get the api key from the environment if required for the model + + if dynamic_api_key is not None: + api_key = dynamic_api_key + # check if user passed in any of the OpenAI optional params + optional_params = get_optional_params( + functions=functions, + function_call=function_call, + temperature=temperature, + top_p=top_p, + n=n, + stream=stream, + stream_options=stream_options, + stop=stop, + max_tokens=max_tokens, + presence_penalty=presence_penalty, + frequency_penalty=frequency_penalty, + logit_bias=logit_bias, + user=user, + # params to identify the model + model=model, + custom_llm_provider=custom_llm_provider, + response_format=response_format, + seed=seed, + tools=tools, + tool_choice=tool_choice, + max_retries=max_retries, + logprobs=logprobs, + top_logprobs=top_logprobs, + extra_headers=extra_headers, + api_version=api_version, + parallel_tool_calls=parallel_tool_calls, + **non_default_params, + ) + + if litellm.add_function_to_prompt and optional_params.get( + "functions_unsupported_model", None + ): # if user opts to add it to prompt, when API doesn't support function calling + functions_unsupported_model = optional_params.pop( + "functions_unsupported_model" + ) + messages = function_call_prompt( + messages=messages, functions=functions_unsupported_model + ) + + # For logging - save the values of the litellm-specific params passed in + litellm_params = get_litellm_params( + acompletion=acompletion, + api_key=api_key, + force_timeout=force_timeout, + logger_fn=logger_fn, + verbose=verbose, + custom_llm_provider=custom_llm_provider, + api_base=api_base, + litellm_call_id=kwargs.get("litellm_call_id", None), + model_alias_map=litellm.model_alias_map, + completion_call_id=id, + metadata=metadata, + model_info=model_info, + proxy_server_request=proxy_server_request, + preset_cache_key=preset_cache_key, + no_log=no_log, + input_cost_per_second=input_cost_per_second, + input_cost_per_token=input_cost_per_token, + output_cost_per_second=output_cost_per_second, + output_cost_per_token=output_cost_per_token, + cooldown_time=cooldown_time, + text_completion=kwargs.get("text_completion"), + azure_ad_token_provider=kwargs.get("azure_ad_token_provider"), + user_continue_message=kwargs.get("user_continue_message"), + ) + logging.update_environment_variables( + model=model, + user=user, + optional_params=optional_params, + litellm_params=litellm_params, + custom_llm_provider=custom_llm_provider, + ) + if mock_response or mock_tool_calls: + return mock_completion( + model, + messages, + stream=stream, + n=n, + mock_response=mock_response, + mock_tool_calls=mock_tool_calls, + logging=logging, + acompletion=acompletion, + mock_delay=kwargs.get("mock_delay", None), + custom_llm_provider=custom_llm_provider, + ) + + if custom_llm_provider == "azure": + # azure configs + ## check dynamic params ## + dynamic_params = False + if client is not None and ( + isinstance(client, openai.AzureOpenAI) + or isinstance(client, openai.AsyncAzureOpenAI) + ): + dynamic_params = _check_dynamic_azure_params( + azure_client_params={"api_version": api_version}, + azure_client=client, + ) + + api_type = get_secret("AZURE_API_TYPE") or "azure" + + api_base = ( + api_base or litellm.api_base or get_secret("AZURE_API_BASE") + ) + + api_version = ( + api_version + or litellm.api_version + or get_secret("AZURE_API_VERSION") + or litellm.AZURE_DEFAULT_API_VERSION + ) + + api_key = ( + api_key + or litellm.api_key + or litellm.azure_key + or get_secret("AZURE_OPENAI_API_KEY") + or get_secret("AZURE_API_KEY") + ) + + azure_ad_token = optional_params.get("extra_body", {}).pop( + "azure_ad_token", None + ) or get_secret("AZURE_AD_TOKEN") + + headers = headers or litellm.headers + + ## LOAD CONFIG - if set + config = litellm.AzureOpenAIConfig.get_config() + for k, v in config.items(): + if ( + k not in optional_params + ): # completion(top_k=3) > azure_config(top_k=3) <- allows for dynamic variables to be passed in + optional_params[k] = v + + ## COMPLETION CALL + response = azure_chat_completions.completion( + model=model, + messages=messages, + headers=headers, + api_key=api_key, + api_base=api_base, + api_version=api_version, + api_type=api_type, + dynamic_params=dynamic_params, + azure_ad_token=azure_ad_token, + model_response=model_response, + print_verbose=print_verbose, + optional_params=optional_params, + litellm_params=litellm_params, + logger_fn=logger_fn, + logging_obj=logging, + acompletion=acompletion, + timeout=timeout, # type: ignore + client=client, # pass AsyncAzureOpenAI, AzureOpenAI client + ) + + if optional_params.get("stream", False): + ## LOGGING + logging.post_call( + input=messages, + api_key=api_key, + original_response=response, + additional_args={ + "headers": headers, + "api_version": api_version, + "api_base": api_base, + }, + ) + + elif custom_llm_provider == "notdiamond": + notdiamond_key = ( + api_key + or litellm.notdiamond_key + or get_secret("NOTDIAMOND_API_KEY") + or litellm.api_key + ) + + api_base = ( + api_base + or litellm.api_base + or get_secret("NOTDIAMOND_API_BASE") + or "https://not-diamond-server.onrender.com/v2/optimizer/modelSelect" + ) + + # since notdiamond.completion() internally calls other models' completion functions + # streaming does not need to be handled separately + response = notdiamond_completion( + model=model, + messages=messages, + api_base=api_base, + model_response=model_response, + print_verbose=print_verbose, + optional_params=optional_params, + litellm_params=litellm_params, + logger_fn=logger_fn, + encoding=encoding, + api_key=notdiamond_key, + logging_obj=logging, + ) + + if optional_params.get("stream", False) or acompletion == True: + ## LOGGING + logging.post_call( + input=messages, + api_key=api_key, + original_response=response, + ) + + response = response + + elif custom_llm_provider == "azure_text": + # azure configs + api_type = get_secret("AZURE_API_TYPE") or "azure" + + api_base = ( + api_base or litellm.api_base or get_secret("AZURE_API_BASE") + ) + + api_version = ( + api_version + or litellm.api_version + or get_secret("AZURE_API_VERSION") + ) + + api_key = ( + api_key + or litellm.api_key + or litellm.azure_key + or get_secret("AZURE_OPENAI_API_KEY") + or get_secret("AZURE_API_KEY") + ) + + azure_ad_token = optional_params.get("extra_body", {}).pop( + "azure_ad_token", None + ) or get_secret("AZURE_AD_TOKEN") + + headers = headers or litellm.headers + + ## LOAD CONFIG - if set + config = litellm.AzureOpenAIConfig.get_config() + for k, v in config.items(): + if ( + k not in optional_params + ): # completion(top_k=3) > azure_config(top_k=3) <- allows for dynamic variables to be passed in + optional_params[k] = v + + ## COMPLETION CALL + response = azure_text_completions.completion( + model=model, + messages=messages, + headers=headers, + api_key=api_key, + api_base=api_base, + api_version=api_version, + api_type=api_type, + azure_ad_token=azure_ad_token, + model_response=model_response, + print_verbose=print_verbose, + optional_params=optional_params, + litellm_params=litellm_params, + logger_fn=logger_fn, + logging_obj=logging, + acompletion=acompletion, + timeout=timeout, + client=client, # pass AsyncAzureOpenAI, AzureOpenAI client + ) + + if optional_params.get("stream", False) or acompletion == True: + ## LOGGING + logging.post_call( + input=messages, + api_key=api_key, + original_response=response, + additional_args={ + "headers": headers, + "api_version": api_version, + "api_base": api_base, + }, + ) + elif custom_llm_provider == "azure_ai": + api_base = ( + api_base # for deepinfra/perplexity/anyscale/groq/friendliai we check in get_llm_provider and pass in the api base from there + or litellm.api_base + or get_secret("AZURE_AI_API_BASE") + ) + # set API KEY + api_key = ( + api_key + or litellm.api_key # for deepinfra/perplexity/anyscale/friendliai we check in get_llm_provider and pass in the api key from there + or litellm.openai_key + or get_secret("AZURE_AI_API_KEY") + ) + + headers = headers or litellm.headers + + ## LOAD CONFIG - if set + config = litellm.OpenAIConfig.get_config() + for k, v in config.items(): + if ( + k not in optional_params + ): # completion(top_k=3) > openai_config(top_k=3) <- allows for dynamic variables to be passed in + optional_params[k] = v + + ## FOR COHERE + if "command-r" in model: # make sure tool call in messages are str + messages = stringify_json_tool_call_content(messages=messages) + + ## COMPLETION CALL + try: + response = openai_chat_completions.completion( + model=model, + messages=messages, + headers=headers, + model_response=model_response, + print_verbose=print_verbose, + api_key=api_key, + api_base=api_base, + acompletion=acompletion, + logging_obj=logging, + optional_params=optional_params, + litellm_params=litellm_params, + logger_fn=logger_fn, + timeout=timeout, # type: ignore + custom_prompt_dict=custom_prompt_dict, + client=client, # pass AsyncOpenAI, OpenAI client + organization=organization, + custom_llm_provider=custom_llm_provider, + drop_params=non_default_params.get("drop_params"), + ) + except Exception as e: + ## LOGGING - log the original exception returned + logging.post_call( + input=messages, + api_key=api_key, + original_response=str(e), + additional_args={"headers": headers}, + ) + raise e + + if optional_params.get("stream", False): + ## LOGGING + logging.post_call( + input=messages, + api_key=api_key, + original_response=response, + additional_args={"headers": headers}, + ) + elif ( + custom_llm_provider == "text-completion-openai" + or "ft:babbage-002" in model + or "ft:davinci-002" + in model # support for finetuned completion models + or custom_llm_provider + in litellm.openai_text_completion_compatible_providers + and kwargs.get("text_completion") is True + ): + openai.api_type = "openai" + + api_base = ( + api_base + or litellm.api_base + or get_secret("OPENAI_API_BASE") + or "https://api.openai.com/v1" + ) + + openai.api_version = None + # set API KEY + + api_key = ( + api_key + or litellm.api_key + or litellm.openai_key + or get_secret("OPENAI_API_KEY") + ) + + headers = headers or litellm.headers + + ## LOAD CONFIG - if set + config = litellm.OpenAITextCompletionConfig.get_config() + for k, v in config.items(): + if ( + k not in optional_params + ): # completion(top_k=3) > openai_text_config(top_k=3) <- allows for dynamic variables to be passed in + optional_params[k] = v + if litellm.organization: + openai.organization = litellm.organization + + if ( + len(messages) > 0 + and "content" in messages[0] + and type(messages[0]["content"]) == list + ): + # text-davinci-003 can accept a string or array, if it's an array, assume the array is set in messages[0]['content'] + # https://platform.openai.com/docs/api-reference/completions/create + prompt = messages[0]["content"] + else: + prompt = " ".join([message["content"] for message in messages]) # type: ignore + + ## COMPLETION CALL + _response = openai_text_completions.completion( + model=model, + messages=messages, + model_response=model_response, + print_verbose=print_verbose, + api_key=api_key, + api_base=api_base, + acompletion=acompletion, + client=client, # pass AsyncOpenAI, OpenAI client + logging_obj=logging, + optional_params=optional_params, + litellm_params=litellm_params, + logger_fn=logger_fn, + timeout=timeout, # type: ignore + ) + + if ( + optional_params.get("stream", False) == False + and acompletion == False + and text_completion == False + ): + # convert to chat completion response + _response = litellm.OpenAITextCompletionConfig().convert_to_chat_model_response_object( + response_object=_response, + model_response_object=model_response, + ) + + if optional_params.get("stream", False) or acompletion == True: + ## LOGGING + logging.post_call( + input=messages, + api_key=api_key, + original_response=_response, + additional_args={"headers": headers}, + ) + response = _response + + elif ( + model in litellm.open_ai_chat_completion_models + or custom_llm_provider == "custom_openai" + or custom_llm_provider == "deepinfra" + or custom_llm_provider == "perplexity" + or custom_llm_provider == "groq" + or custom_llm_provider == "nvidia_nim" + or custom_llm_provider == "cerebras" + or custom_llm_provider == "ai21_chat" + or custom_llm_provider == "volcengine" + or custom_llm_provider == "codestral" + or custom_llm_provider == "deepseek" + or custom_llm_provider == "anyscale" + or custom_llm_provider == "mistral" + or custom_llm_provider == "openai" + or custom_llm_provider == "together_ai" + or custom_llm_provider in litellm.openai_compatible_providers + or "ft:gpt-3.5-turbo" in model # finetune gpt-3.5-turbo + ): # allow user to make an openai call with a custom base + # note: if a user sets a custom base - we should ensure this works + # allow for the setting of dynamic and stateful api-bases + api_base = ( + api_base # for deepinfra/perplexity/anyscale/groq/friendliai we check in get_llm_provider and pass in the api base from there + or litellm.api_base + or get_secret("OPENAI_API_BASE") + or "https://api.openai.com/v1" + ) + openai.organization = ( + organization + or litellm.organization + or get_secret("OPENAI_ORGANIZATION") + or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105 + ) + # set API KEY + api_key = ( + api_key + or litellm.api_key # for deepinfra/perplexity/anyscale/friendliai we check in get_llm_provider and pass in the api key from there + or litellm.openai_key + or get_secret("OPENAI_API_KEY") + ) + + headers = headers or litellm.headers + + ## LOAD CONFIG - if set + config = litellm.OpenAIConfig.get_config() + for k, v in config.items(): + if ( + k not in optional_params + ): # completion(top_k=3) > openai_config(top_k=3) <- allows for dynamic variables to be passed in + optional_params[k] = v + + ## COMPLETION CALL + try: + response = openai_chat_completions.completion( + model=model, + messages=messages, + headers=headers, + model_response=model_response, + print_verbose=print_verbose, + api_key=api_key, + api_base=api_base, + acompletion=acompletion, + logging_obj=logging, + optional_params=optional_params, + litellm_params=litellm_params, + logger_fn=logger_fn, + timeout=timeout, # type: ignore + custom_prompt_dict=custom_prompt_dict, + client=client, # pass AsyncOpenAI, OpenAI client + organization=organization, + custom_llm_provider=custom_llm_provider, + ) + except Exception as e: + ## LOGGING - log the original exception returned + logging.post_call( + input=messages, + api_key=api_key, + original_response=str(e), + additional_args={"headers": headers}, + ) + raise e + + if optional_params.get("stream", False): + ## LOGGING + logging.post_call( + input=messages, + api_key=api_key, + original_response=response, + additional_args={"headers": headers}, + ) + elif ( + "replicate" in model + or custom_llm_provider == "replicate" + or model in litellm.replicate_models + ): + # Setting the relevant API KEY for replicate, replicate defaults to using os.environ.get("REPLICATE_API_TOKEN") + replicate_key = None + replicate_key = ( + api_key + or litellm.replicate_key + or litellm.api_key + or get_secret("REPLICATE_API_KEY") + or get_secret("REPLICATE_API_TOKEN") + ) + + api_base = ( + api_base + or litellm.api_base + or get_secret("REPLICATE_API_BASE") + or "https://api.replicate.com/v1" + ) + + custom_prompt_dict = ( + custom_prompt_dict or litellm.custom_prompt_dict + ) + + model_response = replicate.completion( # type: ignore + model=model, + messages=messages, + api_base=api_base, + model_response=model_response, + print_verbose=print_verbose, + optional_params=optional_params, + litellm_params=litellm_params, + logger_fn=logger_fn, + encoding=encoding, # for calculating input/output tokens + api_key=replicate_key, + logging_obj=logging, + custom_prompt_dict=custom_prompt_dict, + acompletion=acompletion, + ) + + if optional_params.get("stream", False) == True: + ## LOGGING + logging.post_call( + input=messages, + api_key=replicate_key, + original_response=model_response, + ) + + response = model_response + elif ( + "clarifai" in model + or custom_llm_provider == "clarifai" + or model in litellm.clarifai_models + ): + clarifai_key = None + clarifai_key = ( + api_key + or litellm.clarifai_key + or litellm.api_key + or get_secret("CLARIFAI_API_KEY") + or get_secret("CLARIFAI_API_TOKEN") + ) + + api_base = ( + api_base + or litellm.api_base + or get_secret("CLARIFAI_API_BASE") + or "https://api.clarifai.com/v2" + ) + + custom_prompt_dict = ( + custom_prompt_dict or litellm.custom_prompt_dict + ) + model_response = clarifai.completion( + model=model, + messages=messages, + api_base=api_base, + model_response=model_response, + print_verbose=print_verbose, + optional_params=optional_params, + litellm_params=litellm_params, + acompletion=acompletion, + logger_fn=logger_fn, + encoding=encoding, # for calculating input/output tokens + api_key=clarifai_key, + logging_obj=logging, + custom_prompt_dict=custom_prompt_dict, + ) + + if ( + "stream" in optional_params + and optional_params["stream"] == True + ): + # don't try to access stream object, + ## LOGGING + logging.post_call( + input=messages, + api_key=api_key, + original_response=model_response, + ) + + if optional_params.get("stream", False) or acompletion == True: + ## LOGGING + logging.post_call( + input=messages, + api_key=clarifai_key, + original_response=model_response, + ) + response = model_response + + elif custom_llm_provider == "anthropic": + api_key = ( + api_key + or litellm.anthropic_key + or litellm.api_key + or os.environ.get("ANTHROPIC_API_KEY") + ) + custom_prompt_dict = ( + custom_prompt_dict or litellm.custom_prompt_dict + ) + + if (model == "claude-2") or (model == "claude-instant-1"): + # call anthropic /completion, only use this route for claude-2, claude-instant-1 + api_base = ( + api_base + or litellm.api_base + or get_secret("ANTHROPIC_API_BASE") + or get_secret("ANTHROPIC_BASE_URL") + or "https://api.anthropic.com/v1/complete" + ) + + if api_base is not None and not api_base.endswith( + "/v1/complete" + ): + api_base += "/v1/complete" + + response = anthropic_text_completions.completion( + model=model, + messages=messages, + api_base=api_base, + acompletion=acompletion, + custom_prompt_dict=litellm.custom_prompt_dict, + model_response=model_response, + print_verbose=print_verbose, + optional_params=optional_params, + litellm_params=litellm_params, + logger_fn=logger_fn, + encoding=encoding, # for calculating input/output tokens + api_key=api_key, + logging_obj=logging, + headers=headers, + ) + else: + # call /messages + # default route for all anthropic models + api_base = ( + api_base + or litellm.api_base + or get_secret("ANTHROPIC_API_BASE") + or get_secret("ANTHROPIC_BASE_URL") + or "https://api.anthropic.com/v1/messages" + ) + + if api_base is not None and not api_base.endswith( + "/v1/messages" + ): + api_base += "/v1/messages" + + response = anthropic_chat_completions.completion( + model=model, + messages=messages, + api_base=api_base, + acompletion=acompletion, + custom_prompt_dict=litellm.custom_prompt_dict, + model_response=model_response, + print_verbose=print_verbose, + optional_params=optional_params, + litellm_params=litellm_params, + logger_fn=logger_fn, + encoding=encoding, # for calculating input/output tokens + api_key=api_key, + logging_obj=logging, + headers=headers, + timeout=timeout, + client=client, + ) + if optional_params.get("stream", False) or acompletion == True: + ## LOGGING + logging.post_call( + input=messages, + api_key=api_key, + original_response=response, + ) + response = response + elif custom_llm_provider == "nlp_cloud": + nlp_cloud_key = ( + api_key + or litellm.nlp_cloud_key + or get_secret("NLP_CLOUD_API_KEY") + or litellm.api_key + ) + + api_base = ( + api_base + or litellm.api_base + or get_secret("NLP_CLOUD_API_BASE") + or "https://api.nlpcloud.io/v1/gpu/" + ) + + response = nlp_cloud.completion( + model=model, + messages=messages, + api_base=api_base, + model_response=model_response, + print_verbose=print_verbose, + optional_params=optional_params, + litellm_params=litellm_params, + logger_fn=logger_fn, + encoding=encoding, + api_key=nlp_cloud_key, + logging_obj=logging, + ) + + if ( + "stream" in optional_params + and optional_params["stream"] == True + ): + # don't try to access stream object, + response = CustomStreamWrapper( + response, + model, + custom_llm_provider="nlp_cloud", + logging_obj=logging, + ) + + if optional_params.get("stream", False) or acompletion == True: + ## LOGGING + logging.post_call( + input=messages, + api_key=api_key, + original_response=response, + ) + + response = response + elif custom_llm_provider == "aleph_alpha": + aleph_alpha_key = ( + api_key + or litellm.aleph_alpha_key + or get_secret("ALEPH_ALPHA_API_KEY") + or get_secret("ALEPHALPHA_API_KEY") + or litellm.api_key + ) + + api_base = ( + api_base + or litellm.api_base + or get_secret("ALEPH_ALPHA_API_BASE") + or "https://api.aleph-alpha.com/complete" + ) + + model_response = aleph_alpha.completion( + model=model, + messages=messages, + api_base=api_base, + model_response=model_response, + print_verbose=print_verbose, + optional_params=optional_params, + litellm_params=litellm_params, + logger_fn=logger_fn, + encoding=encoding, + default_max_tokens_to_sample=litellm.max_tokens, + api_key=aleph_alpha_key, + logging_obj=logging, # model call logging done inside the class as we make need to modify I/O to fit aleph alpha's requirements + ) + + if ( + "stream" in optional_params + and optional_params["stream"] == True + ): + # don't try to access stream object, + response = CustomStreamWrapper( + model_response, + model, + custom_llm_provider="aleph_alpha", + logging_obj=logging, + ) + return response + response = model_response + elif custom_llm_provider == "cohere": + cohere_key = ( + api_key + or litellm.cohere_key + or get_secret("COHERE_API_KEY") + or get_secret("CO_API_KEY") + or litellm.api_key + ) + + api_base = ( + api_base + or litellm.api_base + or get_secret("COHERE_API_BASE") + or "https://api.cohere.ai/v1/generate" + ) + + headers = headers or litellm.headers or {} + if headers is None: + headers = {} + + if extra_headers is not None: + headers.update(extra_headers) + + model_response = cohere_completion.completion( + model=model, + messages=messages, + api_base=api_base, + model_response=model_response, + print_verbose=print_verbose, + optional_params=optional_params, + litellm_params=litellm_params, + logger_fn=logger_fn, + encoding=encoding, + headers=headers, + api_key=cohere_key, + logging_obj=logging, # model call logging done inside the class as we make need to modify I/O to fit aleph alpha's requirements + ) + + if ( + "stream" in optional_params + and optional_params["stream"] == True + ): + # don't try to access stream object, + response = CustomStreamWrapper( + model_response, + model, + custom_llm_provider="cohere", + logging_obj=logging, + ) + return response + response = model_response + elif custom_llm_provider == "cohere_chat": + cohere_key = ( + api_key + or litellm.cohere_key + or get_secret("COHERE_API_KEY") + or get_secret("CO_API_KEY") + or litellm.api_key + ) + + api_base = ( + api_base + or litellm.api_base + or get_secret("COHERE_API_BASE") + or "https://api.cohere.ai/v1/chat" + ) + + headers = headers or litellm.headers or {} + if headers is None: + headers = {} + + if extra_headers is not None: + headers.update(extra_headers) + + model_response = cohere_chat.completion( + model=model, + messages=messages, + api_base=api_base, + model_response=model_response, + print_verbose=print_verbose, + optional_params=optional_params, + litellm_params=litellm_params, + headers=headers, + logger_fn=logger_fn, + encoding=encoding, + api_key=cohere_key, + logging_obj=logging, # model call logging done inside the class as we make need to modify I/O to fit aleph alpha's requirements + ) + + if ( + "stream" in optional_params + and optional_params["stream"] == True + ): + # don't try to access stream object, + response = CustomStreamWrapper( + model_response, + model, + custom_llm_provider="cohere_chat", + logging_obj=logging, + ) + return response + response = model_response + elif custom_llm_provider == "maritalk": + maritalk_key = ( + api_key + or litellm.maritalk_key + or get_secret("MARITALK_API_KEY") + or litellm.api_key + ) + + api_base = ( + api_base + or litellm.api_base + or get_secret("MARITALK_API_BASE") + or "https://chat.maritaca.ai/api/chat/inference" + ) + + model_response = maritalk.completion( + model=model, + messages=messages, + api_base=api_base, + model_response=model_response, + print_verbose=print_verbose, + optional_params=optional_params, + litellm_params=litellm_params, + logger_fn=logger_fn, + encoding=encoding, + api_key=maritalk_key, + logging_obj=logging, + ) + + if ( + "stream" in optional_params + and optional_params["stream"] == True + ): + # don't try to access stream object, + response = CustomStreamWrapper( + model_response, + model, + custom_llm_provider="maritalk", + logging_obj=logging, + ) + return response + response = model_response + elif custom_llm_provider == "huggingface": + custom_llm_provider = "huggingface" + huggingface_key = ( + api_key + or litellm.huggingface_key + or os.environ.get("HF_TOKEN") + or os.environ.get("HUGGINGFACE_API_KEY") + or litellm.api_key + ) + hf_headers = headers or litellm.headers + + custom_prompt_dict = ( + custom_prompt_dict or litellm.custom_prompt_dict + ) + model_response = huggingface.completion( + model=model, + messages=messages, + api_base=api_base, # type: ignore + headers=hf_headers, + model_response=model_response, + print_verbose=print_verbose, + optional_params=optional_params, + litellm_params=litellm_params, + logger_fn=logger_fn, + encoding=encoding, + api_key=huggingface_key, + acompletion=acompletion, + logging_obj=logging, + custom_prompt_dict=custom_prompt_dict, + timeout=timeout, # type: ignore + ) + if ( + "stream" in optional_params + and optional_params["stream"] == True + and acompletion is False + ): + # don't try to access stream object, + response = CustomStreamWrapper( + model_response, + model, + custom_llm_provider="huggingface", + logging_obj=logging, + ) + return response + response = model_response + elif custom_llm_provider == "oobabooga": + custom_llm_provider = "oobabooga" + model_response = oobabooga.completion( + model=model, + messages=messages, + model_response=model_response, + api_base=api_base, # type: ignore + print_verbose=print_verbose, + optional_params=optional_params, + litellm_params=litellm_params, + api_key=None, + logger_fn=logger_fn, + encoding=encoding, + logging_obj=logging, + ) + if ( + "stream" in optional_params + and optional_params["stream"] == True + ): + # don't try to access stream object, + response = CustomStreamWrapper( + model_response, + model, + custom_llm_provider="oobabooga", + logging_obj=logging, + ) + return response + response = model_response + elif custom_llm_provider == "databricks": + api_base = ( + api_base # for databricks we check in get_llm_provider and pass in the api base from there + or litellm.api_base + or os.getenv("DATABRICKS_API_BASE") + ) + + # set API KEY + api_key = ( + api_key + or litellm.api_key # for databricks we check in get_llm_provider and pass in the api key from there + or litellm.databricks_key + or get_secret("DATABRICKS_API_KEY") + ) + + headers = headers or litellm.headers + + ## COMPLETION CALL + try: + response = databricks_chat_completions.completion( + model=model, + messages=messages, + headers=headers, + model_response=model_response, + print_verbose=print_verbose, + api_key=api_key, + api_base=api_base, + acompletion=acompletion, + logging_obj=logging, + optional_params=optional_params, + litellm_params=litellm_params, + logger_fn=logger_fn, + timeout=timeout, # type: ignore + custom_prompt_dict=custom_prompt_dict, + client=client, # pass AsyncOpenAI, OpenAI client + encoding=encoding, + custom_llm_provider="databricks", + ) + except Exception as e: + ## LOGGING - log the original exception returned + logging.post_call( + input=messages, + api_key=api_key, + original_response=str(e), + additional_args={"headers": headers}, + ) + raise e + + if optional_params.get("stream", False): + ## LOGGING + logging.post_call( + input=messages, + api_key=api_key, + original_response=response, + additional_args={"headers": headers}, + ) + elif custom_llm_provider == "openrouter": + api_base = ( + api_base or litellm.api_base or "https://openrouter.ai/api/v1" + ) + + api_key = ( + api_key + or litellm.api_key + or litellm.openrouter_key + or get_secret("OPENROUTER_API_KEY") + or get_secret("OR_API_KEY") + ) + + openrouter_site_url = ( + get_secret("OR_SITE_URL") or "https://litellm.ai" + ) + openrouter_app_name = get_secret("OR_APP_NAME") or "liteLLM" + + openrouter_headers = { + "HTTP-Referer": openrouter_site_url, + "X-Title": openrouter_app_name, + } + + _headers = headers or litellm.headers + if _headers: + openrouter_headers.update(_headers) + + headers = openrouter_headers + + ## Load Config + config = openrouter.OpenrouterConfig.get_config() + for k, v in config.items(): + if k == "extra_body": + # we use openai 'extra_body' to pass openrouter specific params - transforms, route, models + if "extra_body" in optional_params: + optional_params[k].update(v) + else: + optional_params[k] = v + elif k not in optional_params: + optional_params[k] = v + + data = {"model": model, "messages": messages, **optional_params} + + ## COMPLETION CALL + response = openai_chat_completions.completion( + model=model, + messages=messages, + headers=headers, + api_key=api_key, + api_base=api_base, + model_response=model_response, + print_verbose=print_verbose, + optional_params=optional_params, + litellm_params=litellm_params, + logger_fn=logger_fn, + logging_obj=logging, + acompletion=acompletion, + timeout=timeout, # type: ignore + custom_llm_provider="openrouter", + ) + ## LOGGING + logging.post_call( + input=messages, + api_key=openai.api_key, + original_response=response, + ) + elif ( + custom_llm_provider == "together_ai" + or ("togethercomputer" in model) + or (model in litellm.together_ai_models) + ): + """ + Deprecated. We now do together ai calls via the openai client - https://docs.together.ai/docs/openai-api-compatibility + """ + pass + elif custom_llm_provider == "palm": + palm_api_key = ( + api_key or get_secret("PALM_API_KEY") or litellm.api_key + ) + + # palm does not support streaming as yet :( + model_response = palm.completion( + model=model, + messages=messages, + model_response=model_response, + print_verbose=print_verbose, + optional_params=optional_params, + litellm_params=litellm_params, + logger_fn=logger_fn, + encoding=encoding, + api_key=palm_api_key, + logging_obj=logging, + ) + # fake palm streaming + if ( + "stream" in optional_params + and optional_params["stream"] == True + ): + # fake streaming for palm + resp_string = model_response["choices"][0]["message"][ + "content" + ] + response = CustomStreamWrapper( + resp_string, + model, + custom_llm_provider="palm", + logging_obj=logging, + ) + return response + response = model_response + elif ( + custom_llm_provider == "vertex_ai_beta" + or custom_llm_provider == "gemini" + ): + vertex_ai_project = ( + optional_params.pop("vertex_project", None) + or optional_params.pop("vertex_ai_project", None) + or litellm.vertex_project + or get_secret("VERTEXAI_PROJECT") + ) + vertex_ai_location = ( + optional_params.pop("vertex_location", None) + or optional_params.pop("vertex_ai_location", None) + or litellm.vertex_location + or get_secret("VERTEXAI_LOCATION") + ) + vertex_credentials = ( + optional_params.pop("vertex_credentials", None) + or optional_params.pop("vertex_ai_credentials", None) + or get_secret("VERTEXAI_CREDENTIALS") + ) + + gemini_api_key = ( + api_key + or get_secret("GEMINI_API_KEY") + or get_secret( + "PALM_API_KEY" + ) # older palm api key should also work + or litellm.api_key + ) + + new_params = deepcopy(optional_params) + response = vertex_chat_completion.completion( # type: ignore + model=model, + messages=messages, + model_response=model_response, + print_verbose=print_verbose, + optional_params=new_params, + litellm_params=litellm_params, # type: ignore + logger_fn=logger_fn, + encoding=encoding, + vertex_location=vertex_ai_location, + vertex_project=vertex_ai_project, + vertex_credentials=vertex_credentials, + gemini_api_key=gemini_api_key, + logging_obj=logging, + acompletion=acompletion, + timeout=timeout, + custom_llm_provider=custom_llm_provider, + client=client, + api_base=api_base, + extra_headers=extra_headers, + ) + + elif custom_llm_provider == "vertex_ai": + vertex_ai_project = ( + optional_params.pop("vertex_project", None) + or optional_params.pop("vertex_ai_project", None) + or litellm.vertex_project + or get_secret("VERTEXAI_PROJECT") + ) + vertex_ai_location = ( + optional_params.pop("vertex_location", None) + or optional_params.pop("vertex_ai_location", None) + or litellm.vertex_location + or get_secret("VERTEXAI_LOCATION") + ) + vertex_credentials = ( + optional_params.pop("vertex_credentials", None) + or optional_params.pop("vertex_ai_credentials", None) + or get_secret("VERTEXAI_CREDENTIALS") + ) + + new_params = deepcopy(optional_params) + if "claude-3" in model: + model_response = vertex_ai_anthropic.completion( + model=model, + messages=messages, + model_response=model_response, + print_verbose=print_verbose, + optional_params=new_params, + litellm_params=litellm_params, + logger_fn=logger_fn, + encoding=encoding, + vertex_location=vertex_ai_location, + vertex_project=vertex_ai_project, + vertex_credentials=vertex_credentials, + logging_obj=logging, + acompletion=acompletion, + headers=headers, + custom_prompt_dict=custom_prompt_dict, + timeout=timeout, + client=client, + ) + elif ( + model.startswith("meta/") + or model.startswith("mistral") + or model.startswith("codestral") + or model.startswith("jamba") + ): + model_response = ( + vertex_partner_models_chat_completion.completion( + model=model, + messages=messages, + model_response=model_response, + print_verbose=print_verbose, + optional_params=new_params, + litellm_params=litellm_params, # type: ignore + logger_fn=logger_fn, + encoding=encoding, + vertex_location=vertex_ai_location, + vertex_project=vertex_ai_project, + vertex_credentials=vertex_credentials, + logging_obj=logging, + acompletion=acompletion, + headers=headers, + custom_prompt_dict=custom_prompt_dict, + timeout=timeout, + client=client, + ) + ) + elif "gemini" in model: + model_response = vertex_chat_completion.completion( # type: ignore + model=model, + messages=messages, + model_response=model_response, + print_verbose=print_verbose, + optional_params=new_params, + litellm_params=litellm_params, # type: ignore + logger_fn=logger_fn, + encoding=encoding, + vertex_location=vertex_ai_location, + vertex_project=vertex_ai_project, + vertex_credentials=vertex_credentials, + gemini_api_key=None, + logging_obj=logging, + acompletion=acompletion, + timeout=timeout, + custom_llm_provider=custom_llm_provider, + client=client, + api_base=api_base, + extra_headers=extra_headers, + ) + else: + model_response = vertex_ai_non_gemini.completion( + model=model, + messages=messages, + model_response=model_response, + print_verbose=print_verbose, + optional_params=new_params, + litellm_params=litellm_params, + logger_fn=logger_fn, + encoding=encoding, + vertex_location=vertex_ai_location, + vertex_project=vertex_ai_project, + vertex_credentials=vertex_credentials, + logging_obj=logging, + acompletion=acompletion, + ) + + if ( + "stream" in optional_params + and optional_params["stream"] is True + and acompletion is False + ): + response = CustomStreamWrapper( + model_response, + model, + custom_llm_provider="vertex_ai", + logging_obj=logging, + ) + return response + response = model_response + elif custom_llm_provider == "predibase": + tenant_id = ( + optional_params.pop("tenant_id", None) + or optional_params.pop("predibase_tenant_id", None) + or litellm.predibase_tenant_id + or get_secret("PREDIBASE_TENANT_ID") + ) + + api_base = ( + api_base + or optional_params.pop("api_base", None) + or optional_params.pop("base_url", None) + or litellm.api_base + or get_secret("PREDIBASE_API_BASE") + ) + + api_key = ( + api_key + or litellm.api_key + or litellm.predibase_key + or get_secret("PREDIBASE_API_KEY") + ) + + _model_response = predibase_chat_completions.completion( + model=model, + messages=messages, + model_response=model_response, + print_verbose=print_verbose, + optional_params=optional_params, + litellm_params=litellm_params, + logger_fn=logger_fn, + encoding=encoding, + logging_obj=logging, + acompletion=acompletion, + api_base=api_base, + custom_prompt_dict=custom_prompt_dict, + api_key=api_key, + tenant_id=tenant_id, + timeout=timeout, + ) + + if ( + "stream" in optional_params + and optional_params["stream"] is True + and acompletion is False + ): + return _model_response + response = _model_response + elif custom_llm_provider == "text-completion-codestral": + api_base = ( + api_base + or optional_params.pop("api_base", None) + or optional_params.pop("base_url", None) + or litellm.api_base + or "https://codestral.mistral.ai/v1/fim/completions" + ) + + api_key = ( + api_key or litellm.api_key or get_secret("CODESTRAL_API_KEY") + ) + + text_completion_model_response = litellm.TextCompletionResponse( + stream=stream + ) + + _model_response = codestral_text_completions.completion( # type: ignore + model=model, + messages=messages, + model_response=text_completion_model_response, + print_verbose=print_verbose, + optional_params=optional_params, + litellm_params=litellm_params, + logger_fn=logger_fn, + encoding=encoding, + logging_obj=logging, + acompletion=acompletion, + api_base=api_base, + custom_prompt_dict=custom_prompt_dict, + api_key=api_key, + timeout=timeout, + ) + + if ( + "stream" in optional_params + and optional_params["stream"] is True + and acompletion is False + ): + return _model_response + response = _model_response + elif custom_llm_provider == "ai21": + custom_llm_provider = "ai21" + ai21_key = ( + api_key + or litellm.ai21_key + or os.environ.get("AI21_API_KEY") + or litellm.api_key + ) + + api_base = ( + api_base + or litellm.api_base + or get_secret("AI21_API_BASE") + or "https://api.ai21.com/studio/v1/" + ) + + model_response = ai21.completion( + model=model, + messages=messages, + api_base=api_base, + model_response=model_response, + print_verbose=print_verbose, + optional_params=optional_params, + litellm_params=litellm_params, + logger_fn=logger_fn, + encoding=encoding, + api_key=ai21_key, + logging_obj=logging, + ) + + if ( + "stream" in optional_params + and optional_params["stream"] == True + ): + # don't try to access stream object, + response = CustomStreamWrapper( + model_response, + model, + custom_llm_provider="ai21", + logging_obj=logging, + ) + return response + + ## RESPONSE OBJECT + response = model_response + elif ( + custom_llm_provider == "sagemaker" + or custom_llm_provider == "sagemaker_chat" + ): + # boto3 reads keys from .env + model_response = sagemaker_llm.completion( + model=model, + messages=messages, + model_response=model_response, + print_verbose=print_verbose, + optional_params=optional_params, + litellm_params=litellm_params, + custom_prompt_dict=custom_prompt_dict, + hf_model_name=hf_model_name, + logger_fn=logger_fn, + encoding=encoding, + logging_obj=logging, + acompletion=acompletion, + use_messages_api=( + True if custom_llm_provider == "sagemaker_chat" else False + ), + ) + if optional_params.get("stream", False): + ## LOGGING + logging.post_call( + input=messages, + api_key=None, + original_response=model_response, + ) + + ## RESPONSE OBJECT + response = model_response + elif custom_llm_provider == "bedrock": + # boto3 reads keys from .env + custom_prompt_dict = ( + custom_prompt_dict or litellm.custom_prompt_dict + ) + + if "aws_bedrock_client" in optional_params: + verbose_logger.warning( + "'aws_bedrock_client' is a deprecated param. Please move to another auth method - https://docs.litellm.ai/docs/providers/bedrock#boto3---authentication." + ) + # Extract credentials for legacy boto3 client and pass thru to httpx + aws_bedrock_client = optional_params.pop("aws_bedrock_client") + creds = ( + aws_bedrock_client._get_credentials().get_frozen_credentials() + ) + + if creds.access_key: + optional_params["aws_access_key_id"] = creds.access_key + if creds.secret_key: + optional_params["aws_secret_access_key"] = creds.secret_key + if creds.token: + optional_params["aws_session_token"] = creds.token + if ( + "aws_region_name" not in optional_params + or optional_params["aws_region_name"] is None + ): + optional_params[ + "aws_region_name" + ] = aws_bedrock_client.meta.region_name + + if model in litellm.BEDROCK_CONVERSE_MODELS: + response = bedrock_converse_chat_completion.completion( + model=model, + messages=messages, + custom_prompt_dict=custom_prompt_dict, + model_response=model_response, + print_verbose=print_verbose, + optional_params=optional_params, + litellm_params=litellm_params, # type: ignore + logger_fn=logger_fn, + encoding=encoding, + logging_obj=logging, + extra_headers=extra_headers, + timeout=timeout, + acompletion=acompletion, + client=client, + api_base=api_base, + ) + else: + response = bedrock_chat_completion.completion( + model=model, + messages=messages, + custom_prompt_dict=custom_prompt_dict, + model_response=model_response, + print_verbose=print_verbose, + optional_params=optional_params, + litellm_params=litellm_params, + logger_fn=logger_fn, + encoding=encoding, + logging_obj=logging, + extra_headers=extra_headers, + timeout=timeout, + acompletion=acompletion, + client=client, + api_base=api_base, + ) + + if optional_params.get("stream", False): + ## LOGGING + logging.post_call( + input=messages, + api_key=None, + original_response=response, + ) + + ## RESPONSE OBJECT + response = response + elif custom_llm_provider == "watsonx": + custom_prompt_dict = ( + custom_prompt_dict or litellm.custom_prompt_dict + ) + response = watsonxai.completion( + model=model, + messages=messages, + custom_prompt_dict=custom_prompt_dict, + model_response=model_response, + print_verbose=print_verbose, + optional_params=optional_params, + litellm_params=litellm_params, # type: ignore + logger_fn=logger_fn, + encoding=encoding, + logging_obj=logging, + timeout=timeout, # type: ignore + acompletion=acompletion, + ) + if ( + "stream" in optional_params + and optional_params["stream"] == True + and not isinstance(response, CustomStreamWrapper) + ): + # don't try to access stream object, + response = CustomStreamWrapper( + iter(response), + model, + custom_llm_provider="watsonx", + logging_obj=logging, + ) + + if optional_params.get("stream", False): + ## LOGGING + logging.post_call( + input=messages, + api_key=None, + original_response=response, + ) + ## RESPONSE OBJECT + response = response + elif custom_llm_provider == "vllm": + custom_prompt_dict = ( + custom_prompt_dict or litellm.custom_prompt_dict + ) + model_response = vllm.completion( + model=model, + messages=messages, + custom_prompt_dict=custom_prompt_dict, + model_response=model_response, + print_verbose=print_verbose, + optional_params=optional_params, + litellm_params=litellm_params, + logger_fn=logger_fn, + encoding=encoding, + logging_obj=logging, + ) + + if ( + "stream" in optional_params + and optional_params["stream"] == True + ): ## [BETA] + # don't try to access stream object, + response = CustomStreamWrapper( + model_response, + model, + custom_llm_provider="vllm", + logging_obj=logging, + ) + return response + + ## RESPONSE OBJECT + response = model_response + elif custom_llm_provider == "ollama": + api_base = ( + litellm.api_base + or api_base + or get_secret("OLLAMA_API_BASE") + or "http://localhost:11434" + ) + custom_prompt_dict = ( + custom_prompt_dict or litellm.custom_prompt_dict + ) + if model in custom_prompt_dict: + # check if the model has a registered custom prompt + model_prompt_details = custom_prompt_dict[model] + prompt = custom_prompt( + role_dict=model_prompt_details["roles"], + initial_prompt_value=model_prompt_details[ + "initial_prompt_value" + ], + final_prompt_value=model_prompt_details[ + "final_prompt_value" + ], + messages=messages, + ) + else: + prompt = prompt_factory( + model=model, + messages=messages, + custom_llm_provider=custom_llm_provider, + ) + if isinstance(prompt, dict): + # for multimode models - ollama/llava prompt_factory returns a dict { + # "prompt": prompt, + # "images": images + # } + prompt, images = prompt["prompt"], prompt["images"] + optional_params["images"] = images + + ## LOGGING + generator = ollama.get_ollama_response( + api_base=api_base, + model=model, + prompt=prompt, + optional_params=optional_params, + logging_obj=logging, + acompletion=acompletion, + model_response=model_response, + encoding=encoding, + ) + if ( + acompletion is True + or optional_params.get("stream", False) == True + ): + return generator + + response = generator + elif custom_llm_provider == "ollama_chat": + api_base = ( + litellm.api_base + or api_base + or get_secret("OLLAMA_API_BASE") + or "http://localhost:11434" + ) + + api_key = ( + api_key + or litellm.ollama_key + or os.environ.get("OLLAMA_API_KEY") + or litellm.api_key + ) + ## LOGGING + generator = ollama_chat.get_ollama_response( + api_base=api_base, + api_key=api_key, + model=model, + messages=messages, + optional_params=optional_params, + logging_obj=logging, + acompletion=acompletion, + model_response=model_response, + encoding=encoding, + ) + if ( + acompletion is True + or optional_params.get("stream", False) is True + ): + return generator + + response = generator + + elif custom_llm_provider == "triton": + api_base = litellm.api_base or api_base + model_response = triton_chat_completions.completion( + api_base=api_base, + timeout=timeout, # type: ignore + model=model, + messages=messages, + model_response=model_response, + optional_params=optional_params, + logging_obj=logging, + stream=stream, + acompletion=acompletion, + ) + + ## RESPONSE OBJECT + response = model_response + return response + + elif custom_llm_provider == "cloudflare": + api_key = ( + api_key + or litellm.cloudflare_api_key + or litellm.api_key + or get_secret("CLOUDFLARE_API_KEY") + ) + account_id = get_secret("CLOUDFLARE_ACCOUNT_ID") + api_base = ( + api_base + or litellm.api_base + or get_secret("CLOUDFLARE_API_BASE") + or f"https://api.cloudflare.com/client/v4/accounts/{account_id}/ai/run/" + ) + + custom_prompt_dict = ( + custom_prompt_dict or litellm.custom_prompt_dict + ) + response = cloudflare.completion( + model=model, + messages=messages, + api_base=api_base, + custom_prompt_dict=litellm.custom_prompt_dict, + model_response=model_response, + print_verbose=print_verbose, + optional_params=optional_params, + litellm_params=litellm_params, + logger_fn=logger_fn, + encoding=encoding, # for calculating input/output tokens + api_key=api_key, + logging_obj=logging, + ) + if ( + "stream" in optional_params + and optional_params["stream"] == True + ): + # don't try to access stream object, + response = CustomStreamWrapper( + response, + model, + custom_llm_provider="cloudflare", + logging_obj=logging, + ) + + if optional_params.get("stream", False) or acompletion == True: + ## LOGGING + logging.post_call( + input=messages, + api_key=api_key, + original_response=response, + ) + response = response + elif ( + custom_llm_provider == "baseten" + or litellm.api_base == "https://app.baseten.co" + ): + custom_llm_provider = "baseten" + baseten_key = ( + api_key + or litellm.baseten_key + or os.environ.get("BASETEN_API_KEY") + or litellm.api_key + ) + + model_response = baseten.completion( + model=model, + messages=messages, + model_response=model_response, + print_verbose=print_verbose, + optional_params=optional_params, + litellm_params=litellm_params, + logger_fn=logger_fn, + encoding=encoding, + api_key=baseten_key, + logging_obj=logging, + ) + if inspect.isgenerator(model_response) or ( + "stream" in optional_params + and optional_params["stream"] == True + ): + # don't try to access stream object, + response = CustomStreamWrapper( + model_response, + model, + custom_llm_provider="baseten", + logging_obj=logging, + ) + return response + response = model_response + elif custom_llm_provider == "petals" or model in litellm.petals_models: + api_base = api_base or litellm.api_base + + custom_llm_provider = "petals" + stream = optional_params.pop("stream", False) + model_response = petals.completion( + model=model, + messages=messages, + api_base=api_base, + model_response=model_response, + print_verbose=print_verbose, + optional_params=optional_params, + litellm_params=litellm_params, + logger_fn=logger_fn, + encoding=encoding, + logging_obj=logging, + ) + if stream == True: ## [BETA] + # Fake streaming for petals + resp_string = model_response["choices"][0]["message"][ + "content" + ] + response = CustomStreamWrapper( + resp_string, + model, + custom_llm_provider="petals", + logging_obj=logging, + ) + return response + response = model_response + elif custom_llm_provider == "custom": + import requests + + url = litellm.api_base or api_base or "" + if url == None or url == "": + raise ValueError( + "api_base not set. Set api_base or litellm.api_base for custom endpoints" + ) + + """ + assume input to custom LLM api bases follow this format: + resp = requests.post( + api_base, + json={ + 'model': 'meta-llama/Llama-2-13b-hf', # model name + 'params': { + 'prompt': ["The capital of France is P"], + 'max_tokens': 32, + 'temperature': 0.7, + 'top_p': 1.0, + 'top_k': 40, + } + } + ) + + """ + prompt = " ".join([message["content"] for message in messages]) # type: ignore + resp = requests.post( + url, + json={ + "model": model, + "params": { + "prompt": [prompt], + "max_tokens": max_tokens, + "temperature": temperature, + "top_p": top_p, + "top_k": kwargs.get("top_k", 40), + }, + }, + verify=litellm.ssl_verify, + ) + response_json = resp.json() + """ + assume all responses from custom api_bases of this format: + { + 'data': [ + { + 'prompt': 'The capital of France is P', + 'output': ['The capital of France is PARIS.\nThe capital of France is PARIS.\nThe capital of France is PARIS.\nThe capital of France is PARIS.\nThe capital of France is PARIS.\nThe capital of France is PARIS.\nThe capital of France is PARIS.\nThe capital of France is PARIS.\nThe capital of France is PARIS.\nThe capital of France is PARIS.\nThe capital of France is PARIS.\nThe capital of France is PARIS.\nThe capital of France is PARIS.\nThe capital of France'], + 'params': {'temperature': 0.7, 'top_k': 40, 'top_p': 1}}], + 'message': 'ok' + } + ] + } + """ + string_response = response_json["data"][0]["output"][0] + ## RESPONSE OBJECT + model_response.choices[0].message.content = string_response # type: ignore + model_response.created = int(time.time()) + model_response.model = model + response = model_response + elif ( + custom_llm_provider in litellm._custom_providers + ): # Assume custom LLM provider + # Get the Custom Handler + custom_handler: Optional[CustomLLM] = None + for item in litellm.custom_provider_map: + if item["provider"] == custom_llm_provider: + custom_handler = item["custom_handler"] + + if custom_handler is None: + raise ValueError( + f"Unable to map your input to a model. Check your input - {args}" + ) + + ## ROUTE LLM CALL ## + handler_fn = custom_chat_llm_router( + async_fn=acompletion, stream=stream, custom_llm=custom_handler + ) + + headers = headers or litellm.headers + + ## CALL FUNCTION + response = handler_fn( + model=model, + messages=messages, + headers=headers, + model_response=model_response, + print_verbose=print_verbose, + api_key=api_key, + api_base=api_base, + acompletion=acompletion, + logging_obj=logging, + optional_params=optional_params, + litellm_params=litellm_params, + logger_fn=logger_fn, + timeout=timeout, # type: ignore + custom_prompt_dict=custom_prompt_dict, + client=client, # pass AsyncOpenAI, OpenAI client + encoding=encoding, + ) + if stream is True: + return CustomStreamWrapper( + completion_stream=response, + model=model, + custom_llm_provider=custom_llm_provider, + logging_obj=logging, + ) + + else: + raise ValueError( + f"Unable to map your input to a model. Check your input - {args}" + ) + return response + except Exception as e: + ## Map to OpenAI Exception + raise exception_type( + model=model, + custom_llm_provider=custom_llm_provider, + original_exception=e, + completion_kwargs=args, + extra_kwargs=kwargs, + ) + + +def completion_with_retries(*args, **kwargs): + """ + Executes a litellm.completion() with 3 retries + """ + try: + import tenacity + except Exception as e: + raise Exception( + f"tenacity import failed please run `pip install tenacity`. Error{e}" + ) + + num_retries = kwargs.pop("num_retries", 3) + retry_strategy = kwargs.pop("retry_strategy", "constant_retry") + original_function = kwargs.pop("original_function", completion) + if retry_strategy == "constant_retry": + retryer = tenacity.Retrying( + stop=tenacity.stop_after_attempt(num_retries), reraise=True + ) + elif retry_strategy == "exponential_backoff_retry": + retryer = tenacity.Retrying( + wait=tenacity.wait_exponential(multiplier=1, max=10), + stop=tenacity.stop_after_attempt(num_retries), + reraise=True, + ) + return retryer(original_function, *args, **kwargs) + + +async def acompletion_with_retries(*args, **kwargs): + """ + [DEPRECATED]. Use 'acompletion' or router.acompletion instead! + Executes a litellm.completion() with 3 retries + """ + try: + import tenacity + except Exception as e: + raise Exception( + f"tenacity import failed please run `pip install tenacity`. Error{e}" + ) + + num_retries = kwargs.pop("num_retries", 3) + retry_strategy = kwargs.pop("retry_strategy", "constant_retry") + original_function = kwargs.pop("original_function", completion) + if retry_strategy == "constant_retry": + retryer = tenacity.Retrying( + stop=tenacity.stop_after_attempt(num_retries), reraise=True + ) + elif retry_strategy == "exponential_backoff_retry": + retryer = tenacity.Retrying( + wait=tenacity.wait_exponential(multiplier=1, max=10), + stop=tenacity.stop_after_attempt(num_retries), + reraise=True, + ) + return await retryer(original_function, *args, **kwargs) + + +def batch_completion( + model: str, + # Optional OpenAI params: see https://platform.openai.com/docs/api-reference/chat/create + messages: List = [], + functions: Optional[List] = None, + function_call: Optional[str] = None, + temperature: Optional[float] = None, + top_p: Optional[float] = None, + n: Optional[int] = None, + stream: Optional[bool] = None, + stop=None, + max_tokens: Optional[int] = None, + presence_penalty: Optional[float] = None, + frequency_penalty: Optional[float] = None, + logit_bias: Optional[dict] = None, + user: Optional[str] = None, + deployment_id=None, + request_timeout: Optional[int] = None, + timeout: Optional[int] = 600, + # Optional liteLLM function params + **kwargs, +): + """ + Batch litellm.completion function for a given model. + + Args: + model (str): The model to use for generating completions. + messages (List, optional): List of messages to use as input for generating completions. Defaults to []. + functions (List, optional): List of functions to use as input for generating completions. Defaults to []. + function_call (str, optional): The function call to use as input for generating completions. Defaults to "". + temperature (float, optional): The temperature parameter for generating completions. Defaults to None. + top_p (float, optional): The top-p parameter for generating completions. Defaults to None. + n (int, optional): The number of completions to generate. Defaults to None. + stream (bool, optional): Whether to stream completions or not. Defaults to None. + stop (optional): The stop parameter for generating completions. Defaults to None. + max_tokens (float, optional): The maximum number of tokens to generate. Defaults to None. + presence_penalty (float, optional): The presence penalty for generating completions. Defaults to None. + frequency_penalty (float, optional): The frequency penalty for generating completions. Defaults to None. + logit_bias (dict, optional): The logit bias for generating completions. Defaults to {}. + user (str, optional): The user string for generating completions. Defaults to "". + deployment_id (optional): The deployment ID for generating completions. Defaults to None. + request_timeout (int, optional): The request timeout for generating completions. Defaults to None. + + Returns: + list: A list of completion results. + """ + args = locals() + + batch_messages = messages + completions = [] + model = model + custom_llm_provider = None + if model.split("/", 1)[0] in litellm.provider_list: + custom_llm_provider = model.split("/", 1)[0] + model = model.split("/", 1)[1] + if custom_llm_provider == "vllm": + optional_params = get_optional_params( + functions=functions, + function_call=function_call, + temperature=temperature, + top_p=top_p, + n=n, + stream=stream, + stop=stop, + max_tokens=max_tokens, + presence_penalty=presence_penalty, + frequency_penalty=frequency_penalty, + logit_bias=logit_bias, + user=user, + # params to identify the model + model=model, + custom_llm_provider=custom_llm_provider, + ) + results = vllm.batch_completions( + model=model, + messages=batch_messages, + custom_prompt_dict=litellm.custom_prompt_dict, + optional_params=optional_params, + ) + # all non VLLM models for batch completion models + else: + + def chunks(lst, n): + """Yield successive n-sized chunks from lst.""" + for i in range(0, len(lst), n): + yield lst[i : i + n] + + with ThreadPoolExecutor(max_workers=100) as executor: + for sub_batch in chunks(batch_messages, 100): + for message_list in sub_batch: + kwargs_modified = args.copy() + kwargs_modified["messages"] = message_list + original_kwargs = {} + if "kwargs" in kwargs_modified: + original_kwargs = kwargs_modified.pop("kwargs") + future = executor.submit( + completion, **kwargs_modified, **original_kwargs + ) + completions.append(future) + + # Retrieve the results from the futures + # results = [future.result() for future in completions] + # return exceptions if any + results = [] + for future in completions: + try: + results.append(future.result()) + except Exception as exc: + results.append(exc) + + return results + + +# send one request to multiple models +# return as soon as one of the llms responds +def batch_completion_models(*args, **kwargs): + """ + Send a request to multiple language models concurrently and return the response + as soon as one of the models responds. + + Args: + *args: Variable-length positional arguments passed to the completion function. + **kwargs: Additional keyword arguments: + - models (str or list of str): The language models to send requests to. + - Other keyword arguments to be passed to the completion function. + + Returns: + str or None: The response from one of the language models, or None if no response is received. + + Note: + This function utilizes a ThreadPoolExecutor to parallelize requests to multiple models. + It sends requests concurrently and returns the response from the first model that responds. + """ + import concurrent + + if "model" in kwargs: + kwargs.pop("model") + if "models" in kwargs: + models = kwargs["models"] + kwargs.pop("models") + futures = {} + with concurrent.futures.ThreadPoolExecutor( + max_workers=len(models) + ) as executor: + for model in models: + futures[model] = executor.submit( + completion, *args, model=model, **kwargs + ) + + for model, future in sorted( + futures.items(), key=lambda x: models.index(x[0]) + ): + if future.result() is not None: + return future.result() + elif "deployments" in kwargs: + deployments = kwargs["deployments"] + kwargs.pop("deployments") + kwargs.pop("model_list") + nested_kwargs = kwargs.pop("kwargs", {}) + futures = {} + with concurrent.futures.ThreadPoolExecutor( + max_workers=len(deployments) + ) as executor: + for deployment in deployments: + for key in kwargs.keys(): + if ( + key not in deployment + ): # don't override deployment values e.g. model name, api base, etc. + deployment[key] = kwargs[key] + kwargs = {**deployment, **nested_kwargs} + futures[deployment["model"]] = executor.submit( + completion, **kwargs + ) + + while futures: + # wait for the first returned future + print_verbose("\n\n waiting for next result\n\n") + done, _ = concurrent.futures.wait( + futures.values(), + return_when=concurrent.futures.FIRST_COMPLETED, + ) + print_verbose(f"done list\n{done}") + for future in done: + try: + result = future.result() + return result + except Exception as e: + # if model 1 fails, continue with response from model 2, model3 + print_verbose( + f"\n\ngot an exception, ignoring, removing from futures" + ) + print_verbose(futures) + new_futures = {} + for key, value in futures.items(): + if future == value: + print_verbose(f"removing key{key}") + continue + else: + new_futures[key] = value + futures = new_futures + print_verbose(f"new futures{futures}") + continue + + print_verbose("\n\ndone looping through futures\n\n") + print_verbose(futures) + + return None # If no response is received from any model + + +def batch_completion_models_all_responses(*args, **kwargs): + """ + Send a request to multiple language models concurrently and return a list of responses + from all models that respond. + + Args: + *args: Variable-length positional arguments passed to the completion function. + **kwargs: Additional keyword arguments: + - models (str or list of str): The language models to send requests to. + - Other keyword arguments to be passed to the completion function. + + Returns: + list: A list of responses from the language models that responded. + + Note: + This function utilizes a ThreadPoolExecutor to parallelize requests to multiple models. + It sends requests concurrently and collects responses from all models that respond. + """ + import concurrent.futures + + # ANSI escape codes for colored output + GREEN = "\033[92m" + RED = "\033[91m" + RESET = "\033[0m" + + if "model" in kwargs: + kwargs.pop("model") + if "models" in kwargs: + models = kwargs["models"] + kwargs.pop("models") + + responses = [] + + with concurrent.futures.ThreadPoolExecutor( + max_workers=len(models) + ) as executor: + for idx, model in enumerate(models): + future = executor.submit(completion, *args, model=model, **kwargs) + if future.result() is not None: + responses.append(future.result()) + + return responses + + +### EMBEDDING ENDPOINTS #################### +@client +async def aembedding(*args, **kwargs) -> EmbeddingResponse: + """ + Asynchronously calls the `embedding` function with the given arguments and keyword arguments. + + Parameters: + - `args` (tuple): Positional arguments to be passed to the `embedding` function. + - `kwargs` (dict): Keyword arguments to be passed to the `embedding` function. + + Returns: + - `response` (Any): The response returned by the `embedding` function. + """ + loop = asyncio.get_event_loop() + model = args[0] if len(args) > 0 else kwargs["model"] + ### PASS ARGS TO Embedding ### + kwargs["aembedding"] = True + custom_llm_provider = None + try: + # Use a partial function to pass your keyword arguments + func = partial(embedding, *args, **kwargs) + + # Add the context to the function + ctx = contextvars.copy_context() + func_with_context = partial(ctx.run, func) + + _, custom_llm_provider, _, _ = get_llm_provider( + model=model, api_base=kwargs.get("api_base", None) + ) + + if ( + custom_llm_provider == "openai" + or custom_llm_provider == "azure" + or custom_llm_provider == "xinference" + or custom_llm_provider == "voyage" + or custom_llm_provider == "mistral" + or custom_llm_provider == "custom_openai" + or custom_llm_provider == "triton" + or custom_llm_provider == "anyscale" + or custom_llm_provider == "openrouter" + or custom_llm_provider == "deepinfra" + or custom_llm_provider == "perplexity" + or custom_llm_provider == "groq" + or custom_llm_provider == "nvidia_nim" + or custom_llm_provider == "cerebras" + or custom_llm_provider == "ai21_chat" + or custom_llm_provider == "volcengine" + or custom_llm_provider == "deepseek" + or custom_llm_provider == "fireworks_ai" + or custom_llm_provider == "ollama" + or custom_llm_provider == "vertex_ai" + or custom_llm_provider == "gemini" + or custom_llm_provider == "databricks" + or custom_llm_provider == "watsonx" + or custom_llm_provider == "cohere" + or custom_llm_provider == "huggingface" + or custom_llm_provider == "bedrock" + ): # currently implemented aiohttp calls for just azure and openai, soon all. + # Await normally + init_response = await loop.run_in_executor(None, func_with_context) + if isinstance(init_response, dict): + response = EmbeddingResponse(**init_response) + elif isinstance( + init_response, EmbeddingResponse + ): ## CACHING SCENARIO + response = init_response + elif asyncio.iscoroutine(init_response): + response = await init_response + else: + # Call the synchronous function using run_in_executor + response = await loop.run_in_executor(None, func_with_context) + if response is not None and hasattr(response, "_hidden_params"): + response._hidden_params[ + "custom_llm_provider" + ] = custom_llm_provider + return response + except Exception as e: + custom_llm_provider = custom_llm_provider or "openai" + raise exception_type( + model=model, + custom_llm_provider=custom_llm_provider, + original_exception=e, + completion_kwargs=args, + extra_kwargs=kwargs, + ) + + +@client +def embedding( + model, + input=[], + # Optional params + dimensions: Optional[int] = None, + timeout=600, # default to 10 minutes + # set api_base, api_version, api_key + api_base: Optional[str] = None, + api_version: Optional[str] = None, + api_key: Optional[str] = None, + api_type: Optional[str] = None, + caching: bool = False, + user: Optional[str] = None, + custom_llm_provider=None, + litellm_call_id=None, + litellm_logging_obj=None, + logger_fn=None, + **kwargs, +) -> EmbeddingResponse: + """ + Embedding function that calls an API to generate embeddings for the given input. + + Parameters: + - model: The embedding model to use. + - input: The input for which embeddings are to be generated. + - dimensions: The number of dimensions the resulting output embeddings should have. Only supported in text-embedding-3 and later models. + - timeout: The timeout value for the API call, default 10 mins + - litellm_call_id: The call ID for litellm logging. + - litellm_logging_obj: The litellm logging object. + - logger_fn: The logger function. + - api_base: Optional. The base URL for the API. + - api_version: Optional. The version of the API. + - api_key: Optional. The API key to use. + - api_type: Optional. The type of the API. + - caching: A boolean indicating whether to enable caching. + - custom_llm_provider: The custom llm provider. + + Returns: + - response: The response received from the API call. + + Raises: + - exception_type: If an exception occurs during the API call. + """ + azure = kwargs.get("azure", None) + client = kwargs.pop("client", None) + rpm = kwargs.pop("rpm", None) + tpm = kwargs.pop("tpm", None) + cooldown_time = kwargs.get("cooldown_time", None) + max_parallel_requests = kwargs.pop("max_parallel_requests", None) + model_info = kwargs.get("model_info", None) + metadata = kwargs.get("metadata", None) + encoding_format = kwargs.get("encoding_format", None) + proxy_server_request = kwargs.get("proxy_server_request", None) + aembedding = kwargs.get("aembedding", None) + extra_headers = kwargs.get("extra_headers", None) + ### CUSTOM MODEL COST ### + input_cost_per_token = kwargs.get("input_cost_per_token", None) + output_cost_per_token = kwargs.get("output_cost_per_token", None) + input_cost_per_second = kwargs.get("input_cost_per_second", None) + output_cost_per_second = kwargs.get("output_cost_per_second", None) + openai_params = [ + "user", + "dimensions", + "request_timeout", + "api_base", + "api_version", + "api_key", + "deployment_id", + "organization", + "base_url", + "default_headers", + "timeout", + "max_retries", + "encoding_format", + ] + litellm_params = [ + "metadata", + "aembedding", + "caching", + "mock_response", + "api_key", + "api_version", + "api_base", + "force_timeout", + "logger_fn", + "verbose", + "custom_llm_provider", + "litellm_logging_obj", + "litellm_call_id", + "use_client", + "id", + "fallbacks", + "azure", + "headers", + "model_list", + "num_retries", + "context_window_fallback_dict", + "retry_policy", + "roles", + "final_prompt_value", + "bos_token", + "eos_token", + "request_timeout", + "complete_response", + "self", + "client", + "rpm", + "tpm", + "max_parallel_requests", + "input_cost_per_token", + "output_cost_per_token", + "input_cost_per_second", + "output_cost_per_second", + "hf_model_name", + "proxy_server_request", + "model_info", + "preset_cache_key", + "caching_groups", + "ttl", + "cache", + "no-log", + "region_name", + "allowed_model_region", + "model_config", + "cooldown_time", + "tags", + "azure_ad_token_provider", + "tenant_id", + "client_id", + "client_secret", + "extra_headers", + ] + default_params = openai_params + litellm_params + non_default_params = { + k: v for k, v in kwargs.items() if k not in default_params + } # model-specific params - pass them straight to the model/provider + + model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider( + model=model, + custom_llm_provider=custom_llm_provider, + api_base=api_base, + api_key=api_key, + ) + optional_params = get_optional_params_embeddings( + model=model, + user=user, + dimensions=dimensions, + encoding_format=encoding_format, + custom_llm_provider=custom_llm_provider, + **non_default_params, + ) + ### REGISTER CUSTOM MODEL PRICING -- IF GIVEN ### + if input_cost_per_token is not None and output_cost_per_token is not None: + litellm.register_model( + { + model: { + "input_cost_per_token": input_cost_per_token, + "output_cost_per_token": output_cost_per_token, + "litellm_provider": custom_llm_provider, + } + } + ) + if ( + input_cost_per_second is not None + ): # time based pricing just needs cost in place + output_cost_per_second = output_cost_per_second or 0.0 + litellm.register_model( + { + model: { + "input_cost_per_second": input_cost_per_second, + "output_cost_per_second": output_cost_per_second, + "litellm_provider": custom_llm_provider, + } + } + ) + try: + response = None + logging: Logging = litellm_logging_obj # type: ignore + logging.update_environment_variables( + model=model, + user=user, + optional_params=optional_params, + litellm_params={ + "timeout": timeout, + "azure": azure, + "litellm_call_id": litellm_call_id, + "logger_fn": logger_fn, + "proxy_server_request": proxy_server_request, + "model_info": model_info, + "metadata": metadata, + "aembedding": aembedding, + "preset_cache_key": None, + "stream_response": {}, + "cooldown_time": cooldown_time, + }, + ) + if azure is True or custom_llm_provider == "azure": + # azure configs + api_type = get_secret("AZURE_API_TYPE") or "azure" + + api_base = ( + api_base or litellm.api_base or get_secret("AZURE_API_BASE") + ) + + api_version = ( + api_version + or litellm.api_version + or get_secret("AZURE_API_VERSION") + or litellm.AZURE_DEFAULT_API_VERSION + ) + + azure_ad_token = optional_params.pop( + "azure_ad_token", None + ) or get_secret("AZURE_AD_TOKEN") + + api_key = ( + api_key + or litellm.api_key + or litellm.azure_key + or get_secret("AZURE_API_KEY") + ) + ## EMBEDDING CALL + response = azure_chat_completions.embedding( + model=model, + input=input, + api_base=api_base, + api_key=api_key, + api_version=api_version, + azure_ad_token=azure_ad_token, + logging_obj=logging, + timeout=timeout, + model_response=EmbeddingResponse(), + optional_params=optional_params, + client=client, + aembedding=aembedding, + ) + elif ( + model in litellm.open_ai_embedding_models + or custom_llm_provider == "openai" + ): + api_base = ( + api_base + or litellm.api_base + or get_secret("OPENAI_API_BASE") + or "https://api.openai.com/v1" + ) + openai.organization = ( + litellm.organization + or get_secret("OPENAI_ORGANIZATION") + or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105 + ) + # set API KEY + api_key = ( + api_key + or litellm.api_key + or litellm.openai_key + or get_secret("OPENAI_API_KEY") + ) + api_type = "openai" + api_version = None + + ## EMBEDDING CALL + response = openai_chat_completions.embedding( + model=model, + input=input, + api_base=api_base, + api_key=api_key, + logging_obj=logging, + timeout=timeout, + model_response=EmbeddingResponse(), + optional_params=optional_params, + client=client, + aembedding=aembedding, + ) + elif custom_llm_provider == "databricks": + api_base = ( + api_base + or litellm.api_base + or get_secret("DATABRICKS_API_BASE") + ) # type: ignore + + # set API KEY + api_key = ( + api_key + or litellm.api_key + or litellm.databricks_key + or get_secret("DATABRICKS_API_KEY") + ) # type: ignore + + ## EMBEDDING CALL + response = databricks_chat_completions.embedding( + model=model, + input=input, + api_base=api_base, + api_key=api_key, + logging_obj=logging, + timeout=timeout, + model_response=EmbeddingResponse(), + optional_params=optional_params, + client=client, + aembedding=aembedding, + ) + elif ( + custom_llm_provider == "cohere" + or custom_llm_provider == "cohere_chat" + ): + cohere_key = ( + api_key + or litellm.cohere_key + or get_secret("COHERE_API_KEY") + or get_secret("CO_API_KEY") + or litellm.api_key + ) + + if extra_headers is not None and isinstance(extra_headers, dict): + headers = extra_headers + else: + headers = {} + response = cohere_embed.embedding( + model=model, + input=input, + optional_params=optional_params, + encoding=encoding, + api_key=cohere_key, # type: ignore + headers=headers, + logging_obj=logging, + model_response=EmbeddingResponse(), + aembedding=aembedding, + timeout=timeout, + client=client, + ) + elif custom_llm_provider == "huggingface": + api_key = ( + api_key + or litellm.huggingface_key + or get_secret("HUGGINGFACE_API_KEY") + or litellm.api_key + ) # type: ignore + response = huggingface.embedding( + model=model, + input=input, + encoding=encoding, # type: ignore + api_key=api_key, + api_base=api_base, + logging_obj=logging, + model_response=EmbeddingResponse(), + optional_params=optional_params, + client=client, + aembedding=aembedding, + ) + elif custom_llm_provider == "bedrock": + if isinstance(input, str): + transformed_input = [input] + else: + transformed_input = input + response = bedrock_embedding.embeddings( + model=model, + input=transformed_input, + encoding=encoding, + logging_obj=logging, + optional_params=optional_params, + model_response=EmbeddingResponse(), + client=client, + timeout=timeout, + aembedding=aembedding, + litellm_params=litellm_params, + api_base=api_base, + print_verbose=print_verbose, + extra_headers=extra_headers, + ) + elif custom_llm_provider == "triton": + if api_base is None: + raise ValueError( + "api_base is required for triton. Please pass `api_base`" + ) + response = triton_chat_completions.embedding( + model=model, + input=input, + api_base=api_base, + api_key=api_key, + logging_obj=logging, + timeout=timeout, + model_response=EmbeddingResponse(), + optional_params=optional_params, + client=client, + aembedding=aembedding, + ) + elif custom_llm_provider == "gemini": + gemini_api_key = ( + api_key or get_secret("GEMINI_API_KEY") or litellm.api_key + ) + + response = google_batch_embeddings.batch_embeddings( # type: ignore + model=model, + input=input, + encoding=encoding, + logging_obj=logging, + optional_params=optional_params, + model_response=EmbeddingResponse(), + vertex_project=None, + vertex_location=None, + vertex_credentials=None, + aembedding=aembedding, + print_verbose=print_verbose, + custom_llm_provider="gemini", + api_key=gemini_api_key, + ) + + elif custom_llm_provider == "vertex_ai": + vertex_ai_project = ( + optional_params.pop("vertex_project", None) + or optional_params.pop("vertex_ai_project", None) + or litellm.vertex_project + or get_secret("VERTEXAI_PROJECT") + or get_secret("VERTEX_PROJECT") + ) + vertex_ai_location = ( + optional_params.pop("vertex_location", None) + or optional_params.pop("vertex_ai_location", None) + or litellm.vertex_location + or get_secret("VERTEXAI_LOCATION") + or get_secret("VERTEX_LOCATION") + ) + vertex_credentials = ( + optional_params.pop("vertex_credentials", None) + or optional_params.pop("vertex_ai_credentials", None) + or get_secret("VERTEXAI_CREDENTIALS") + or get_secret("VERTEX_CREDENTIALS") + ) + + if ( + "image" in optional_params + or "video" in optional_params + or model + in vertex_multimodal_embedding.SUPPORTED_MULTIMODAL_EMBEDDING_MODELS + ): + # multimodal embedding is supported on vertex httpx + response = vertex_multimodal_embedding.multimodal_embedding( + model=model, + input=input, + encoding=encoding, + logging_obj=logging, + optional_params=optional_params, + model_response=EmbeddingResponse(), + vertex_project=vertex_ai_project, + vertex_location=vertex_ai_location, + vertex_credentials=vertex_credentials, + aembedding=aembedding, + print_verbose=print_verbose, + custom_llm_provider="vertex_ai", + ) + else: + response = vertex_ai_embedding_handler.embedding( + model=model, + input=input, + encoding=encoding, + logging_obj=logging, + optional_params=optional_params, + model_response=EmbeddingResponse(), + vertex_project=vertex_ai_project, + vertex_location=vertex_ai_location, + vertex_credentials=vertex_credentials, + aembedding=aembedding, + print_verbose=print_verbose, + ) + elif custom_llm_provider == "oobabooga": + response = oobabooga.embedding( + model=model, + input=input, + encoding=encoding, + api_base=api_base, + logging_obj=logging, + optional_params=optional_params, + model_response=EmbeddingResponse(), + ) + elif custom_llm_provider == "ollama": + api_base = ( + litellm.api_base + or api_base + or get_secret("OLLAMA_API_BASE") + or "http://localhost:11434" + ) # type: ignore + if isinstance(input, str): + input = [input] + if not all(isinstance(item, str) for item in input): + raise litellm.BadRequestError( + message=f"Invalid input for ollama embeddings. input={input}", + model=model, # type: ignore + llm_provider="ollama", # type: ignore + ) + ollama_embeddings_fn = ( + ollama.ollama_aembeddings + if aembedding is True + else ollama.ollama_embeddings + ) + response = ollama_embeddings_fn( # type: ignore + api_base=api_base, + model=model, + prompts=input, + encoding=encoding, + logging_obj=logging, + optional_params=optional_params, + model_response=EmbeddingResponse(), + ) + elif custom_llm_provider == "sagemaker": + response = sagemaker_llm.embedding( + model=model, + input=input, + encoding=encoding, + logging_obj=logging, + optional_params=optional_params, + model_response=EmbeddingResponse(), + print_verbose=print_verbose, + ) + elif custom_llm_provider == "mistral": + api_key = ( + api_key or litellm.api_key or get_secret("MISTRAL_API_KEY") + ) + response = openai_chat_completions.embedding( + model=model, + input=input, + api_base=api_base, + api_key=api_key, + logging_obj=logging, + timeout=timeout, + model_response=EmbeddingResponse(), + optional_params=optional_params, + client=client, + aembedding=aembedding, + ) + elif custom_llm_provider == "voyage": + api_key = ( + api_key or litellm.api_key or get_secret("VOYAGE_API_KEY") + ) + response = openai_chat_completions.embedding( + model=model, + input=input, + api_base=api_base, + api_key=api_key, + logging_obj=logging, + timeout=timeout, + model_response=EmbeddingResponse(), + optional_params=optional_params, + client=client, + aembedding=aembedding, + ) + elif custom_llm_provider == "xinference": + api_key = ( + api_key + or litellm.api_key + or get_secret("XINFERENCE_API_KEY") + or "stub-xinference-key" + ) # xinference does not need an api key, pass a stub key if user did not set one + api_base = ( + api_base + or litellm.api_base + or get_secret("XINFERENCE_API_BASE") + or "http://127.0.0.1:9997/v1" + ) + response = openai_chat_completions.embedding( + model=model, + input=input, + api_base=api_base, + api_key=api_key, + logging_obj=logging, + timeout=timeout, + model_response=EmbeddingResponse(), + optional_params=optional_params, + client=client, + aembedding=aembedding, + ) + elif custom_llm_provider == "watsonx": + response = watsonxai.embedding( + model=model, + input=input, + encoding=encoding, + logging_obj=logging, + optional_params=optional_params, + model_response=EmbeddingResponse(), + aembedding=aembedding, + ) + else: + args = locals() + raise ValueError( + f"No valid embedding model args passed in - {args}" + ) + if response is not None and hasattr(response, "_hidden_params"): + response._hidden_params[ + "custom_llm_provider" + ] = custom_llm_provider + return response + except Exception as e: + ## LOGGING + logging.post_call( + input=input, + api_key=api_key, + original_response=str(e), + ) + ## Map to OpenAI Exception + raise exception_type( + model=model, + original_exception=e, + custom_llm_provider=custom_llm_provider, + extra_kwargs=kwargs, + ) + + +###### Text Completion ################ +@client +async def atext_completion( + *args, **kwargs +) -> Union[TextCompletionResponse, TextCompletionStreamWrapper]: + """ + Implemented to handle async streaming for the text completion endpoint + """ + loop = asyncio.get_event_loop() + model = args[0] if len(args) > 0 else kwargs["model"] + ### PASS ARGS TO COMPLETION ### + kwargs["acompletion"] = True + custom_llm_provider = None + try: + # Use a partial function to pass your keyword arguments + func = partial(text_completion, *args, **kwargs) + + # Add the context to the function + ctx = contextvars.copy_context() + func_with_context = partial(ctx.run, func) + + _, custom_llm_provider, _, _ = get_llm_provider( + model=model, api_base=kwargs.get("api_base", None) + ) + + if ( + custom_llm_provider == "openai" + or custom_llm_provider == "azure" + or custom_llm_provider == "azure_text" + or custom_llm_provider == "custom_openai" + or custom_llm_provider == "anyscale" + or custom_llm_provider == "mistral" + or custom_llm_provider == "openrouter" + or custom_llm_provider == "deepinfra" + or custom_llm_provider == "perplexity" + or custom_llm_provider == "groq" + or custom_llm_provider == "nvidia_nim" + or custom_llm_provider == "cerebras" + or custom_llm_provider == "ai21_chat" + or custom_llm_provider == "volcengine" + or custom_llm_provider == "text-completion-codestral" + or custom_llm_provider == "deepseek" + or custom_llm_provider == "fireworks_ai" + or custom_llm_provider == "text-completion-openai" + or custom_llm_provider == "huggingface" + or custom_llm_provider == "ollama" + or custom_llm_provider == "vertex_ai" + or custom_llm_provider in litellm.openai_compatible_providers + ): # currently implemented aiohttp calls for just azure and openai, soon all. + # Await normally + response = await loop.run_in_executor(None, func_with_context) + if asyncio.iscoroutine(response): + response = await response + else: + # Call the synchronous function using run_in_executor + response = await loop.run_in_executor(None, func_with_context) + if kwargs.get("stream", False) is True: # return an async generator + return TextCompletionStreamWrapper( + completion_stream=_async_streaming( + response=response, + model=model, + custom_llm_provider=custom_llm_provider, + args=args, + ), + model=model, + custom_llm_provider=custom_llm_provider, + ) + else: + transformed_logprobs = None + # only supported for TGI models + try: + raw_response = response._hidden_params.get( + "original_response", None + ) + transformed_logprobs = litellm.utils.transform_logprobs( + raw_response + ) + except Exception as e: + print_verbose(f"LiteLLM non blocking exception: {e}") + + ## TRANSLATE CHAT TO TEXT FORMAT ## + if isinstance(response, TextCompletionResponse): + return response + elif asyncio.iscoroutine(response): + response = await response + + text_completion_response = TextCompletionResponse() + text_completion_response["id"] = response.get("id", None) + text_completion_response["object"] = "text_completion" + text_completion_response["created"] = response.get("created", None) + text_completion_response["model"] = response.get("model", None) + text_choices = TextChoices() + text_choices["text"] = response["choices"][0]["message"]["content"] + text_choices["index"] = response["choices"][0]["index"] + text_choices["logprobs"] = transformed_logprobs + text_choices["finish_reason"] = response["choices"][0][ + "finish_reason" + ] + text_completion_response["choices"] = [text_choices] + text_completion_response["usage"] = response.get("usage", None) + text_completion_response._hidden_params = HiddenParams( + **response._hidden_params + ) + return text_completion_response + except Exception as e: + custom_llm_provider = custom_llm_provider or "openai" + raise exception_type( + model=model, + custom_llm_provider=custom_llm_provider, + original_exception=e, + completion_kwargs=args, + extra_kwargs=kwargs, + ) + + +@client +def text_completion( + prompt: Union[ + str, List[Union[str, List[Union[str, List[int]]]]] + ], # Required: The prompt(s) to generate completions for. + model: Optional[ + str + ] = None, # Optional: either `model` or `engine` can be set + best_of: Optional[ + int + ] = None, # Optional: Generates best_of completions server-side. + echo: Optional[ + bool + ] = None, # Optional: Echo back the prompt in addition to the completion. + frequency_penalty: Optional[ + float + ] = None, # Optional: Penalize new tokens based on their existing frequency. + logit_bias: Optional[ + Dict[int, int] + ] = None, # Optional: Modify the likelihood of specified tokens. + logprobs: Optional[ + int + ] = None, # Optional: Include the log probabilities on the most likely tokens. + max_tokens: Optional[ + int + ] = None, # Optional: The maximum number of tokens to generate in the completion. + n: Optional[ + int + ] = None, # Optional: How many completions to generate for each prompt. + presence_penalty: Optional[ + float + ] = None, # Optional: Penalize new tokens based on whether they appear in the text so far. + stop: Optional[ + Union[str, List[str]] + ] = None, # Optional: Sequences where the API will stop generating further tokens. + stream: Optional[ + bool + ] = None, # Optional: Whether to stream back partial progress. + stream_options: Optional[dict] = None, + suffix: Optional[ + str + ] = None, # Optional: The suffix that comes after a completion of inserted text. + temperature: Optional[ + float + ] = None, # Optional: Sampling temperature to use. + top_p: Optional[float] = None, # Optional: Nucleus sampling parameter. + user: Optional[ + str + ] = None, # Optional: A unique identifier representing your end-user. + # set api_base, api_version, api_key + api_base: Optional[str] = None, + api_version: Optional[str] = None, + api_key: Optional[str] = None, + model_list: Optional[list] = None, # pass in a list of api_base,keys, etc. + # Optional liteLLM function params + custom_llm_provider: Optional[str] = None, + *args, + **kwargs, +): + global print_verbose + import copy + + """ + Generate text completions using the OpenAI API. + + Args: + model (str): ID of the model to use. + prompt (Union[str, List[Union[str, List[Union[str, List[int]]]]]): The prompt(s) to generate completions for. + best_of (Optional[int], optional): Generates best_of completions server-side. Defaults to 1. + echo (Optional[bool], optional): Echo back the prompt in addition to the completion. Defaults to False. + frequency_penalty (Optional[float], optional): Penalize new tokens based on their existing frequency. Defaults to 0. + logit_bias (Optional[Dict[int, int]], optional): Modify the likelihood of specified tokens. Defaults to None. + logprobs (Optional[int], optional): Include the log probabilities on the most likely tokens. Defaults to None. + max_tokens (Optional[int], optional): The maximum number of tokens to generate in the completion. Defaults to 16. + n (Optional[int], optional): How many completions to generate for each prompt. Defaults to 1. + presence_penalty (Optional[float], optional): Penalize new tokens based on whether they appear in the text so far. Defaults to 0. + stop (Optional[Union[str, List[str]]], optional): Sequences where the API will stop generating further tokens. Defaults to None. + stream (Optional[bool], optional): Whether to stream back partial progress. Defaults to False. + suffix (Optional[str], optional): The suffix that comes after a completion of inserted text. Defaults to None. + temperature (Optional[float], optional): Sampling temperature to use. Defaults to 1. + top_p (Optional[float], optional): Nucleus sampling parameter. Defaults to 1. + user (Optional[str], optional): A unique identifier representing your end-user. + Returns: + TextCompletionResponse: A response object containing the generated completion and associated metadata. + + Example: + Your example of how to use this function goes here. + """ + if "engine" in kwargs: + if model == None: + # only use engine when model not passed + model = kwargs["engine"] + kwargs.pop("engine") + + text_completion_response = TextCompletionResponse() + + optional_params: Dict[str, Any] = {} + # default values for all optional params are none, litellm only passes them to the llm when they are set to non None values + if best_of is not None: + optional_params["best_of"] = best_of + if echo is not None: + optional_params["echo"] = echo + if frequency_penalty is not None: + optional_params["frequency_penalty"] = frequency_penalty + if logit_bias is not None: + optional_params["logit_bias"] = logit_bias + if logprobs is not None: + optional_params["logprobs"] = logprobs + if max_tokens is not None: + optional_params["max_tokens"] = max_tokens + if n is not None: + optional_params["n"] = n + if presence_penalty is not None: + optional_params["presence_penalty"] = presence_penalty + if stop is not None: + optional_params["stop"] = stop + if stream is not None: + optional_params["stream"] = stream + if stream_options is not None: + optional_params["stream_options"] = stream_options + if suffix is not None: + optional_params["suffix"] = suffix + if temperature is not None: + optional_params["temperature"] = temperature + if top_p is not None: + optional_params["top_p"] = top_p + if user is not None: + optional_params["user"] = user + if api_base is not None: + optional_params["api_base"] = api_base + if api_version is not None: + optional_params["api_version"] = api_version + if api_key is not None: + optional_params["api_key"] = api_key + if custom_llm_provider is not None: + optional_params["custom_llm_provider"] = custom_llm_provider + + # get custom_llm_provider + _model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base) # type: ignore + + if custom_llm_provider == "huggingface": + # if echo == True, for TGI llms we need to set top_n_tokens to 3 + if echo == True: + # for tgi llms + if "top_n_tokens" not in kwargs: + kwargs["top_n_tokens"] = 3 + + # processing prompt - users can pass raw tokens to OpenAI Completion() + if type(prompt) == list: + import concurrent.futures + + tokenizer = tiktoken.encoding_for_model("text-davinci-003") + ## if it's a 2d list - each element in the list is a text_completion() request + if len(prompt) > 0 and type(prompt[0]) == list: + responses = [None for x in prompt] # init responses + + def process_prompt(i, individual_prompt): + decoded_prompt = tokenizer.decode(individual_prompt) + all_params = {**kwargs, **optional_params} + response = text_completion( + model=model, + prompt=decoded_prompt, + num_retries=3, # ensure this does not fail for the batch + *args, + **all_params, + ) + + text_completion_response["id"] = response.get("id", None) + text_completion_response["object"] = "text_completion" + text_completion_response["created"] = response.get( + "created", None + ) + text_completion_response["model"] = response.get( + "model", None + ) + return response["choices"][0] + + with concurrent.futures.ThreadPoolExecutor() as executor: + futures = [ + executor.submit(process_prompt, i, individual_prompt) + for i, individual_prompt in enumerate(prompt) + ] + for i, future in enumerate( + concurrent.futures.as_completed(futures) + ): + responses[i] = future.result() + text_completion_response.choices = responses # type: ignore + + return text_completion_response + # else: + # check if non default values passed in for best_of, echo, logprobs, suffix + # these are the params supported by Completion() but not ChatCompletion + + # default case, non OpenAI requests go through here + # handle prompt formatting if prompt is a string vs. list of strings + messages = [] + if ( + isinstance(prompt, list) + and len(prompt) > 0 + and isinstance(prompt[0], str) + ): + for p in prompt: + message = {"role": "user", "content": p} + messages.append(message) + elif isinstance(prompt, str): + messages = [{"role": "user", "content": prompt}] + elif ( + ( + custom_llm_provider == "openai" + or custom_llm_provider == "azure" + or custom_llm_provider == "azure_text" + or custom_llm_provider == "text-completion-codestral" + or custom_llm_provider == "text-completion-openai" + ) + and isinstance(prompt, list) + and len(prompt) > 0 + and isinstance(prompt[0], list) + ): + verbose_logger.warning( + msg="List of lists being passed. If this is for tokens, then it might not work across all models." + ) + messages = [{"role": "user", "content": prompt}] # type: ignore + else: + raise Exception( + f"Unmapped prompt format. Your prompt is neither a list of strings nor a string. prompt={prompt}. File an issue - https://github.com/BerriAI/litellm/issues" + ) + + kwargs.pop("prompt", None) + + if _model is not None and ( + custom_llm_provider == "openai" + ): # for openai compatible endpoints - e.g. vllm, call the native /v1/completions endpoint for text completion calls + if _model not in litellm.open_ai_chat_completion_models: + model = "text-completion-openai/" + _model + optional_params.pop("custom_llm_provider", None) + + kwargs["text_completion"] = True + response = completion( + model=model, + messages=messages, + *args, + **kwargs, + **optional_params, + ) + if kwargs.get("acompletion", False) is True: + return response + if stream is True or kwargs.get("stream", False) is True: + response = TextCompletionStreamWrapper( + completion_stream=response, + model=model, + stream_options=stream_options, + custom_llm_provider=custom_llm_provider, + ) + return response + transformed_logprobs = None + # only supported for TGI models + try: + raw_response = response._hidden_params.get("original_response", None) + transformed_logprobs = litellm.utils.transform_logprobs(raw_response) + except Exception as e: + print_verbose(f"LiteLLM non blocking exception: {e}") + + if isinstance(response, TextCompletionResponse): + return response + + text_completion_response["id"] = response.get("id", None) + text_completion_response["object"] = "text_completion" + text_completion_response["created"] = response.get("created", None) + text_completion_response["model"] = response.get("model", None) + text_choices = TextChoices() + text_choices["text"] = response["choices"][0]["message"]["content"] + text_choices["index"] = response["choices"][0]["index"] + text_choices["logprobs"] = transformed_logprobs + text_choices["finish_reason"] = response["choices"][0]["finish_reason"] + text_completion_response["choices"] = [text_choices] + text_completion_response["usage"] = response.get("usage", None) + text_completion_response._hidden_params = HiddenParams( + **response._hidden_params + ) + + return text_completion_response + + +###### Adapter Completion ################ + + +async def aadapter_completion( + *, adapter_id: str, **kwargs +) -> Optional[Union[BaseModel, AdapterCompletionStreamWrapper]]: + """ + Implemented to handle async calls for adapter_completion() + """ + try: + translation_obj: Optional[CustomLogger] = None + for item in litellm.adapters: + if item["id"] == adapter_id: + translation_obj = item["adapter"] + + if translation_obj is None: + raise ValueError( + "No matching adapter given. Received 'adapter_id'={}, litellm.adapters={}".format( + adapter_id, litellm.adapters + ) + ) + + new_kwargs = translation_obj.translate_completion_input_params( + kwargs=kwargs + ) + + response: Union[ModelResponse, CustomStreamWrapper] = await acompletion(**new_kwargs) # type: ignore + translated_response: Optional[ + Union[BaseModel, AdapterCompletionStreamWrapper] + ] = None + if isinstance(response, ModelResponse): + translated_response = ( + translation_obj.translate_completion_output_params( + response=response + ) + ) + if isinstance(response, CustomStreamWrapper): + translated_response = ( + translation_obj.translate_completion_output_params_streaming( + completion_stream=response + ) + ) + + return translated_response + except Exception as e: + raise e + + +def adapter_completion( + *, adapter_id: str, **kwargs +) -> Optional[Union[BaseModel, AdapterCompletionStreamWrapper]]: + translation_obj: Optional[CustomLogger] = None + for item in litellm.adapters: + if item["id"] == adapter_id: + translation_obj = item["adapter"] + + if translation_obj is None: + raise ValueError( + "No matching adapter given. Received 'adapter_id'={}, litellm.adapters={}".format( + adapter_id, litellm.adapters + ) + ) + + new_kwargs = translation_obj.translate_completion_input_params( + kwargs=kwargs + ) + + response: Union[ModelResponse, CustomStreamWrapper] = completion(**new_kwargs) # type: ignore + translated_response: Optional[ + Union[BaseModel, AdapterCompletionStreamWrapper] + ] = None + if isinstance(response, ModelResponse): + translated_response = ( + translation_obj.translate_completion_output_params( + response=response + ) + ) + elif isinstance(response, CustomStreamWrapper) or inspect.isgenerator( + response + ): + translated_response = ( + translation_obj.translate_completion_output_params_streaming( + completion_stream=response + ) + ) + + return translated_response + + +##### Moderation ####################### + + +def moderation( + input: str, + model: Optional[str] = None, + api_key: Optional[str] = None, + **kwargs, +): + # only supports open ai for now + api_key = ( + api_key + or litellm.api_key + or litellm.openai_key + or get_secret("OPENAI_API_KEY") + ) + + openai_client = kwargs.get("client", None) + if openai_client is None: + openai_client = openai.OpenAI( + api_key=api_key, + ) + + response = openai_client.moderations.create(input=input, model=model) + return response + + +@client +async def amoderation( + input: str, + model: Optional[str] = None, + api_key: Optional[str] = None, + **kwargs, +): + # only supports open ai for now + api_key = ( + api_key + or litellm.api_key + or litellm.openai_key + or get_secret("OPENAI_API_KEY") + ) + openai_client = kwargs.get("client", None) + if openai_client is None: + # call helper to get OpenAI client + # _get_openai_client maintains in-memory caching logic for OpenAI clients + openai_client = openai_chat_completions._get_openai_client( + is_async=True, + api_key=api_key, + ) + response = await openai_client.moderations.create(input=input, model=model) + return response + + +##### Image Generation ####################### +@client +async def aimage_generation(*args, **kwargs) -> ImageResponse: + """ + Asynchronously calls the `image_generation` function with the given arguments and keyword arguments. + + Parameters: + - `args` (tuple): Positional arguments to be passed to the `image_generation` function. + - `kwargs` (dict): Keyword arguments to be passed to the `image_generation` function. + + Returns: + - `response` (Any): The response returned by the `image_generation` function. + """ + loop = asyncio.get_event_loop() + model = args[0] if len(args) > 0 else kwargs["model"] + ### PASS ARGS TO Image Generation ### + kwargs["aimg_generation"] = True + custom_llm_provider = None + try: + # Use a partial function to pass your keyword arguments + func = partial(image_generation, *args, **kwargs) + + # Add the context to the function + ctx = contextvars.copy_context() + func_with_context = partial(ctx.run, func) + + _, custom_llm_provider, _, _ = get_llm_provider( + model=model, api_base=kwargs.get("api_base", None) + ) + + # Await normally + init_response = await loop.run_in_executor(None, func_with_context) + if isinstance(init_response, dict) or isinstance( + init_response, ImageResponse + ): ## CACHING SCENARIO + if isinstance(init_response, dict): + init_response = ImageResponse(**init_response) + response = init_response + elif asyncio.iscoroutine(init_response): + response = await init_response + else: + # Call the synchronous function using run_in_executor + response = await loop.run_in_executor(None, func_with_context) + return response + except Exception as e: + custom_llm_provider = custom_llm_provider or "openai" + raise exception_type( + model=model, + custom_llm_provider=custom_llm_provider, + original_exception=e, + completion_kwargs=args, + extra_kwargs=kwargs, + ) + + +@client +def image_generation( + prompt: str, + model: Optional[str] = None, + n: Optional[int] = None, + quality: Optional[str] = None, + response_format: Optional[str] = None, + size: Optional[str] = None, + style: Optional[str] = None, + user: Optional[str] = None, + timeout=600, # default to 10 minutes + api_key: Optional[str] = None, + api_base: Optional[str] = None, + api_version: Optional[str] = None, + litellm_logging_obj=None, + custom_llm_provider=None, + **kwargs, +) -> ImageResponse: + """ + Maps the https://api.openai.com/v1/images/generations endpoint. + + Currently supports just Azure + OpenAI. + """ + try: + aimg_generation = kwargs.get("aimg_generation", False) + litellm_call_id = kwargs.get("litellm_call_id", None) + logger_fn = kwargs.get("logger_fn", None) + proxy_server_request = kwargs.get("proxy_server_request", None) + model_info = kwargs.get("model_info", None) + metadata = kwargs.get("metadata", {}) + client = kwargs.get("client", None) + + model_response = litellm.utils.ImageResponse() + if model is not None or custom_llm_provider is not None: + model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base) # type: ignore + else: + model = "dall-e-2" + custom_llm_provider = "openai" # default to dall-e-2 on openai + model_response._hidden_params["model"] = model + openai_params = [ + "user", + "request_timeout", + "api_base", + "api_version", + "api_key", + "deployment_id", + "organization", + "base_url", + "default_headers", + "timeout", + "max_retries", + "n", + "quality", + "size", + "style", + ] + litellm_params = [ + "metadata", + "aimg_generation", + "caching", + "mock_response", + "api_key", + "api_version", + "api_base", + "force_timeout", + "logger_fn", + "verbose", + "custom_llm_provider", + "litellm_logging_obj", + "litellm_call_id", + "use_client", + "id", + "fallbacks", + "azure", + "headers", + "model_list", + "num_retries", + "context_window_fallback_dict", + "retry_policy", + "roles", + "final_prompt_value", + "bos_token", + "eos_token", + "request_timeout", + "complete_response", + "self", + "client", + "rpm", + "tpm", + "max_parallel_requests", + "input_cost_per_token", + "output_cost_per_token", + "hf_model_name", + "proxy_server_request", + "model_info", + "preset_cache_key", + "caching_groups", + "ttl", + "cache", + "region_name", + "allowed_model_region", + "model_config", + ] + default_params = openai_params + litellm_params + non_default_params = { + k: v for k, v in kwargs.items() if k not in default_params + } # model-specific params - pass them straight to the model/provider + optional_params = get_optional_params_image_gen( + n=n, + quality=quality, + response_format=response_format, + size=size, + style=style, + user=user, + custom_llm_provider=custom_llm_provider, + **non_default_params, + ) + logging: Logging = litellm_logging_obj + logging.update_environment_variables( + model=model, + user=user, + optional_params=optional_params, + litellm_params={ + "timeout": timeout, + "azure": False, + "litellm_call_id": litellm_call_id, + "logger_fn": logger_fn, + "proxy_server_request": proxy_server_request, + "model_info": model_info, + "metadata": metadata, + "preset_cache_key": None, + "stream_response": {}, + }, + custom_llm_provider=custom_llm_provider, + ) + + if custom_llm_provider == "azure": + # azure configs + api_type = get_secret("AZURE_API_TYPE") or "azure" + + api_base = ( + api_base or litellm.api_base or get_secret("AZURE_API_BASE") + ) + + api_version = ( + api_version + or litellm.api_version + or get_secret("AZURE_API_VERSION") + ) + + api_key = ( + api_key + or litellm.api_key + or litellm.azure_key + or get_secret("AZURE_OPENAI_API_KEY") + or get_secret("AZURE_API_KEY") + ) + + azure_ad_token = optional_params.pop( + "azure_ad_token", None + ) or get_secret("AZURE_AD_TOKEN") + + model_response = azure_chat_completions.image_generation( + model=model, + prompt=prompt, + timeout=timeout, + api_key=api_key, + api_base=api_base, + logging_obj=litellm_logging_obj, + optional_params=optional_params, + model_response=model_response, + api_version=api_version, + aimg_generation=aimg_generation, + client=client, + ) + elif custom_llm_provider == "openai": + model_response = openai_chat_completions.image_generation( + model=model, + prompt=prompt, + timeout=timeout, + api_key=api_key, + api_base=api_base, + logging_obj=litellm_logging_obj, + optional_params=optional_params, + model_response=model_response, + aimg_generation=aimg_generation, + client=client, + ) + elif custom_llm_provider == "bedrock": + if model is None: + raise Exception("Model needs to be set for bedrock") + model_response = bedrock_image_generation.image_generation( + model=model, + prompt=prompt, + timeout=timeout, + logging_obj=litellm_logging_obj, + optional_params=optional_params, + model_response=model_response, + aimg_generation=aimg_generation, + ) + elif custom_llm_provider == "vertex_ai": + vertex_ai_project = ( + optional_params.pop("vertex_project", None) + or optional_params.pop("vertex_ai_project", None) + or litellm.vertex_project + or get_secret("VERTEXAI_PROJECT") + ) + vertex_ai_location = ( + optional_params.pop("vertex_location", None) + or optional_params.pop("vertex_ai_location", None) + or litellm.vertex_location + or get_secret("VERTEXAI_LOCATION") + ) + vertex_credentials = ( + optional_params.pop("vertex_credentials", None) + or optional_params.pop("vertex_ai_credentials", None) + or get_secret("VERTEXAI_CREDENTIALS") + ) + model_response = vertex_image_generation.image_generation( + model=model, + prompt=prompt, + timeout=timeout, + logging_obj=litellm_logging_obj, + optional_params=optional_params, + model_response=model_response, + vertex_project=vertex_ai_project, + vertex_location=vertex_ai_location, + vertex_credentials=vertex_credentials, + aimg_generation=aimg_generation, + ) + + return model_response + except Exception as e: + ## Map to OpenAI Exception + raise exception_type( + model=model, + custom_llm_provider=custom_llm_provider, + original_exception=e, + completion_kwargs=locals(), + extra_kwargs=kwargs, + ) + + +##### Transcription ####################### + + +@client +async def atranscription(*args, **kwargs) -> TranscriptionResponse: + """ + Calls openai + azure whisper endpoints. + + Allows router to load balance between them + """ + loop = asyncio.get_event_loop() + model = args[0] if len(args) > 0 else kwargs["model"] + ### PASS ARGS TO Image Generation ### + kwargs["atranscription"] = True + custom_llm_provider = None + try: + # Use a partial function to pass your keyword arguments + func = partial(transcription, *args, **kwargs) + + # Add the context to the function + ctx = contextvars.copy_context() + func_with_context = partial(ctx.run, func) + + _, custom_llm_provider, _, _ = get_llm_provider( + model=model, api_base=kwargs.get("api_base", None) + ) + + # Await normally + init_response = await loop.run_in_executor(None, func_with_context) + if isinstance(init_response, dict): + response = TranscriptionResponse(**init_response) + elif isinstance( + init_response, TranscriptionResponse + ): ## CACHING SCENARIO + response = init_response + elif asyncio.iscoroutine(init_response): + response = await init_response + else: + # Call the synchronous function using run_in_executor + response = await loop.run_in_executor(None, func_with_context) + return response + except Exception as e: + custom_llm_provider = custom_llm_provider or "openai" + raise exception_type( + model=model, + custom_llm_provider=custom_llm_provider, + original_exception=e, + completion_kwargs=args, + extra_kwargs=kwargs, + ) + + +@client +def transcription( + model: str, + file: FileTypes, + ## OPTIONAL OPENAI PARAMS ## + language: Optional[str] = None, + prompt: Optional[str] = None, + response_format: Optional[ + Literal["json", "text", "srt", "verbose_json", "vtt"] + ] = None, + temperature: Optional[int] = None, # openai defaults this to 0 + ## LITELLM PARAMS ## + user: Optional[str] = None, + timeout=600, # default to 10 minutes + api_key: Optional[str] = None, + api_base: Optional[str] = None, + api_version: Optional[str] = None, + max_retries: Optional[int] = None, + litellm_logging_obj: Optional[LiteLLMLoggingObj] = None, + custom_llm_provider=None, + **kwargs, +) -> TranscriptionResponse: + """ + Calls openai + azure whisper endpoints. + + Allows router to load balance between them + """ + atranscription = kwargs.get("atranscription", False) + litellm_call_id = kwargs.get("litellm_call_id", None) + logger_fn = kwargs.get("logger_fn", None) + proxy_server_request = kwargs.get("proxy_server_request", None) + model_info = kwargs.get("model_info", None) + metadata = kwargs.get("metadata", {}) + tags = kwargs.pop("tags", []) + + drop_params = kwargs.get("drop_params", None) + client: Optional[ + Union[ + openai.AsyncOpenAI, + openai.OpenAI, + openai.AzureOpenAI, + openai.AsyncAzureOpenAI, + ] + ] = kwargs.pop("client", None) + + if litellm_logging_obj: + litellm_logging_obj.model_call_details["client"] = str(client) + + if max_retries is None: + max_retries = openai.DEFAULT_MAX_RETRIES + + model_response = litellm.utils.TranscriptionResponse() + + model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base) # type: ignore + + if dynamic_api_key is not None: + api_key = dynamic_api_key + + optional_params = get_optional_params_transcription( + model=model, + language=language, + prompt=prompt, + response_format=response_format, + temperature=temperature, + custom_llm_provider=custom_llm_provider, + drop_params=drop_params, + ) + # optional_params = { + # "language": language, + # "prompt": prompt, + # "response_format": response_format, + # "temperature": None, # openai defaults this to 0 + # } + + if custom_llm_provider == "azure": + # azure configs + api_base = api_base or litellm.api_base or get_secret("AZURE_API_BASE") + + api_version = ( + api_version + or litellm.api_version + or get_secret("AZURE_API_VERSION") + ) + + azure_ad_token = kwargs.pop("azure_ad_token", None) or get_secret( + "AZURE_AD_TOKEN" + ) + + api_key = ( + api_key + or litellm.api_key + or litellm.azure_key + or get_secret("AZURE_API_KEY") + ) # type: ignore + + response = azure_audio_transcriptions.audio_transcriptions( + model=model, + audio_file=file, + optional_params=optional_params, + model_response=model_response, + atranscription=atranscription, + client=client, + timeout=timeout, + logging_obj=litellm_logging_obj, + api_base=api_base, + api_key=api_key, + api_version=api_version, + azure_ad_token=azure_ad_token, + max_retries=max_retries, + ) + elif custom_llm_provider == "openai" or custom_llm_provider == "groq": + api_base = ( + api_base + or litellm.api_base + or get_secret("OPENAI_API_BASE") + or "https://api.openai.com/v1" + ) # type: ignore + openai.organization = ( + litellm.organization + or get_secret("OPENAI_ORGANIZATION") + or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105 + ) + # set API KEY + api_key = ( + api_key + or litellm.api_key + or litellm.openai_key + or get_secret("OPENAI_API_KEY") + ) # type: ignore + response = openai_audio_transcriptions.audio_transcriptions( + model=model, + audio_file=file, + optional_params=optional_params, + model_response=model_response, + atranscription=atranscription, + client=client, + timeout=timeout, + logging_obj=litellm_logging_obj, + max_retries=max_retries, + api_base=api_base, + api_key=api_key, + ) + return response + + +@client +async def aspeech(*args, **kwargs) -> HttpxBinaryResponseContent: + """ + Calls openai tts endpoints. + """ + loop = asyncio.get_event_loop() + model = args[0] if len(args) > 0 else kwargs["model"] + ### PASS ARGS TO Image Generation ### + kwargs["aspeech"] = True + custom_llm_provider = kwargs.get("custom_llm_provider", None) + try: + # Use a partial function to pass your keyword arguments + func = partial(speech, *args, **kwargs) + + # Add the context to the function + ctx = contextvars.copy_context() + func_with_context = partial(ctx.run, func) + + _, custom_llm_provider, _, _ = get_llm_provider( + model=model, api_base=kwargs.get("api_base", None) + ) + + # Await normally + init_response = await loop.run_in_executor(None, func_with_context) + if asyncio.iscoroutine(init_response): + response = await init_response + else: + # Call the synchronous function using run_in_executor + response = await loop.run_in_executor(None, func_with_context) + return response # type: ignore + except Exception as e: + custom_llm_provider = custom_llm_provider or "openai" + raise exception_type( + model=model, + custom_llm_provider=custom_llm_provider, + original_exception=e, + completion_kwargs=args, + extra_kwargs=kwargs, + ) + + +@client +def speech( + model: str, + input: str, + voice: Optional[Union[str, dict]] = None, + api_key: Optional[str] = None, + api_base: Optional[str] = None, + api_version: Optional[str] = None, + organization: Optional[str] = None, + project: Optional[str] = None, + max_retries: Optional[int] = None, + metadata: Optional[dict] = None, + timeout: Optional[Union[float, httpx.Timeout]] = None, + response_format: Optional[str] = None, + speed: Optional[int] = None, + client=None, + headers: Optional[dict] = None, + custom_llm_provider: Optional[str] = None, + aspeech: Optional[bool] = None, + **kwargs, +) -> HttpxBinaryResponseContent: + model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base) # type: ignore + tags = kwargs.pop("tags", []) + + optional_params = {} + if response_format is not None: + optional_params["response_format"] = response_format + if speed is not None: + optional_params["speed"] = speed # type: ignore + + if timeout is None: + timeout = litellm.request_timeout + + if max_retries is None: + max_retries = litellm.num_retries or openai.DEFAULT_MAX_RETRIES + + logging_obj = kwargs.get("litellm_logging_obj", None) + response: Optional[HttpxBinaryResponseContent] = None + if custom_llm_provider == "openai": + if voice is None or not (isinstance(voice, str)): + raise litellm.BadRequestError( + message="'voice' is required to be passed as a string for OpenAI TTS", + model=model, + llm_provider=custom_llm_provider, + ) + api_base = ( + api_base # for deepinfra/perplexity/anyscale/groq/friendliai we check in get_llm_provider and pass in the api base from there + or litellm.api_base + or get_secret("OPENAI_API_BASE") + or "https://api.openai.com/v1" + ) # type: ignore + # set API KEY + api_key = ( + api_key + or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there + or litellm.openai_key + or get_secret("OPENAI_API_KEY") + ) # type: ignore + + organization = ( + organization + or litellm.organization + or get_secret("OPENAI_ORGANIZATION") + or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105 + ) # type: ignore + + project = ( + project + or litellm.project + or get_secret("OPENAI_PROJECT") + or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105 + ) # type: ignore + + headers = headers or litellm.headers + + response = openai_chat_completions.audio_speech( + model=model, + input=input, + voice=voice, + optional_params=optional_params, + api_key=api_key, + api_base=api_base, + organization=organization, + project=project, + max_retries=max_retries, + timeout=timeout, + client=client, # pass AsyncOpenAI, OpenAI client + aspeech=aspeech, + ) + elif custom_llm_provider == "azure": + # azure configs + if voice is None or not (isinstance(voice, str)): + raise litellm.BadRequestError( + message="'voice' is required to be passed as a string for Azure TTS", + model=model, + llm_provider=custom_llm_provider, + ) + api_base = api_base or litellm.api_base or get_secret("AZURE_API_BASE") # type: ignore + + api_version = ( + api_version + or litellm.api_version + or get_secret("AZURE_API_VERSION") + ) # type: ignore + + api_key = ( + api_key + or litellm.api_key + or litellm.azure_key + or get_secret("AZURE_OPENAI_API_KEY") + or get_secret("AZURE_API_KEY") + ) # type: ignore + + azure_ad_token: Optional[str] = optional_params.get("extra_body", {}).pop( # type: ignore + "azure_ad_token", None + ) or get_secret( + "AZURE_AD_TOKEN" + ) + + headers = headers or litellm.headers + + response = azure_chat_completions.audio_speech( + model=model, + input=input, + voice=voice, + optional_params=optional_params, + api_key=api_key, + api_base=api_base, + api_version=api_version, + azure_ad_token=azure_ad_token, + organization=organization, + max_retries=max_retries, + timeout=timeout, + client=client, # pass AsyncOpenAI, OpenAI client + aspeech=aspeech, + ) + elif ( + custom_llm_provider == "vertex_ai" + or custom_llm_provider == "vertex_ai_beta" + ): + from litellm.types.router import GenericLiteLLMParams + + generic_optional_params = GenericLiteLLMParams(**kwargs) + + api_base = generic_optional_params.api_base or "" + vertex_ai_project = ( + generic_optional_params.vertex_project + or litellm.vertex_project + or get_secret("VERTEXAI_PROJECT") + ) + vertex_ai_location = ( + generic_optional_params.vertex_location + or litellm.vertex_location + or get_secret("VERTEXAI_LOCATION") + ) + vertex_credentials = ( + generic_optional_params.vertex_credentials + or get_secret("VERTEXAI_CREDENTIALS") + ) + + if voice is not None and not isinstance(voice, dict): + raise litellm.BadRequestError( + message=f"'voice' is required to be passed as a dict for Vertex AI TTS, passed in voice={voice}", + model=model, + llm_provider=custom_llm_provider, + ) + response = vertex_text_to_speech.audio_speech( + _is_async=aspeech, + vertex_credentials=vertex_credentials, + vertex_project=vertex_ai_project, + vertex_location=vertex_ai_location, + timeout=timeout, + api_base=api_base, + model=model, + input=input, + voice=voice, + optional_params=optional_params, + kwargs=kwargs, + logging_obj=logging_obj, + ) + + if response is None: + raise Exception( + "Unable to map the custom llm provider={} to a known provider={}.".format( + custom_llm_provider, litellm.provider_list + ) + ) + return response + + +##### Health Endpoints ####################### + + +async def ahealth_check( + model_params: dict, + mode: Optional[ + Literal["completion", "embedding", "image_generation", "chat", "batch"] + ] = None, + prompt: Optional[str] = None, + input: Optional[List] = None, + default_timeout: float = 6000, +): + """ + Support health checks for different providers. Return remaining rate limit, etc. + + For azure/openai -> completion.with_raw_response + For rest -> litellm.acompletion() + """ + passed_in_mode: Optional[str] = None + try: + model: Optional[str] = model_params.get("model", None) + + if model is None: + raise Exception("model not set") + + if model in litellm.model_cost and mode is None: + mode = litellm.model_cost[model].get("mode") + + model, custom_llm_provider, _, _ = get_llm_provider(model=model) + + if model in litellm.model_cost and mode is None: + mode = litellm.model_cost[model].get("mode") + + mode = mode + passed_in_mode = mode + if mode is None: + mode = "chat" # default to chat completion calls + + if custom_llm_provider == "azure": + api_key = ( + model_params.get("api_key") + or get_secret("AZURE_API_KEY") + or get_secret("AZURE_OPENAI_API_KEY") + ) + + api_base = ( + model_params.get("api_base") + or get_secret("AZURE_API_BASE") + or get_secret("AZURE_OPENAI_API_BASE") + ) + + api_version = ( + model_params.get("api_version") + or get_secret("AZURE_API_VERSION") + or get_secret("AZURE_OPENAI_API_VERSION") + ) + + timeout = ( + model_params.get("timeout") + or litellm.request_timeout + or default_timeout + ) + + response = await azure_chat_completions.ahealth_check( + model=model, + messages=model_params.get( + "messages", None + ), # Replace with your actual messages list + api_key=api_key, + api_base=api_base, + api_version=api_version, + timeout=timeout, + mode=mode, + prompt=prompt, + input=input, + ) + elif ( + custom_llm_provider == "openai" + or custom_llm_provider == "text-completion-openai" + ): + api_key = model_params.get("api_key") or get_secret( + "OPENAI_API_KEY" + ) + organization = model_params.get("organization") + + timeout = ( + model_params.get("timeout") + or litellm.request_timeout + or default_timeout + ) + + api_base = model_params.get("api_base") or get_secret( + "OPENAI_API_BASE" + ) + + if custom_llm_provider == "text-completion-openai": + mode = "completion" + + response = await openai_chat_completions.ahealth_check( + model=model, + messages=model_params.get( + "messages", None + ), # Replace with your actual messages list + api_key=api_key, + api_base=api_base, + timeout=timeout, + mode=mode, + prompt=prompt, + input=input, + organization=organization, + ) + else: + model_params["cache"] = { + "no-cache": True + } # don't used cached responses for making health check calls + if mode == "embedding": + model_params.pop("messages", None) + model_params["input"] = input + await litellm.aembedding(**model_params) + response = {} + elif mode == "image_generation": + model_params.pop("messages", None) + model_params["prompt"] = prompt + await litellm.aimage_generation(**model_params) + response = {} + elif "*" in model: + from litellm.litellm_core_utils.llm_request_utils import ( + pick_cheapest_model_from_llm_provider, + ) + + # this is a wildcard model, we need to pick a random model from the provider + cheapest_model = pick_cheapest_model_from_llm_provider( + custom_llm_provider=custom_llm_provider + ) + model_params["model"] = cheapest_model + await acompletion(**model_params) + response = {} # args like remaining ratelimit etc. + else: # default to completion calls + await acompletion(**model_params) + response = {} # args like remaining ratelimit etc. + return response + except Exception as e: + stack_trace = traceback.format_exc() + if isinstance(stack_trace, str): + stack_trace = stack_trace[:1000] + + if passed_in_mode is None: + return { + "error": f"error:{str(e)}. Missing `mode`. Set the `mode` for the model - https://docs.litellm.ai/docs/proxy/health#embedding-models \nstacktrace: {stack_trace}" + } + + error_to_return = ( + str(e) + + "\nHave you set 'mode' - https://docs.litellm.ai/docs/proxy/health#embedding-models" + + "\nstack trace: " + + stack_trace + ) + return {"error": error_to_return} + + +####### HELPER FUNCTIONS ################ +## Set verbose to true -> ```litellm.set_verbose = True``` +def print_verbose(print_statement): + try: + verbose_logger.debug(print_statement) + if litellm.set_verbose: + print(print_statement) # noqa + except: + pass + + +def config_completion(**kwargs): + if litellm.config_path != None: + config_args = read_config_args(litellm.config_path) + # overwrite any args passed in with config args + return completion(**kwargs, **config_args) + else: + raise ValueError( + "No config path set, please set a config path using `litellm.config_path = 'path/to/config.json'`" + ) + + +def stream_chunk_builder_text_completion( + chunks: list, messages: Optional[List] = None +): + id = chunks[0]["id"] + object = chunks[0]["object"] + created = chunks[0]["created"] + model = chunks[0]["model"] + system_fingerprint = chunks[0].get("system_fingerprint", None) + finish_reason = chunks[-1]["choices"][0]["finish_reason"] + logprobs = chunks[-1]["choices"][0]["logprobs"] + + response = { + "id": id, + "object": object, + "created": created, + "model": model, + "system_fingerprint": system_fingerprint, + "choices": [ + { + "text": None, + "index": 0, + "logprobs": logprobs, + "finish_reason": finish_reason, + } + ], + "usage": { + "prompt_tokens": None, + "completion_tokens": None, + "total_tokens": None, + }, + } + content_list = [] + for chunk in chunks: + choices = chunk["choices"] + for choice in choices: + if ( + choice is not None + and hasattr(choice, "text") + and choice.get("text") is not None + ): + _choice = choice.get("text") + content_list.append(_choice) + + # Combine the "content" strings into a single string || combine the 'function' strings into a single string + combined_content = "".join(content_list) + + # Update the "content" field within the response dictionary + response["choices"][0]["text"] = combined_content + + if len(combined_content) > 0: + completion_output = combined_content + else: + completion_output = "" + # # Update usage information if needed + try: + response["usage"]["prompt_tokens"] = token_counter( + model=model, messages=messages + ) + except: # don't allow this failing to block a complete streaming response from being returned + print_verbose(f"token_counter failed, assuming prompt tokens is 0") + response["usage"]["prompt_tokens"] = 0 + response["usage"]["completion_tokens"] = token_counter( + model=model, + text=combined_content, + count_response_tokens=True, # count_response_tokens is a Flag to tell token counter this is a response, No need to add extra tokens we do for input messages + ) + response["usage"]["total_tokens"] = ( + response["usage"]["prompt_tokens"] + + response["usage"]["completion_tokens"] + ) + return response + + +def stream_chunk_builder( + chunks: list, + messages: Optional[list] = None, + start_time=None, + end_time=None, +) -> Optional[Union[ModelResponse, TextCompletionResponse]]: + try: + model_response = litellm.ModelResponse() + ### BASE-CASE ### + if len(chunks) == 0: + return None + ### SORT CHUNKS BASED ON CREATED ORDER ## + print_verbose( + "Goes into checking if chunk has hiddden created at param" + ) + if chunks[0]._hidden_params.get("created_at", None): + print_verbose("Chunks have a created at hidden param") + # Sort chunks based on created_at in ascending order + chunks = sorted( + chunks, + key=lambda x: x._hidden_params.get("created_at", float("inf")), + ) + print_verbose("Chunks sorted") + + # set hidden params from chunk to model_response + if model_response is not None and hasattr( + model_response, "_hidden_params" + ): + model_response._hidden_params = chunks[0].get("_hidden_params", {}) + id = chunks[0]["id"] + object = chunks[0]["object"] + created = chunks[0]["created"] + model = chunks[0]["model"] + system_fingerprint = chunks[0].get("system_fingerprint", None) + + if isinstance( + chunks[0]["choices"][0], litellm.utils.TextChoices + ): # route to the text completion logic + return stream_chunk_builder_text_completion( + chunks=chunks, messages=messages + ) + role = chunks[0]["choices"][0]["delta"]["role"] + finish_reason = chunks[-1]["choices"][0]["finish_reason"] + + # Initialize the response dictionary + response = { + "id": id, + "object": object, + "created": created, + "model": model, + "system_fingerprint": system_fingerprint, + "choices": [ + { + "index": 0, + "message": {"role": role, "content": ""}, + "finish_reason": finish_reason, + } + ], + "usage": { + "prompt_tokens": 0, # Modify as needed + "completion_tokens": 0, # Modify as needed + "total_tokens": 0, # Modify as needed + }, + } + + # Extract the "content" strings from the nested dictionaries within "choices" + content_list = [] + combined_content = "" + combined_arguments = "" + + tool_call_chunks = [ + chunk + for chunk in chunks + if "tool_calls" in chunk["choices"][0]["delta"] + and chunk["choices"][0]["delta"]["tool_calls"] is not None + ] + + if len(tool_call_chunks) > 0: + argument_list: List = [] + delta = tool_call_chunks[0]["choices"][0]["delta"] + message = response["choices"][0]["message"] + message["tool_calls"] = [] + id = None + name = None + type = None + tool_calls_list = [] + prev_index = None + prev_name = None + prev_id = None + curr_id = None + curr_index = 0 + for chunk in tool_call_chunks: + choices = chunk["choices"] + for choice in choices: + delta = choice.get("delta", {}) + tool_calls = delta.get("tool_calls", "") + # Check if a tool call is present + if tool_calls and tool_calls[0].function is not None: + if tool_calls[0].id: + id = tool_calls[0].id + curr_id = id + if prev_id is None: + prev_id = curr_id + if tool_calls[0].index: + curr_index = tool_calls[0].index + if tool_calls[0].function.arguments: + # Now, tool_calls is expected to be a dictionary + arguments = tool_calls[0].function.arguments + argument_list.append(arguments) + if tool_calls[0].function.name: + name = tool_calls[0].function.name + if tool_calls[0].type: + type = tool_calls[0].type + if prev_index is None: + prev_index = curr_index + if prev_name is None: + prev_name = name + if curr_index != prev_index: # new tool call + combined_arguments = "".join(argument_list) + tool_calls_list.append( + { + "id": prev_id, + "function": { + "arguments": combined_arguments, + "name": prev_name, + }, + "type": type, + } + ) + argument_list = [] # reset + prev_index = curr_index + prev_id = curr_id + prev_name = name + + combined_arguments = ( + "".join(argument_list) or "{}" + ) # base case, return empty dict + + tool_calls_list.append( + { + "id": id, + "function": { + "arguments": combined_arguments, + "name": name, + }, + "type": type, + } + ) + response["choices"][0]["message"]["content"] = None + response["choices"][0]["message"]["tool_calls"] = tool_calls_list + + function_call_chunks = [ + chunk + for chunk in chunks + if "function_call" in chunk["choices"][0]["delta"] + and chunk["choices"][0]["delta"]["function_call"] is not None + ] + + if len(function_call_chunks) > 0: + argument_list = [] + delta = function_call_chunks[0]["choices"][0]["delta"] + function_call = delta.get("function_call", "") + function_call_name = function_call.name + + message = response["choices"][0]["message"] + message["function_call"] = {} + message["function_call"]["name"] = function_call_name + + for chunk in function_call_chunks: + choices = chunk["choices"] + for choice in choices: + delta = choice.get("delta", {}) + function_call = delta.get("function_call", "") + + # Check if a function call is present + if function_call: + # Now, function_call is expected to be a dictionary + arguments = function_call.arguments + argument_list.append(arguments) + + combined_arguments = "".join(argument_list) + response["choices"][0]["message"]["content"] = None + response["choices"][0]["message"]["function_call"][ + "arguments" + ] = combined_arguments + + content_chunks = [ + chunk + for chunk in chunks + if "content" in chunk["choices"][0]["delta"] + and chunk["choices"][0]["delta"]["content"] is not None + ] + + if len(content_chunks) > 0: + for chunk in chunks: + choices = chunk["choices"] + for choice in choices: + delta = choice.get("delta", {}) + content = delta.get("content", "") + if content is None: + continue # openai v1.0.0 sets content = None for chunks + content_list.append(content) + + # Combine the "content" strings into a single string || combine the 'function' strings into a single string + combined_content = "".join(content_list) + + # Update the "content" field within the response dictionary + response["choices"][0]["message"]["content"] = combined_content + + completion_output = "" + if len(combined_content) > 0: + completion_output += combined_content + if len(combined_arguments) > 0: + completion_output += combined_arguments + + # Update usage information if needed + prompt_tokens = 0 + completion_tokens = 0 + # anthropic prompt caching information + cache_creation_input_tokens: Optional[int] = None + cache_read_input_tokens: Optional[int] = None + for chunk in chunks: + usage_chunk: Optional[Usage] = None + if "usage" in chunk: + usage_chunk = chunk.usage + elif ( + hasattr(chunk, "_hidden_params") + and "usage" in chunk._hidden_params + ): + usage_chunk = chunk._hidden_params["usage"] + if usage_chunk is not None: + if "prompt_tokens" in usage_chunk: + prompt_tokens = usage_chunk.get("prompt_tokens", 0) or 0 + if "completion_tokens" in usage_chunk: + completion_tokens = ( + usage_chunk.get("completion_tokens", 0) or 0 + ) + if "cache_creation_input_tokens" in usage_chunk: + cache_creation_input_tokens = usage_chunk.get( + "cache_creation_input_tokens" + ) + if "cache_read_input_tokens" in usage_chunk: + cache_read_input_tokens = usage_chunk.get( + "cache_read_input_tokens" + ) + + try: + response["usage"][ + "prompt_tokens" + ] = prompt_tokens or token_counter(model=model, messages=messages) + except ( + Exception + ): # don't allow this failing to block a complete streaming response from being returned + print_verbose("token_counter failed, assuming prompt tokens is 0") + response["usage"]["prompt_tokens"] = 0 + response["usage"][ + "completion_tokens" + ] = completion_tokens or token_counter( + model=model, + text=completion_output, + count_response_tokens=True, # count_response_tokens is a Flag to tell token counter this is a response, No need to add extra tokens we do for input messages + ) + response["usage"]["total_tokens"] = ( + response["usage"]["prompt_tokens"] + + response["usage"]["completion_tokens"] + ) + + if cache_creation_input_tokens is not None: + response["usage"][ + "cache_creation_input_tokens" + ] = cache_creation_input_tokens + if cache_read_input_tokens is not None: + response["usage"][ + "cache_read_input_tokens" + ] = cache_read_input_tokens + + return convert_to_model_response_object( + response_object=response, + model_response_object=model_response, + start_time=start_time, + end_time=end_time, + ) # type: ignore + except Exception as e: + verbose_logger.exception( + "litellm.main.py::stream_chunk_builder() - Exception occurred - {}".format( + str(e) + ) + ) + raise litellm.APIError( + status_code=500, + message="Error building chunks for logging/streaming usage calculation", + llm_provider="", + model="", + ) diff --git a/notdiamond/toolkit/litellm_notdiamond.py b/notdiamond/toolkit/litellm_notdiamond.py new file mode 100644 index 00000000..c44ab627 --- /dev/null +++ b/notdiamond/toolkit/litellm_notdiamond.py @@ -0,0 +1,266 @@ +# flake8: noqa + +import types +from typing import Callable, Dict, List, Optional + +import httpx +import litellm +import requests +from litellm._version import version +from litellm.utils import ModelResponse + +# dict to map notdiamond providers and models to litellm providers and models +ND2LITELLM = { + # openai + "openai/gpt-3.5-turbo": "gpt-3.5-turbo-0125", + "openai/gpt-3.5-turbo-0125": "gpt-3.5-turbo-0125", + "openai/gpt-4": "gpt-4", + "openai/gpt-4-0613": "gpt-4-0613", + "openai/gpt-4o": "gpt-4o", + "openai/gpt-4o-2024-05-13": "gpt-4o-2024-05-13", + "openai/gpt-4-turbo": "gpt-4-turbo", + "openai/gpt-4-turbo-2024-04-09": "gpt-4-turbo-2024-04-09", + "openai/gpt-4-turbo-preview": "gpt-4-turbo-preview", + "openai/gpt-4-0125-preview": "gpt-4-0125-preview", + "openai/gpt-4-1106-preview": "gpt-4-1106-preview", + "openai/gpt-4-1106-preview": "gpt-4-1106-preview", + "openai/gpt-4o-mini": "gpt-4o-mini", + "openai/gpt-4o-mini-2024-07-18": "gpt-4o-mini-2024-07-18", + # anthropic + "anthropic/claude-2.1": "claude-2.1", + "anthropic/claude-3-opus-20240229": "claude-3-opus-20240229", + "anthropic/claude-3-sonnet-20240229": "claude-3-sonnet-20240229", + "anthropic/claude-3-5-sonnet-20240620": "claude-3-5-sonnet-20240620", + "anthropic/claude-3-haiku-20240307": "claude-3-haiku-20240307", + # mistral + "mistral/mistral-large-latest": "mistral/mistral-large-latest", + "mistral/mistral-medium-latest": "mistral/mistral-medium-latest", + "mistral/mistral-small-latest": "mistral/mistral-small-latest", + "mistral/codestral-latest": "mistral/codestral-latest", + "mistral/open-mistral-7b": "mistral/open-mistral-7b", + "mistral/open-mixtral-8x7b": "mistral/open-mixtral-8x7b", + "mistral/open-mixtral-8x22b": "mistral/open-mixtral-8x22b", + "mistral/mistral-large-2407": "mistral/mistral-large-2407", + "mistral/mistral-large-2402": "mistral/mistral-large-2402", + # perplexity + "perplexity/llama-3.1-sonar-large-128k-online": "perplexity/llama-3.1-sonar-large-128k-online", + # cohere + "cohere/command-r": "cohere_chat/command-r", + "cohere/command-r-plus": "cohere_chat/command-r-plus", + # google + "google/gemini-pro": "gemini/gemini-pro", + "google/gemini-1.5-pro-latest": "gemini/gemini-1.5-pro-latest", + "google/gemini-1.5-flash-latest": "gemini/gemini-1.5-flash-latest", + "google/gemini-1.0-pro-latest": "gemini/gemini-pro", + # replicate + "replicate/mistral-7b-instruct-v0.2": "replicate/mistralai/mistral-7b-instruct-v0.2", + "replicate/mixtral-8x7b-instruct-v0.1": "replicate/mistralai/mixtral-8x7b-instruct-v0.1", + "replicate/meta-llama-3-70b-instruct": "replicate/meta/meta-llama-3-70b-instruct", + "replicate/meta-llama-3-8b-instruct": "replicate/meta/meta-llama-3-8b-instruct", + "replicate/meta-llama-3.1-405b-instruct": "replicate/meta/meta-llama-3.1-405b-instruct", + # togetherai + "togetherai/Mistral-7B-Instruct-v0.2": "together_ai/mistralai/Mistral-7B-Instruct-v0.2", + "togetherai/Mixtral-8x7B-Instruct-v0.1": "together_ai/mistralai/Mixtral-8x7B-Instruct-v0.1", + "togetherai/Mixtral-8x22B-Instruct-v0.1": "together_ai/mistralai/Mixtral-8x22B-Instruct-v0.1", + "togetherai/Llama-3-70b-chat-hf": "together_ai/meta-llama/Llama-3-70b-chat-hf", + "togetherai/Llama-3-8b-chat-hf": "together_ai/meta-llama/Llama-3-8b-chat-hf", + "togetherai/Qwen2-72B-Instruct": "together_ai/Qwen/Qwen2-72B-Instruct", + "togetherai/Meta-Llama-3.1-8B-Instruct-Turbo": "together_ai/meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo", + "togetherai/Meta-Llama-3.1-70B-Instruct-Turbo": "together_ai/meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo", + "togetherai/Meta-Llama-3.1-405B-Instruct-Turbo": "together_ai/meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo", +} + + +class NotDiamondError(Exception): + def __init__( + self, + status_code, + message, + url="https://not-diamond-server.onrender.com/v2/optimizer/modelSelect", + ): + self.status_code = status_code + self.message = message + self.request = httpx.Request(method="POST", url=url) + self.response = httpx.Response( + status_code=status_code, request=self.request + ) + super().__init__( + self.message + ) # Call the base class constructor with the parameters it needs + + +class NotDiamondConfig: + llm_providers: List[Dict[str, str]] + tools: Optional[List[Dict[str, str]]] = None + max_model_depth: int = 1 + # tradeoff params: "cost"/"latency" + tradeoff: Optional[str] = None + preference_id: Optional[str] = None + hash_content: Optional[bool] = False + + def __init__( + self, + llm_providers: List[Dict[str, str]], + tools: Optional[str] = None, + max_model_depth: Optional[int] = 1, + tradeoff: Optional[str] = None, + preference_id: Optional[str] = None, + hash_content: Optional[bool] = False, + ) -> None: + locals_ = locals() + for key, value in locals_.items(): + if key != "self" and value is not None: + setattr(self.__class__, key, value) + + @classmethod + def get_config(cls): + return { + k: v + for k, v in cls.__dict__.items() + if not k.startswith("__") + and not isinstance( + v, + ( + types.FunctionType, + types.BuiltinFunctionType, + classmethod, + staticmethod, + ), + ) + and v is not None + or k == "llm_providers" + } + + +def validate_environment(api_key): + if api_key is None: + raise ValueError( + "Missing NOTDIAMOND_API_KEY in env - A call is being made to Not Diamond but no key is set either in the environment variables or via params" + ) + headers = { + "Authorization": "Bearer " + api_key, + "accept": "application/json", + "content-type": "application/json", + "User-Agent": f"litellm/{version}", + } + return headers + + +def get_litellm_model(response: dict) -> str: + nd_provider = response["providers"][0]["provider"] + nd_model = response["providers"][0]["model"] + nd_provider_model = f"{nd_provider}/{nd_model}" + litellm_model = ND2LITELLM[nd_provider_model] + return litellm_model + + +def update_litellm_params(litellm_params: dict): + """ + Create a new litellm_params dict with non-default litellm_params from the original call, custom_llm_provider and api_base + """ + new_litellm_params = dict() + for k, v in litellm_params.items(): + # all litellm_params have defaults of None or False, except force_timeout + if (k == "force_timeout" and v != 600) or v: + new_litellm_params[k] = v + if "custom_llm_provider" in new_litellm_params: + del new_litellm_params["custom_llm_provider"] + if "api_base" in new_litellm_params: + del new_litellm_params["api_base"] + if "api_key" in new_litellm_params: + del new_litellm_params["api_key"] + return new_litellm_params + + +def completion( + model: str, + messages: list, + api_base: str, + model_response: ModelResponse, + print_verbose: Callable, + encoding, + api_key, + logging_obj, + optional_params=None, + litellm_params=None, + logger_fn=None, +): + headers = validate_environment(api_key) + completion_url = api_base + + ## Load Config + config = NotDiamondConfig.get_config() + for k, v in config.items(): + if k not in optional_params: + optional_params[k] = v + + # separate ND optional params from litellm optional params + nd_params = [ + "llm_providers", + "tools", + "max_model_depth", + "tradeoff", + "preference_id", + "hash_content", + ] + selected_model_params = dict() + for k, v in optional_params.items(): + if k not in nd_params: + selected_model_params[k] = v + if "tools" in optional_params: + selected_model_params["tools"] = optional_params["tools"] + # remove any optional params that are not in the ND params + optional_params = { + k: v for k, v in optional_params.items() if k in nd_params + } + + data = { + "messages": messages, + **optional_params, + } + + ## LOGGING + logging_obj.pre_call( + input=messages, + api_key=api_key, + additional_args={ + "complete_input_dict": data, + "headers": headers, + "api_base": completion_url, + }, + ) + + ## MODEL SELECTION CALL + nd_response = requests.post( + api_base, + headers=headers, + json=data, + ) + print_verbose(f"Raw response from Not Diamond: {nd_response.text}") + + ## RESPONSE OBJECT + if nd_response.status_code != 200: + raise NotDiamondError( + status_code=nd_response.status_code, message=nd_response.text + ) + nd_response = nd_response.json() + litellm_model = get_litellm_model(nd_response) + + ## COMPLETION CALL + litellm_params = update_litellm_params(litellm_params) + + is_async_call = litellm_params.pop("acompletion", False) + if is_async_call: + return litellm.acompletion( + model=litellm_model, + messages=messages, + **selected_model_params, + **litellm_params, + ) + else: + return litellm.completion( + model=litellm_model, + messages=messages, + **selected_model_params, + **litellm_params, + ) diff --git a/poetry.lock b/poetry.lock index ffb9deb9..ae106ec6 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2195,13 +2195,13 @@ requests = ">=2,<3" [[package]] name = "litellm" -version = "1.44.15" +version = "1.44.26" description = "Library to easily interface with LLM API providers" optional = false python-versions = "!=2.7.*,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,!=3.7.*,>=3.8" files = [ - {file = "litellm-1.44.15-py3-none-any.whl", hash = "sha256:6b818bb9b974e72dd731fedb0ef0c2164bc3239fc879428b054f4192ad32115a"}, - {file = "litellm-1.44.15.tar.gz", hash = "sha256:7bd3a9bde01f7a80f1bbf748ccd171debcea435ba4ac0853688049f66673a44e"}, + {file = "litellm-1.44.26-py3-none-any.whl", hash = "sha256:de63115a19e1432a44e38b9a2d8d8dfb77d5745370ef80d45327017b7341dea5"}, + {file = "litellm-1.44.26.tar.gz", hash = "sha256:10856ad8e9b5fed96f2f4ad62dd3266240c2d996ef20ee9e233a2140e9c1eb5b"}, ] [package.dependencies] @@ -4051,10 +4051,54 @@ description = "Database Abstraction Library" optional = true python-versions = ">=3.7" files = [ + {file = "SQLAlchemy-2.0.34-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:95d0b2cf8791ab5fb9e3aa3d9a79a0d5d51f55b6357eecf532a120ba3b5524db"}, + {file = "SQLAlchemy-2.0.34-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:243f92596f4fd4c8bd30ab8e8dd5965afe226363d75cab2468f2c707f64cd83b"}, + {file = "SQLAlchemy-2.0.34-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9ea54f7300553af0a2a7235e9b85f4204e1fc21848f917a3213b0e0818de9a24"}, + {file = "SQLAlchemy-2.0.34-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:173f5f122d2e1bff8fbd9f7811b7942bead1f5e9f371cdf9e670b327e6703ebd"}, + {file = "SQLAlchemy-2.0.34-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:196958cde924a00488e3e83ff917be3b73cd4ed8352bbc0f2989333176d1c54d"}, + {file = "SQLAlchemy-2.0.34-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:bd90c221ed4e60ac9d476db967f436cfcecbd4ef744537c0f2d5291439848768"}, + {file = "SQLAlchemy-2.0.34-cp310-cp310-win32.whl", hash = "sha256:3166dfff2d16fe9be3241ee60ece6fcb01cf8e74dd7c5e0b64f8e19fab44911b"}, + {file = "SQLAlchemy-2.0.34-cp310-cp310-win_amd64.whl", hash = "sha256:6831a78bbd3c40f909b3e5233f87341f12d0b34a58f14115c9e94b4cdaf726d3"}, + {file = "SQLAlchemy-2.0.34-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:c7db3db284a0edaebe87f8f6642c2b2c27ed85c3e70064b84d1c9e4ec06d5d84"}, + {file = "SQLAlchemy-2.0.34-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:430093fce0efc7941d911d34f75a70084f12f6ca5c15d19595c18753edb7c33b"}, + {file = "SQLAlchemy-2.0.34-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:79cb400c360c7c210097b147c16a9e4c14688a6402445ac848f296ade6283bbc"}, {file = "SQLAlchemy-2.0.34-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fb1b30f31a36c7f3fee848391ff77eebdd3af5750bf95fbf9b8b5323edfdb4ec"}, + {file = "SQLAlchemy-2.0.34-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:8fddde2368e777ea2a4891a3fb4341e910a056be0bb15303bf1b92f073b80c02"}, {file = "SQLAlchemy-2.0.34-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:80bd73ea335203b125cf1d8e50fef06be709619eb6ab9e7b891ea34b5baa2287"}, + {file = "SQLAlchemy-2.0.34-cp311-cp311-win32.whl", hash = "sha256:6daeb8382d0df526372abd9cb795c992e18eed25ef2c43afe518c73f8cccb721"}, + {file = "SQLAlchemy-2.0.34-cp311-cp311-win_amd64.whl", hash = "sha256:5bc08e75ed11693ecb648b7a0a4ed80da6d10845e44be0c98c03f2f880b68ff4"}, + {file = "SQLAlchemy-2.0.34-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:53e68b091492c8ed2bd0141e00ad3089bcc6bf0e6ec4142ad6505b4afe64163e"}, + {file = "SQLAlchemy-2.0.34-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:bcd18441a49499bf5528deaa9dee1f5c01ca491fc2791b13604e8f972877f812"}, + {file = "SQLAlchemy-2.0.34-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:165bbe0b376541092bf49542bd9827b048357f4623486096fc9aaa6d4e7c59a2"}, {file = "SQLAlchemy-2.0.34-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c3330415cd387d2b88600e8e26b510d0370db9b7eaf984354a43e19c40df2e2b"}, + {file = "SQLAlchemy-2.0.34-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:97b850f73f8abbffb66ccbab6e55a195a0eb655e5dc74624d15cff4bfb35bd74"}, {file = "SQLAlchemy-2.0.34-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:7cee4c6917857fd6121ed84f56d1dc78eb1d0e87f845ab5a568aba73e78adf83"}, + {file = "SQLAlchemy-2.0.34-cp312-cp312-win32.whl", hash = "sha256:fbb034f565ecbe6c530dff948239377ba859420d146d5f62f0271407ffb8c580"}, + {file = "SQLAlchemy-2.0.34-cp312-cp312-win_amd64.whl", hash = "sha256:707c8f44931a4facd4149b52b75b80544a8d824162602b8cd2fe788207307f9a"}, + {file = "SQLAlchemy-2.0.34-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:24af3dc43568f3780b7e1e57c49b41d98b2d940c1fd2e62d65d3928b6f95f021"}, + {file = "SQLAlchemy-2.0.34-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e60ed6ef0a35c6b76b7640fe452d0e47acc832ccbb8475de549a5cc5f90c2c06"}, + {file = "SQLAlchemy-2.0.34-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:413c85cd0177c23e32dee6898c67a5f49296640041d98fddb2c40888fe4daa2e"}, + {file = "SQLAlchemy-2.0.34-cp37-cp37m-musllinux_1_2_aarch64.whl", hash = "sha256:25691f4adfb9d5e796fd48bf1432272f95f4bbe5f89c475a788f31232ea6afba"}, + {file = "SQLAlchemy-2.0.34-cp37-cp37m-musllinux_1_2_x86_64.whl", hash = "sha256:526ce723265643dbc4c7efb54f56648cc30e7abe20f387d763364b3ce7506c82"}, + {file = "SQLAlchemy-2.0.34-cp37-cp37m-win32.whl", hash = "sha256:13be2cc683b76977a700948411a94c67ad8faf542fa7da2a4b167f2244781cf3"}, + {file = "SQLAlchemy-2.0.34-cp37-cp37m-win_amd64.whl", hash = "sha256:e54ef33ea80d464c3dcfe881eb00ad5921b60f8115ea1a30d781653edc2fd6a2"}, + {file = "SQLAlchemy-2.0.34-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:43f28005141165edd11fbbf1541c920bd29e167b8bbc1fb410d4fe2269c1667a"}, + {file = "SQLAlchemy-2.0.34-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:b68094b165a9e930aedef90725a8fcfafe9ef95370cbb54abc0464062dbf808f"}, + {file = "SQLAlchemy-2.0.34-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6a1e03db964e9d32f112bae36f0cc1dcd1988d096cfd75d6a588a3c3def9ab2b"}, + {file = "SQLAlchemy-2.0.34-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:203d46bddeaa7982f9c3cc693e5bc93db476ab5de9d4b4640d5c99ff219bee8c"}, + {file = "SQLAlchemy-2.0.34-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:ae92bebca3b1e6bd203494e5ef919a60fb6dfe4d9a47ed2453211d3bd451b9f5"}, + {file = "SQLAlchemy-2.0.34-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:9661268415f450c95f72f0ac1217cc6f10256f860eed85c2ae32e75b60278ad8"}, + {file = "SQLAlchemy-2.0.34-cp38-cp38-win32.whl", hash = "sha256:895184dfef8708e15f7516bd930bda7e50ead069280d2ce09ba11781b630a434"}, + {file = "SQLAlchemy-2.0.34-cp38-cp38-win_amd64.whl", hash = "sha256:6e7cde3a2221aa89247944cafb1b26616380e30c63e37ed19ff0bba5e968688d"}, + {file = "SQLAlchemy-2.0.34-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:dbcdf987f3aceef9763b6d7b1fd3e4ee210ddd26cac421d78b3c206d07b2700b"}, + {file = "SQLAlchemy-2.0.34-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:ce119fc4ce0d64124d37f66a6f2a584fddc3c5001755f8a49f1ca0a177ef9796"}, + {file = "SQLAlchemy-2.0.34-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a17d8fac6df9835d8e2b4c5523666e7051d0897a93756518a1fe101c7f47f2f0"}, + {file = "SQLAlchemy-2.0.34-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9ebc11c54c6ecdd07bb4efbfa1554538982f5432dfb8456958b6d46b9f834bb7"}, + {file = "SQLAlchemy-2.0.34-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:2e6965346fc1491a566e019a4a1d3dfc081ce7ac1a736536367ca305da6472a8"}, + {file = "SQLAlchemy-2.0.34-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:220574e78ad986aea8e81ac68821e47ea9202b7e44f251b7ed8c66d9ae3f4278"}, + {file = "SQLAlchemy-2.0.34-cp39-cp39-win32.whl", hash = "sha256:b75b00083e7fe6621ce13cfce9d4469c4774e55e8e9d38c305b37f13cf1e874c"}, + {file = "SQLAlchemy-2.0.34-cp39-cp39-win_amd64.whl", hash = "sha256:c29d03e0adf3cc1a8c3ec62d176824972ae29b67a66cbb18daff3062acc6faa8"}, + {file = "SQLAlchemy-2.0.34-py3-none-any.whl", hash = "sha256:7286c353ee6475613d8beff83167374006c6b3e3f0e6491bfe8ca610eb1dec0f"}, {file = "sqlalchemy-2.0.34.tar.gz", hash = "sha256:10d8f36990dd929690666679b0f42235c159a7051534adb135728ee52828dd22"}, ] diff --git a/tests/test_toolkit/test_litellm.py b/tests/test_toolkit/test_litellm.py new file mode 100644 index 00000000..7c73c50e --- /dev/null +++ b/tests/test_toolkit/test_litellm.py @@ -0,0 +1,216 @@ +import os + +import pytest + +from notdiamond.settings import GOOGLE_API_KEY, PPLX_API_KEY, TOGETHER_API_KEY +from notdiamond.toolkit.litellm import acompletion, completion + +os.environ["TOGETHERAI_API_KEY"] = TOGETHER_API_KEY +os.environ["GEMINI_API_KEY"] = GOOGLE_API_KEY +os.environ["PERPLEXITYAI_API_KEY"] = PPLX_API_KEY + + +# nd providers and models +ND_MODEL_LIST = [ + {"provider": "openai", "model": "gpt-3.5-turbo"}, + {"provider": "openai", "model": "gpt-3.5-turbo-0125"}, + {"provider": "openai", "model": "gpt-4"}, + {"provider": "openai", "model": "gpt-4-0613"}, + {"provider": "openai", "model": "gpt-4o"}, + {"provider": "openai", "model": "gpt-4o-2024-05-13"}, + {"provider": "openai", "model": "gpt-4-turbo"}, + {"provider": "openai", "model": "gpt-4-turbo-2024-04-09"}, + {"provider": "openai", "model": "gpt-4-turbo-preview"}, + {"provider": "openai", "model": "gpt-4-0125-preview"}, + {"provider": "openai", "model": "gpt-4-1106-preview"}, + {"provider": "openai", "model": "gpt-4o-mini"}, + {"provider": "openai", "model": "gpt-4o-mini-2024-07-18"}, + {"provider": "anthropic", "model": "claude-2.1"}, + {"provider": "anthropic", "model": "claude-3-opus-20240229"}, + {"provider": "anthropic", "model": "claude-3-sonnet-20240229"}, + {"provider": "anthropic", "model": "claude-3-5-sonnet-20240620"}, + {"provider": "anthropic", "model": "claude-3-haiku-20240307"}, + {"provider": "mistral", "model": "mistral-large-latest"}, + {"provider": "mistral", "model": "mistral-medium-latest"}, + {"provider": "mistral", "model": "mistral-small-latest"}, + {"provider": "mistral", "model": "codestral-latest"}, + {"provider": "mistral", "model": "open-mistral-7b"}, + {"provider": "mistral", "model": "open-mixtral-8x7b"}, + {"provider": "mistral", "model": "open-mixtral-8x22b"}, + {"provider": "mistral", "model": "mistral-large-2407"}, + {"provider": "mistral", "model": "mistral-large-2402"}, + {"provider": "perplexity", "model": "llama-3.1-sonar-large-128k-online"}, + {"provider": "cohere", "model": "command-r"}, + {"provider": "cohere", "model": "command-r-plus"}, + {"provider": "google", "model": "gemini-pro"}, + {"provider": "google", "model": "gemini-1.5-pro-latest"}, + {"provider": "google", "model": "gemini-1.5-flash-latest"}, + {"provider": "google", "model": "gemini-1.0-pro-latest"}, + # {"provider": "replicate", "model": "mistral-7b-instruct-v0.2"}, removed due to replicate side error + {"provider": "replicate", "model": "mixtral-8x7b-instruct-v0.1"}, + {"provider": "replicate", "model": "meta-llama-3-70b-instruct"}, + {"provider": "replicate", "model": "meta-llama-3-8b-instruct"}, + {"provider": "replicate", "model": "meta-llama-3.1-405b-instruct"}, + {"provider": "togetherai", "model": "Mistral-7B-Instruct-v0.2"}, + {"provider": "togetherai", "model": "Mixtral-8x7B-Instruct-v0.1"}, + {"provider": "togetherai", "model": "Mixtral-8x22B-Instruct-v0.1"}, + {"provider": "togetherai", "model": "Llama-3-70b-chat-hf"}, + {"provider": "togetherai", "model": "Llama-3-8b-chat-hf"}, + {"provider": "togetherai", "model": "Qwen2-72B-Instruct"}, + {"provider": "togetherai", "model": "Meta-Llama-3.1-8B-Instruct-Turbo"}, + {"provider": "togetherai", "model": "Meta-Llama-3.1-70B-Instruct-Turbo"}, + {"provider": "togetherai", "model": "Meta-Llama-3.1-405B-Instruct-Turbo"}, +] + +ND_TOOLS_MODEL_LIST = [ + {"provider": "openai", "model": "gpt-3.5-turbo"}, + {"provider": "openai", "model": "gpt-3.5-turbo-0125"}, + {"provider": "openai", "model": "gpt-4"}, + {"provider": "openai", "model": "gpt-4-0613"}, + {"provider": "openai", "model": "gpt-4o"}, + {"provider": "openai", "model": "gpt-4o-2024-05-13"}, + {"provider": "openai", "model": "gpt-4-turbo"}, + {"provider": "openai", "model": "gpt-4-turbo-2024-04-09"}, + {"provider": "openai", "model": "gpt-4-turbo-preview"}, + {"provider": "openai", "model": "gpt-4-0125-preview"}, + {"provider": "openai", "model": "gpt-4-1106-preview"}, + {"provider": "openai", "model": "gpt-4o-mini"}, + {"provider": "openai", "model": "gpt-4o-mini-2024-07-18"}, + {"provider": "anthropic", "model": "claude-3-opus-20240229"}, + {"provider": "anthropic", "model": "claude-3-sonnet-20240229"}, + {"provider": "anthropic", "model": "claude-3-5-sonnet-20240620"}, + {"provider": "anthropic", "model": "claude-3-haiku-20240307"}, + {"provider": "mistral", "model": "mistral-large-latest"}, + {"provider": "mistral", "model": "mistral-small-latest"}, + {"provider": "cohere", "model": "command-r"}, + {"provider": "cohere", "model": "command-r-plus"}, + {"provider": "google", "model": "gemini-pro"}, + {"provider": "google", "model": "gemini-1.5-pro-latest"}, + {"provider": "google", "model": "gemini-1.5-flash-latest"}, + {"provider": "google", "model": "gemini-1.0-pro-latest"}, +] + + +def test_completion_notdiamond(): + try: + messages = [ + { + "role": "user", + "content": "Hey", + }, + ] + for model in ND_MODEL_LIST: + print(f"Testing {model}") + _ = completion( + model="notdiamond/notdiamond", + messages=messages, + llm_providers=[model], + ) + except Exception as e: + pytest.fail(f"Error occurred: {e}") + + +def test_completion_notdiamond_stream(): + try: + messages = [ + { + "role": "user", + "content": "Hey", + }, + ] + for model in ND_MODEL_LIST: + _ = completion( + model="notdiamond/notdiamond", + messages=messages, + llm_providers=[model], + stream=True, + ) + except Exception as e: + pytest.fail(f"Error occurred: {e}") + + +def test_completion_notdiamond_tool_calling(): + try: + messages = [ + { + "role": "user", + "content": "what is 2 + 5?", + }, + ] + tools = [ + { + "type": "function", + "function": { + "name": "add", + "description": "Adds a and b.", + "parameters": { + "type": "object", + "properties": { + "a": {"type": "integer"}, + "b": {"type": "integer"}, + }, + "required": ["a", "b"], + }, + }, + }, + ] + for model in ND_TOOLS_MODEL_LIST: + _ = completion( + model="notdiamond/notdiamond", + messages=messages, + llm_providers=[model], + tools=tools, + ) + except Exception as e: + pytest.fail(f"Error occurred: {e}") + + +def test_async_completion_notdiamond(): + import asyncio + + async def test_get_response(model): + user_message = "Hello, how are you?" + messages = [{"content": user_message, "role": "user"}] + try: + _ = await acompletion( + model="notdiamond/notdiamond", + messages=messages, + llm_providers=[model], + num_retries=3, + timeout=10, + ) + except Exception as e: + pytest.fail(f"Error occurred: {e}") + + async def run_concurrent_tests(): + _ = await asyncio.gather( + *[test_get_response(model) for model in ND_MODEL_LIST] + ) + + asyncio.run(run_concurrent_tests()) + + +def test_async_completion_notdiamond_stream(): + import asyncio + + async def test_get_response(model): + user_message = "Hello, how are you?" + messages = [{"content": user_message, "role": "user"}] + try: + _ = await acompletion( + model="notdiamond/notdiamond", + messages=messages, + llm_providers=[model], + num_retries=3, + timeout=10, + stream=True, + ) + except Exception as e: + pytest.fail(f"Error occurred: {e}") + + async def run_concurrent_tests(): + _ = await asyncio.gather( + *[test_get_response(model) for model in ND_MODEL_LIST] + ) + + asyncio.run(run_concurrent_tests()) From 781cf538ac81ccc8ab7a50c272b05dca4606bd8f Mon Sep 17 00:00:00 2001 From: Tze-Yang Tung Date: Fri, 13 Sep 2024 10:11:54 -0400 Subject: [PATCH 2/7] cleanup --- notdiamond/toolkit/__init__.py | 5 ++--- notdiamond/toolkit/litellm/__init__.py | 1 + notdiamond/toolkit/{ => litellm}/litellm.py | 0 notdiamond/toolkit/{ => litellm}/litellm_notdiamond.py | 0 4 files changed, 3 insertions(+), 3 deletions(-) create mode 100644 notdiamond/toolkit/litellm/__init__.py rename notdiamond/toolkit/{ => litellm}/litellm.py (100%) rename notdiamond/toolkit/{ => litellm}/litellm_notdiamond.py (100%) diff --git a/notdiamond/toolkit/__init__.py b/notdiamond/toolkit/__init__.py index 13eee816..525ff9eb 100644 --- a/notdiamond/toolkit/__init__.py +++ b/notdiamond/toolkit/__init__.py @@ -1,3 +1,2 @@ -from .custom_router import CustomRouter - -CustomRouter +from .custom_router import CustomRouter # noqa +from .litellm import litellm # noqa diff --git a/notdiamond/toolkit/litellm/__init__.py b/notdiamond/toolkit/litellm/__init__.py new file mode 100644 index 00000000..e5523de3 --- /dev/null +++ b/notdiamond/toolkit/litellm/__init__.py @@ -0,0 +1 @@ +from .litellm import * # noqa diff --git a/notdiamond/toolkit/litellm.py b/notdiamond/toolkit/litellm/litellm.py similarity index 100% rename from notdiamond/toolkit/litellm.py rename to notdiamond/toolkit/litellm/litellm.py diff --git a/notdiamond/toolkit/litellm_notdiamond.py b/notdiamond/toolkit/litellm/litellm_notdiamond.py similarity index 100% rename from notdiamond/toolkit/litellm_notdiamond.py rename to notdiamond/toolkit/litellm/litellm_notdiamond.py From c551dced40ce456ed4e2bedc4309a596c5f8612d Mon Sep 17 00:00:00 2001 From: Tze-Yang Tung Date: Fri, 13 Sep 2024 10:51:58 -0400 Subject: [PATCH 3/7] even more brute force --- notdiamond/toolkit/litellm/__init__.py | 1047 ++++++++++++++++- .../toolkit/litellm/{litellm.py => main.py} | 20 +- poetry.lock | 14 +- tests/test_toolkit/test_litellm.py | 6 +- 4 files changed, 1063 insertions(+), 24 deletions(-) rename notdiamond/toolkit/litellm/{litellm.py => main.py} (99%) diff --git a/notdiamond/toolkit/litellm/__init__.py b/notdiamond/toolkit/litellm/__init__.py index e5523de3..1140ee27 100644 --- a/notdiamond/toolkit/litellm/__init__.py +++ b/notdiamond/toolkit/litellm/__init__.py @@ -1 +1,1046 @@ -from .litellm import * # noqa +# flake8: noqa +# +### Hide pydantic namespace conflict warnings globally ### +import warnings + +from .litellm_notdiamond import NotDiamondConfig + +warnings.filterwarnings( + "ignore", message=".*conflict with protected namespace.*" +) +import os + +### INIT VARIABLES ### +import threading +from enum import Enum +from typing import ( + Any, + Callable, + Dict, + List, + Literal, + Optional, + Union, + get_args, +) + +import dotenv +import httpx +import requests +from litellm._logging import ( + _turn_on_debug, + _turn_on_json, + json_logs, + log_level, + set_verbose, + verbose_logger, +) +from litellm.caching import Cache +from litellm.llms.custom_httpx.http_handler import ( + AsyncHTTPHandler, + HTTPHandler, +) +from litellm.proxy._types import ( + KeyManagementSettings, + KeyManagementSystem, + LiteLLM_UpperboundKeyGenerateParams, +) +from litellm.types.guardrails import GuardrailItem + +litellm_mode = os.getenv("LITELLM_MODE", "DEV") # "PRODUCTION", "DEV" +if litellm_mode == "DEV": + dotenv.load_dotenv() +############################################# +if set_verbose == True: + _turn_on_debug() +############################################# +### Callbacks /Logging / Success / Failure Handlers ### +input_callback: List[Union[str, Callable]] = [] +success_callback: List[Union[str, Callable]] = [] +failure_callback: List[Union[str, Callable]] = [] +service_callback: List[Union[str, Callable]] = [] +_custom_logger_compatible_callbacks_literal = Literal[ + "lago", + "openmeter", + "logfire", + "dynamic_rate_limiter", + "langsmith", + "prometheus", + "galileo", + "braintrust", + "arize", + "gcs_bucket", +] +_known_custom_logger_compatible_callbacks: List = list( + get_args(_custom_logger_compatible_callbacks_literal) +) +callbacks: List[ + Union[Callable, _custom_logger_compatible_callbacks_literal] +] = [] +langfuse_default_tags: Optional[List[str]] = None +langsmith_batch_size: Optional[int] = None +_async_input_callback: List[ + Callable +] = [] # internal variable - async custom callbacks are routed here. +_async_success_callback: List[ + Union[str, Callable] +] = [] # internal variable - async custom callbacks are routed here. +_async_failure_callback: List[ + Callable +] = [] # internal variable - async custom callbacks are routed here. +pre_call_rules: List[Callable] = [] +post_call_rules: List[Callable] = [] +turn_off_message_logging: Optional[bool] = False +log_raw_request_response: bool = False +redact_messages_in_exceptions: Optional[bool] = False +redact_user_api_key_info: Optional[bool] = False +store_audit_logs = False # Enterprise feature, allow users to see audit logs +## end of callbacks ############# + +email: Optional[ + str +] = None # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648 +token: Optional[ + str +] = None # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648 +telemetry = True +max_tokens = 256 # OpenAI Defaults +drop_params = bool(os.getenv("LITELLM_DROP_PARAMS", False)) +modify_params = False +retry = True +### AUTH ### +api_key: Optional[str] = None +openai_key: Optional[str] = None +databricks_key: Optional[str] = None +azure_key: Optional[str] = None +anthropic_key: Optional[str] = None +replicate_key: Optional[str] = None +cohere_key: Optional[str] = None +clarifai_key: Optional[str] = None +maritalk_key: Optional[str] = None +ai21_key: Optional[str] = None +ollama_key: Optional[str] = None +openrouter_key: Optional[str] = None +predibase_key: Optional[str] = None +huggingface_key: Optional[str] = None +vertex_project: Optional[str] = None +vertex_location: Optional[str] = None +predibase_tenant_id: Optional[str] = None +togetherai_api_key: Optional[str] = None +cloudflare_api_key: Optional[str] = None +baseten_key: Optional[str] = None +notdiamond_key: Optional[str] = None +aleph_alpha_key: Optional[str] = None +nlp_cloud_key: Optional[str] = None +common_cloud_provider_auth_params: dict = { + "params": ["project", "region_name", "token"], + "providers": [ + "vertex_ai", + "bedrock", + "watsonx", + "azure", + "vertex_ai_beta", + ], +} +use_client: bool = False +ssl_verify: Union[str, bool] = True +ssl_certificate: Optional[str] = None +disable_streaming_logging: bool = False +in_memory_llm_clients_cache: dict = {} +safe_memory_mode: bool = False +enable_azure_ad_token_refresh: Optional[bool] = False +### DEFAULT AZURE API VERSION ### +AZURE_DEFAULT_API_VERSION = ( + "2024-08-01-preview" # this is updated to the latest +) +### COHERE EMBEDDINGS DEFAULT TYPE ### +COHERE_DEFAULT_EMBEDDING_INPUT_TYPE = "search_document" +### GUARDRAILS ### +llamaguard_model_name: Optional[str] = None +openai_moderations_model_name: Optional[str] = None +presidio_ad_hoc_recognizers: Optional[str] = None +google_moderation_confidence_threshold: Optional[float] = None +llamaguard_unsafe_content_categories: Optional[str] = None +blocked_user_list: Optional[Union[str, List]] = None +banned_keywords_list: Optional[Union[str, List]] = None +llm_guard_mode: Literal["all", "key-specific", "request-specific"] = "all" +guardrail_name_config_map: Dict[str, GuardrailItem] = {} +################## +### PREVIEW FEATURES ### +enable_preview_features: bool = False +return_response_headers: bool = False # get response headers from LLM Api providers - example x-remaining-requests, +enable_json_schema_validation: bool = False +################## +logging: bool = True +enable_loadbalancing_on_batch_endpoints: Optional[bool] = None +enable_caching_on_provider_specific_optional_params: bool = ( + False # feature-flag for caching on optional params - e.g. 'top_k' +) +caching: bool = False # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648 +always_read_redis: bool = ( + True # always use redis for rate limiting logic on litellm proxy +) +caching_with_models: bool = False # # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648 +cache: Optional[ + Cache +] = None # cache object <- use this - https://docs.litellm.ai/docs/caching +default_in_memory_ttl: Optional[float] = None +default_redis_ttl: Optional[float] = None +model_alias_map: Dict[str, str] = {} +model_group_alias_map: Dict[str, str] = {} +max_budget: float = 0.0 # set the max budget across all providers +budget_duration: Optional[ + str +] = None # proxy only - resets budget after fixed duration. You can set duration as seconds ("30s"), minutes ("30m"), hours ("30h"), days ("30d"). +default_soft_budget: float = ( + 50.0 # by default all litellm proxy keys have a soft budget of 50.0 +) +forward_traceparent_to_llm_provider: bool = False +_openai_finish_reasons = [ + "stop", + "length", + "function_call", + "content_filter", + "null", +] +_openai_completion_params = [ + "functions", + "function_call", + "temperature", + "temperature", + "top_p", + "n", + "stream", + "stop", + "max_tokens", + "presence_penalty", + "frequency_penalty", + "logit_bias", + "user", + "request_timeout", + "api_base", + "api_version", + "api_key", + "deployment_id", + "organization", + "base_url", + "default_headers", + "timeout", + "response_format", + "seed", + "tools", + "tool_choice", + "max_retries", +] +_litellm_completion_params = [ + "metadata", + "acompletion", + "caching", + "mock_response", + "api_key", + "api_version", + "api_base", + "force_timeout", + "logger_fn", + "verbose", + "custom_llm_provider", + "litellm_logging_obj", + "litellm_call_id", + "use_client", + "id", + "fallbacks", + "azure", + "headers", + "model_list", + "num_retries", + "context_window_fallback_dict", + "roles", + "final_prompt_value", + "bos_token", + "eos_token", + "request_timeout", + "complete_response", + "self", + "client", + "rpm", + "tpm", + "input_cost_per_token", + "output_cost_per_token", + "hf_model_name", + "model_info", + "proxy_server_request", + "preset_cache_key", +] +_current_cost = 0 # private variable, used if max budget is set +error_logs: Dict = {} +add_function_to_prompt: bool = False # if function calling not supported by api, append function call details to system prompt +client_session: Optional[httpx.Client] = None +aclient_session: Optional[httpx.AsyncClient] = None +model_fallbacks: Optional[List] = None # Deprecated for 'litellm.fallbacks' +model_cost_map_url: str = "https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json" +suppress_debug_info = False +dynamodb_table_name: Optional[str] = None +s3_callback_params: Optional[Dict] = None +generic_logger_headers: Optional[Dict] = None +default_key_generate_params: Optional[Dict] = None +upperbound_key_generate_params: Optional[ + LiteLLM_UpperboundKeyGenerateParams +] = None +default_user_params: Optional[Dict] = None +default_team_settings: Optional[List] = None +max_user_budget: Optional[float] = None +default_max_internal_user_budget: Optional[float] = None +max_internal_user_budget: Optional[float] = None +internal_user_budget_duration: Optional[str] = None +max_end_user_budget: Optional[float] = None +#### REQUEST PRIORITIZATION #### +priority_reservation: Optional[Dict[str, float]] = None +#### RELIABILITY #### +REPEATED_STREAMING_CHUNK_LIMIT = 100 # catch if model starts looping the same chunk while streaming. Uses high default to prevent false positives. +request_timeout: float = 6000 +module_level_aclient = AsyncHTTPHandler(timeout=request_timeout) +module_level_client = HTTPHandler(timeout=request_timeout) +num_retries: Optional[int] = None # per model endpoint +default_fallbacks: Optional[List] = None +fallbacks: Optional[List] = None +context_window_fallbacks: Optional[List] = None +content_policy_fallbacks: Optional[List] = None +allowed_fails: int = 3 +num_retries_per_request: Optional[ + int +] = None # for the request overall (incl. fallbacks + model retries) +####### SECRET MANAGERS ##################### +secret_manager_client: Optional[ + Any +] = None # list of instantiated key management clients - e.g. azure kv, infisical, etc. +_google_kms_resource_name: Optional[str] = None +_key_management_system: Optional[KeyManagementSystem] = None +_key_management_settings: Optional[KeyManagementSettings] = None +#### PII MASKING #### +output_parse_pii: bool = False +############################################# + + +def get_model_cost_map(url: str): + if ( + os.getenv("LITELLM_LOCAL_MODEL_COST_MAP", False) == True + or os.getenv("LITELLM_LOCAL_MODEL_COST_MAP", False) == "True" + ): + import importlib.resources + import json + + with importlib.resources.open_text( + "litellm", "model_prices_and_context_window_backup.json" + ) as f: + content = json.load(f) + return content + + try: + with requests.get( + url, timeout=5 + ) as response: # set a 5 second timeout for the get request + response.raise_for_status() # Raise an exception if the request is unsuccessful + content = response.json() + return content + except Exception as e: + import importlib.resources + import json + + with importlib.resources.open_text( + "litellm", "model_prices_and_context_window_backup.json" + ) as f: + content = json.load(f) + return content + + +model_cost = get_model_cost_map(url=model_cost_map_url) +custom_prompt_dict: Dict[str, dict] = {} + + +####### THREAD-SPECIFIC DATA ################### +class MyLocal(threading.local): + def __init__(self): + self.user = "Hello World" + + +_thread_context = MyLocal() + + +def identify(event_details): + # Store user in thread local data + if "user" in event_details: + _thread_context.user = event_details["user"] + + +####### ADDITIONAL PARAMS ################### configurable params if you use proxy models like Helicone, map spend to org id, etc. +api_base = None +headers = None +api_version = None +organization = None +project = None +config_path = None +vertex_ai_safety_settings: Optional[dict] = None +####### COMPLETION MODELS ################### +open_ai_chat_completion_models: List = [] +open_ai_text_completion_models: List = [] +cohere_models: List = [] +cohere_chat_models: List = [] +mistral_chat_models: List = [] +anthropic_models: List = [] +empower_models: List = [] +openrouter_models: List = [] +vertex_language_models: List = [] +vertex_vision_models: List = [] +vertex_chat_models: List = [] +vertex_code_chat_models: List = [] +vertex_ai_image_models: List = [] +vertex_text_models: List = [] +vertex_code_text_models: List = [] +vertex_embedding_models: List = [] +vertex_anthropic_models: List = [] +vertex_llama3_models: List = [] +vertex_ai_ai21_models: List = [] +vertex_mistral_models: List = [] +ai21_models: List = [] +ai21_chat_models: List = [] +nlp_cloud_models: List = [] +aleph_alpha_models: List = [] +bedrock_models: List = [] +fireworks_ai_models: List = [] +deepinfra_models: List = [] +perplexity_models: List = [] +watsonx_models: List = [] +gemini_models: List = [] +for key, value in model_cost.items(): + if value.get("litellm_provider") == "openai": + open_ai_chat_completion_models.append(key) + elif value.get("litellm_provider") == "text-completion-openai": + open_ai_text_completion_models.append(key) + elif value.get("litellm_provider") == "cohere": + cohere_models.append(key) + elif value.get("litellm_provider") == "cohere_chat": + cohere_chat_models.append(key) + elif value.get("litellm_provider") == "mistral": + mistral_chat_models.append(key) + elif value.get("litellm_provider") == "anthropic": + anthropic_models.append(key) + elif value.get("litellm_provider") == "empower": + empower_models.append(key) + elif value.get("litellm_provider") == "openrouter": + openrouter_models.append(key) + elif value.get("litellm_provider") == "vertex_ai-text-models": + vertex_text_models.append(key) + elif value.get("litellm_provider") == "vertex_ai-code-text-models": + vertex_code_text_models.append(key) + elif value.get("litellm_provider") == "vertex_ai-language-models": + vertex_language_models.append(key) + elif value.get("litellm_provider") == "vertex_ai-vision-models": + vertex_vision_models.append(key) + elif value.get("litellm_provider") == "vertex_ai-chat-models": + vertex_chat_models.append(key) + elif value.get("litellm_provider") == "vertex_ai-code-chat-models": + vertex_code_chat_models.append(key) + elif value.get("litellm_provider") == "vertex_ai-embedding-models": + vertex_embedding_models.append(key) + elif value.get("litellm_provider") == "vertex_ai-anthropic_models": + key = key.replace("vertex_ai/", "") + vertex_anthropic_models.append(key) + elif value.get("litellm_provider") == "vertex_ai-llama_models": + key = key.replace("vertex_ai/", "") + vertex_llama3_models.append(key) + elif value.get("litellm_provider") == "vertex_ai-mistral_models": + key = key.replace("vertex_ai/", "") + vertex_mistral_models.append(key) + elif value.get("litellm_provider") == "vertex_ai-ai21_models": + key = key.replace("vertex_ai/", "") + vertex_ai_ai21_models.append(key) + elif value.get("litellm_provider") == "vertex_ai-image-models": + key = key.replace("vertex_ai/", "") + vertex_ai_image_models.append(key) + elif value.get("litellm_provider") == "ai21": + if value.get("mode") == "chat": + ai21_chat_models.append(key) + else: + ai21_models.append(key) + elif value.get("litellm_provider") == "nlp_cloud": + nlp_cloud_models.append(key) + elif value.get("litellm_provider") == "aleph_alpha": + aleph_alpha_models.append(key) + elif value.get("litellm_provider") == "bedrock": + bedrock_models.append(key) + elif value.get("litellm_provider") == "deepinfra": + deepinfra_models.append(key) + elif value.get("litellm_provider") == "perplexity": + perplexity_models.append(key) + elif value.get("litellm_provider") == "watsonx": + watsonx_models.append(key) + elif value.get("litellm_provider") == "gemini": + gemini_models.append(key) + elif value.get("litellm_provider") == "fireworks_ai": + fireworks_ai_models.append(key) +# known openai compatible endpoints - we'll eventually move this list to the model_prices_and_context_window.json dictionary +openai_compatible_endpoints: List = [ + "api.perplexity.ai", + "api.endpoints.anyscale.com/v1", + "api.deepinfra.com/v1/openai", + "api.mistral.ai/v1", + "codestral.mistral.ai/v1/chat/completions", + "codestral.mistral.ai/v1/fim/completions", + "api.groq.com/openai/v1", + "https://integrate.api.nvidia.com/v1", + "api.deepseek.com/v1", + "api.together.xyz/v1", + "app.empower.dev/api/v1", + "inference.friendli.ai/v1", +] + +# this is maintained for Exception Mapping +openai_compatible_providers: List = [ + "anyscale", + "mistral", + "groq", + "nvidia_nim", + "cerebras", + "ai21_chat", + "volcengine", + "codestral", + "deepseek", + "deepinfra", + "perplexity", + "xinference", + "together_ai", + "fireworks_ai", + "empower", + "friendliai", + "azure_ai", + "github", +] +openai_text_completion_compatible_providers: List = ( + [ # providers that support `/v1/completions` + "together_ai", + "fireworks_ai", + ] +) + +# well supported replicate llms +replicate_models: List = [ + # llama replicate supported LLMs + "replicate/llama-2-70b-chat:2796ee9483c3fd7aa2e171d38f4ca12251a30609463dcfd4cd76703f22e96cdf", + "a16z-infra/llama-2-13b-chat:2a7f981751ec7fdf87b5b91ad4db53683a98082e9ff7bfd12c8cd5ea85980a52", + "meta/codellama-13b:1c914d844307b0588599b8393480a3ba917b660c7e9dfae681542b5325f228db", + # Vicuna + "replicate/vicuna-13b:6282abe6a492de4145d7bb601023762212f9ddbbe78278bd6771c8b3b2f2a13b", + "joehoover/instructblip-vicuna13b:c4c54e3c8c97cd50c2d2fec9be3b6065563ccf7d43787fb99f84151b867178fe", + # Flan T-5 + "daanelson/flan-t5-large:ce962b3f6792a57074a601d3979db5839697add2e4e02696b3ced4c022d4767f", + # Others + "replicate/dolly-v2-12b:ef0e1aefc61f8e096ebe4db6b2bacc297daf2ef6899f0f7e001ec445893500e5", + "replit/replit-code-v1-3b:b84f4c074b807211cd75e3e8b1589b6399052125b4c27106e43d47189e8415ad", +] + +clarifai_models: List = [ + "clarifai/meta.Llama-3.Llama-3-8B-Instruct", + "clarifai/gcp.generate.gemma-1_1-7b-it", + "clarifai/mistralai.completion.mixtral-8x22B", + "clarifai/cohere.generate.command-r-plus", + "clarifai/databricks.drbx.dbrx-instruct", + "clarifai/mistralai.completion.mistral-large", + "clarifai/mistralai.completion.mistral-medium", + "clarifai/mistralai.completion.mistral-small", + "clarifai/mistralai.completion.mixtral-8x7B-Instruct-v0_1", + "clarifai/gcp.generate.gemma-2b-it", + "clarifai/gcp.generate.gemma-7b-it", + "clarifai/deci.decilm.deciLM-7B-instruct", + "clarifai/mistralai.completion.mistral-7B-Instruct", + "clarifai/gcp.generate.gemini-pro", + "clarifai/anthropic.completion.claude-v1", + "clarifai/anthropic.completion.claude-instant-1_2", + "clarifai/anthropic.completion.claude-instant", + "clarifai/anthropic.completion.claude-v2", + "clarifai/anthropic.completion.claude-2_1", + "clarifai/meta.Llama-2.codeLlama-70b-Python", + "clarifai/meta.Llama-2.codeLlama-70b-Instruct", + "clarifai/openai.completion.gpt-3_5-turbo-instruct", + "clarifai/meta.Llama-2.llama2-7b-chat", + "clarifai/meta.Llama-2.llama2-13b-chat", + "clarifai/meta.Llama-2.llama2-70b-chat", + "clarifai/openai.chat-completion.gpt-4-turbo", + "clarifai/microsoft.text-generation.phi-2", + "clarifai/meta.Llama-2.llama2-7b-chat-vllm", + "clarifai/upstage.solar.solar-10_7b-instruct", + "clarifai/openchat.openchat.openchat-3_5-1210", + "clarifai/togethercomputer.stripedHyena.stripedHyena-Nous-7B", + "clarifai/gcp.generate.text-bison", + "clarifai/meta.Llama-2.llamaGuard-7b", + "clarifai/fblgit.una-cybertron.una-cybertron-7b-v2", + "clarifai/openai.chat-completion.GPT-4", + "clarifai/openai.chat-completion.GPT-3_5-turbo", + "clarifai/ai21.complete.Jurassic2-Grande", + "clarifai/ai21.complete.Jurassic2-Grande-Instruct", + "clarifai/ai21.complete.Jurassic2-Jumbo-Instruct", + "clarifai/ai21.complete.Jurassic2-Jumbo", + "clarifai/ai21.complete.Jurassic2-Large", + "clarifai/cohere.generate.cohere-generate-command", + "clarifai/wizardlm.generate.wizardCoder-Python-34B", + "clarifai/wizardlm.generate.wizardLM-70B", + "clarifai/tiiuae.falcon.falcon-40b-instruct", + "clarifai/togethercomputer.RedPajama.RedPajama-INCITE-7B-Chat", + "clarifai/gcp.generate.code-gecko", + "clarifai/gcp.generate.code-bison", + "clarifai/mistralai.completion.mistral-7B-OpenOrca", + "clarifai/mistralai.completion.openHermes-2-mistral-7B", + "clarifai/wizardlm.generate.wizardLM-13B", + "clarifai/huggingface-research.zephyr.zephyr-7B-alpha", + "clarifai/wizardlm.generate.wizardCoder-15B", + "clarifai/microsoft.text-generation.phi-1_5", + "clarifai/databricks.Dolly-v2.dolly-v2-12b", + "clarifai/bigcode.code.StarCoder", + "clarifai/salesforce.xgen.xgen-7b-8k-instruct", + "clarifai/mosaicml.mpt.mpt-7b-instruct", + "clarifai/anthropic.completion.claude-3-opus", + "clarifai/anthropic.completion.claude-3-sonnet", + "clarifai/gcp.generate.gemini-1_5-pro", + "clarifai/gcp.generate.imagen-2", + "clarifai/salesforce.blip.general-english-image-caption-blip-2", +] + + +huggingface_models: List = [ + "meta-llama/Llama-2-7b-hf", + "meta-llama/Llama-2-7b-chat-hf", + "meta-llama/Llama-2-13b-hf", + "meta-llama/Llama-2-13b-chat-hf", + "meta-llama/Llama-2-70b-hf", + "meta-llama/Llama-2-70b-chat-hf", + "meta-llama/Llama-2-7b", + "meta-llama/Llama-2-7b-chat", + "meta-llama/Llama-2-13b", + "meta-llama/Llama-2-13b-chat", + "meta-llama/Llama-2-70b", + "meta-llama/Llama-2-70b-chat", +] # these have been tested on extensively. But by default all text2text-generation and text-generation models are supported by liteLLM. - https://docs.litellm.ai/docs/providers +empower_models = [ + "empower/empower-functions", + "empower/empower-functions-small", +] + +together_ai_models: List = [ + # llama llms - chat + "togethercomputer/llama-2-70b-chat", + # llama llms - language / instruct + "togethercomputer/llama-2-70b", + "togethercomputer/LLaMA-2-7B-32K", + "togethercomputer/Llama-2-7B-32K-Instruct", + "togethercomputer/llama-2-7b", + # falcon llms + "togethercomputer/falcon-40b-instruct", + "togethercomputer/falcon-7b-instruct", + # alpaca + "togethercomputer/alpaca-7b", + # chat llms + "HuggingFaceH4/starchat-alpha", + # code llms + "togethercomputer/CodeLlama-34b", + "togethercomputer/CodeLlama-34b-Instruct", + "togethercomputer/CodeLlama-34b-Python", + "defog/sqlcoder", + "NumbersStation/nsql-llama-2-7B", + "WizardLM/WizardCoder-15B-V1.0", + "WizardLM/WizardCoder-Python-34B-V1.0", + # language llms + "NousResearch/Nous-Hermes-Llama2-13b", + "Austism/chronos-hermes-13b", + "upstage/SOLAR-0-70b-16bit", + "WizardLM/WizardLM-70B-V1.0", +] # supports all together ai models, just pass in the model id e.g. completion(model="together_computer/replit_code_3b",...) + + +baseten_models: List = [ + "qvv0xeq", + "q841o8w", + "31dxrj3", +] # FALCON 7B # WizardLM # Mosaic ML + + +# used for Cost Tracking & Token counting +# https://azure.microsoft.com/en-in/pricing/details/cognitive-services/openai-service/ +# Azure returns gpt-35-turbo in their responses, we need to map this to azure/gpt-3.5-turbo for token counting +azure_llms = { + "gpt-35-turbo": "azure/gpt-35-turbo", + "gpt-35-turbo-16k": "azure/gpt-35-turbo-16k", + "gpt-35-turbo-instruct": "azure/gpt-35-turbo-instruct", +} + +azure_embedding_models = { + "ada": "azure/ada", +} + +petals_models = [ + "petals-team/StableBeluga2", +] + +ollama_models = ["llama2"] + +maritalk_models = ["maritalk"] + +model_list = ( + open_ai_chat_completion_models + + open_ai_text_completion_models + + cohere_models + + cohere_chat_models + + anthropic_models + + replicate_models + + openrouter_models + + huggingface_models + + vertex_chat_models + + vertex_text_models + + ai21_models + + ai21_chat_models + + together_ai_models + + baseten_models + + aleph_alpha_models + + nlp_cloud_models + + ollama_models + + bedrock_models + + deepinfra_models + + perplexity_models + + maritalk_models + + vertex_language_models + + watsonx_models + + gemini_models +) + + +class LlmProviders(str, Enum): + OPENAI = "openai" + CUSTOM_OPENAI = "custom_openai" + TEXT_COMPLETION_OPENAI = "text-completion-openai" + COHERE = "cohere" + COHERE_CHAT = "cohere_chat" + CLARIFAI = "clarifai" + ANTHROPIC = "anthropic" + REPLICATE = "replicate" + HUGGINGFACE = "huggingface" + TOGETHER_AI = "together_ai" + OPENROUTER = "openrouter" + VERTEX_AI = "vertex_ai" + VERTEX_AI_BETA = "vertex_ai_beta" + PALM = "palm" + GEMINI = "gemini" + AI21 = "ai21" + BASETEN = "baseten" + AZURE = "azure" + AZURE_TEXT = "azure_text" + AZURE_AI = "azure_ai" + SAGEMAKER = "sagemaker" + SAGEMAKER_CHAT = "sagemaker_chat" + BEDROCK = "bedrock" + VLLM = "vllm" + NLP_CLOUD = "nlp_cloud" + PETALS = "petals" + OOBABOOGA = "oobabooga" + OLLAMA = "ollama" + OLLAMA_CHAT = "ollama_chat" + DEEPINFRA = "deepinfra" + PERPLEXITY = "perplexity" + ANYSCALE = "anyscale" + MISTRAL = "mistral" + GROQ = "groq" + NVIDIA_NIM = "nvidia_nim" + CEREBRAS = "cerebras" + AI21_CHAT = "ai21_chat" + VOLCENGINE = "volcengine" + CODESTRAL = "codestral" + TEXT_COMPLETION_CODESTRAL = "text-completion-codestral" + DEEPSEEK = "deepseek" + MARITALK = "maritalk" + VOYAGE = "voyage" + CLOUDFLARE = "cloudflare" + XINFERENCE = "xinference" + FIREWORKS_AI = "fireworks_ai" + FRIENDLIAI = "friendliai" + WATSONX = "watsonx" + TRITON = "triton" + PREDIBASE = "predibase" + DATABRICKS = "databricks" + EMPOWER = "empower" + GITHUB = "github" + CUSTOM = "custom" + NOTDIAMOND = "notdiamond" + + +provider_list: List[Union[LlmProviders, str]] = list(LlmProviders) + + +models_by_provider: dict = { + "openai": open_ai_chat_completion_models + open_ai_text_completion_models, + "cohere": cohere_models + cohere_chat_models, + "cohere_chat": cohere_chat_models, + "anthropic": anthropic_models, + "replicate": replicate_models, + "huggingface": huggingface_models, + "together_ai": together_ai_models, + "baseten": baseten_models, + "openrouter": openrouter_models, + "vertex_ai": vertex_chat_models + + vertex_text_models + + vertex_anthropic_models + + vertex_vision_models + + vertex_language_models, + "ai21": ai21_models, + "bedrock": bedrock_models, + "petals": petals_models, + "ollama": ollama_models, + "deepinfra": deepinfra_models, + "perplexity": perplexity_models, + "maritalk": maritalk_models, + "watsonx": watsonx_models, + "gemini": gemini_models, + "fireworks_ai": fireworks_ai_models, +} + +# mapping for those models which have larger equivalents +longer_context_model_fallback_dict: dict = { + # openai chat completion models + "gpt-3.5-turbo": "gpt-3.5-turbo-16k", + "gpt-3.5-turbo-0301": "gpt-3.5-turbo-16k-0301", + "gpt-3.5-turbo-0613": "gpt-3.5-turbo-16k-0613", + "gpt-4": "gpt-4-32k", + "gpt-4-0314": "gpt-4-32k-0314", + "gpt-4-0613": "gpt-4-32k-0613", + # anthropic + "claude-instant-1": "claude-2", + "claude-instant-1.2": "claude-2", + # vertexai + "chat-bison": "chat-bison-32k", + "chat-bison@001": "chat-bison-32k", + "codechat-bison": "codechat-bison-32k", + "codechat-bison@001": "codechat-bison-32k", + # openrouter + "openrouter/openai/gpt-3.5-turbo": "openrouter/openai/gpt-3.5-turbo-16k", + "openrouter/anthropic/claude-instant-v1": "openrouter/anthropic/claude-2", +} + +####### EMBEDDING MODELS ################### +open_ai_embedding_models: List = ["text-embedding-ada-002"] +cohere_embedding_models: List = [ + "embed-english-v3.0", + "embed-english-light-v3.0", + "embed-multilingual-v3.0", + "embed-english-v2.0", + "embed-english-light-v2.0", + "embed-multilingual-v2.0", +] +bedrock_embedding_models: List = [ + "amazon.titan-embed-text-v1", + "cohere.embed-english-v3", + "cohere.embed-multilingual-v3", +] + +all_embedding_models = ( + open_ai_embedding_models + + cohere_embedding_models + + bedrock_embedding_models + + vertex_embedding_models +) + +####### IMAGE GENERATION MODELS ################### +openai_image_generation_models = ["dall-e-2", "dall-e-3"] + +from litellm.cost_calculator import completion_cost +from litellm.litellm_core_utils.core_helpers import ( + remove_index_from_tool_calls, +) +from litellm.litellm_core_utils.get_llm_provider_logic import get_llm_provider +from litellm.litellm_core_utils.litellm_logging import Logging +from litellm.litellm_core_utils.token_counter import get_modified_max_tokens +from litellm.timeout import timeout +from litellm.utils import ( + EmbeddingResponse, + ImageResponse, + ModelResponse, + ModelResponseListIterator, + TextCompletionResponse, + TranscriptionResponse, + _calculate_retry_after, + _should_retry, + acreate, + check_valid_key, + client, + create_pretrained_tokenizer, + create_tokenizer, + decode, + encode, + exception_type, + get_api_base, + get_first_chars_messages, + get_litellm_params, + get_max_tokens, + get_model_info, + get_model_list, + get_optional_params, + get_provider_fields, + get_response_string, + get_supported_openai_params, + modify_integration, + register_model, + register_prompt_template, + supports_function_calling, + supports_parallel_function_calling, + supports_response_schema, + supports_system_messages, + supports_vision, + token_counter, + validate_environment, +) + +ALL_LITELLM_RESPONSE_TYPES = [ + ModelResponse, + EmbeddingResponse, + ImageResponse, + TranscriptionResponse, + TextCompletionResponse, +] + +from litellm.assistants.main import * +from litellm.batches.main import * +from litellm.budget_manager import BudgetManager +from litellm.cost_calculator import cost_per_token, response_cost_calculator +from litellm.exceptions import ( + LITELLM_EXCEPTION_TYPES, + APIConnectionError, + APIError, + APIResponseValidationError, + AuthenticationError, + BadRequestError, + BudgetExceededError, + ContentPolicyViolationError, + ContextWindowExceededError, + InternalServerError, + InvalidRequestError, + JSONSchemaValidationError, + MockException, + NotFoundError, + OpenAIError, + RateLimitError, + ServiceUnavailableError, + Timeout, + UnprocessableEntityError, + UnsupportedParamsError, +) +from litellm.files.main import * +from litellm.fine_tuning.main import * +from litellm.integrations import * +from litellm.llms.AI21.chat import AI21ChatConfig +from litellm.llms.AI21.completion import AI21Config +from litellm.llms.aleph_alpha import AlephAlphaConfig +from litellm.llms.anthropic.chat import AnthropicConfig +from litellm.llms.anthropic.completion import AnthropicTextConfig +from litellm.llms.AzureOpenAI.azure import ( + AzureOpenAIAssistantsAPIConfig, + AzureOpenAIConfig, + AzureOpenAIError, +) +from litellm.llms.bedrock.chat import ( + BEDROCK_CONVERSE_MODELS, + AmazonCohereChatConfig, + AmazonConverseConfig, + bedrock_tool_name_mappings, +) +from litellm.llms.bedrock.common_utils import ( + AmazonAI21Config, + AmazonAnthropicClaude3Config, + AmazonAnthropicConfig, + AmazonBedrockGlobalConfig, + AmazonCohereConfig, + AmazonLlamaConfig, + AmazonMistralConfig, + AmazonStabilityConfig, + AmazonTitanConfig, +) +from litellm.llms.bedrock.embed.amazon_titan_g1_transformation import ( + AmazonTitanG1Config, +) +from litellm.llms.bedrock.embed.amazon_titan_multimodal_transformation import ( + AmazonTitanMultimodalEmbeddingG1Config, +) +from litellm.llms.bedrock.embed.amazon_titan_v2_transformation import ( + AmazonTitanV2Config, +) +from litellm.llms.bedrock.embed.cohere_transformation import ( + BedrockCohereEmbeddingConfig, +) +from litellm.llms.cerebras.chat import CerebrasConfig +from litellm.llms.clarifai import ClarifaiConfig +from litellm.llms.cloudflare import CloudflareConfig +from litellm.llms.cohere.completion import CohereConfig +from litellm.llms.custom_llm import CustomLLM +from litellm.llms.databricks.chat import ( + DatabricksConfig, + DatabricksEmbeddingConfig, +) +from litellm.llms.fireworks_ai import FireworksAIConfig +from litellm.llms.gemini import GeminiConfig +from litellm.llms.huggingface_restapi import HuggingfaceConfig +from litellm.llms.maritalk import MaritTalkConfig +from litellm.llms.nlp_cloud import NLPCloudConfig +from litellm.llms.nvidia_nim import NvidiaNimConfig +from litellm.llms.ollama import OllamaConfig +from litellm.llms.ollama_chat import OllamaChatConfig +from litellm.llms.OpenAI.o1_reasoning import OpenAIO1Config +from litellm.llms.OpenAI.openai import ( + AzureAIStudioConfig, + DeepInfraConfig, + GroqConfig, + MistralConfig, + MistralEmbeddingConfig, + OpenAIConfig, + OpenAITextCompletionConfig, +) +from litellm.llms.palm import PalmConfig +from litellm.llms.petals import PetalsConfig +from litellm.llms.predibase import PredibaseConfig +from litellm.llms.replicate import ReplicateConfig +from litellm.llms.sagemaker.sagemaker import SagemakerConfig +from litellm.llms.text_completion_codestral import MistralTextCompletionConfig +from litellm.llms.together_ai import TogetherAIConfig +from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import ( + GoogleAIStudioGeminiConfig, + VertexAIConfig, + VertexGeminiConfig, +) +from litellm.llms.vertex_ai_and_google_ai_studio.vertex_ai_anthropic import ( + VertexAIAnthropicConfig, +) +from litellm.llms.vertex_ai_and_google_ai_studio.vertex_ai_partner_models.ai21.transformation import ( + VertexAIAi21Config, +) +from litellm.llms.vertex_ai_and_google_ai_studio.vertex_ai_partner_models.llama3.transformation import ( + VertexAILlama3Config, +) +from litellm.llms.vertex_ai_and_google_ai_studio.vertex_embeddings.embedding_handler import ( + VertexAITextEmbeddingConfig, +) +from litellm.llms.volcengine import VolcEngineConfig +from litellm.llms.watsonx import IBMWatsonXAIConfig +from litellm.proxy.proxy_cli import run_server +from litellm.rerank_api.main import * +from litellm.router import Router +from litellm.scheduler import * + +### ADAPTERS ### +from litellm.types.adapter import AdapterItem +from litellm.types.utils import ImageObject + +from .main import * + +adapters: List[AdapterItem] = [] + +### CUSTOM LLMs ### +from litellm.types.llms.custom_llm import CustomLLMItem +from litellm.types.utils import GenericStreamingChunk + +custom_provider_map: List[CustomLLMItem] = [] +_custom_providers: List[ + str +] = [] # internal helper util, used to track names of custom providers diff --git a/notdiamond/toolkit/litellm/litellm.py b/notdiamond/toolkit/litellm/main.py similarity index 99% rename from notdiamond/toolkit/litellm/litellm.py rename to notdiamond/toolkit/litellm/main.py index 591bf569..aeb57643 100644 --- a/notdiamond/toolkit/litellm/litellm.py +++ b/notdiamond/toolkit/litellm/main.py @@ -171,6 +171,7 @@ read_config_args, ) +from . import notdiamond_key, provider_list from .litellm_notdiamond import completion as notdiamond_completion openai_chat_completions = OpenAIChatCompletion() @@ -198,9 +199,6 @@ watsonxai = IBMWatsonXAI() sagemaker_llm = SagemakerLLM() -litellm.provider_list.append("notdiamond") -litellm.notdiamond_key = None - class LiteLLM: def __init__( @@ -306,11 +304,7 @@ def get_api_key(llm_provider: str, dynamic_api_key: Optional[str]): ) # notdiamond elif llm_provider == "notdiamond": - api_key = ( - api_key - or litellm.notdiamond_key - or get_secret("NOTDIAMOND_API_KEY") - ) + api_key = api_key or notdiamond_key or get_secret("NOTDIAMOND_API_KEY") # nlp_cloud elif llm_provider == "nlp_cloud": api_key = ( @@ -374,7 +368,7 @@ def get_llm_provider( dynamic_api_key = get_secret(api_key) # check if llm provider part of model name if ( - model.split("/", 1)[0] in litellm.provider_list + model.split("/", 1)[0] in provider_list and model.split("/", 1)[0] not in litellm.model_list and len(model.split("/")) > 1 # handle edge case where user passes in `litellm --model mistral` https://github.com/BerriAI/litellm/issues/1351 @@ -495,7 +489,7 @@ def get_llm_provider( ) ) return model, custom_llm_provider, dynamic_api_key, api_base - elif model.split("/", 1)[0] in litellm.provider_list: + elif model.split("/", 1)[0] in provider_list: custom_llm_provider = model.split("/", 1)[0] model = model.split("/", 1)[1] if api_base is not None and not isinstance(api_base, str): @@ -1565,7 +1559,7 @@ def completion( elif custom_llm_provider == "notdiamond": notdiamond_key = ( api_key - or litellm.notdiamond_key + or notdiamond_key or get_secret("NOTDIAMOND_API_KEY") or litellm.api_key ) @@ -3538,7 +3532,7 @@ def batch_completion( completions = [] model = model custom_llm_provider = None - if model.split("/", 1)[0] in litellm.provider_list: + if model.split("/", 1)[0] in provider_list: custom_llm_provider = model.split("/", 1)[0] model = model.split("/", 1)[1] if custom_llm_provider == "vllm": @@ -5694,7 +5688,7 @@ def speech( if response is None: raise Exception( "Unable to map the custom llm provider={} to a known provider={}.".format( - custom_llm_provider, litellm.provider_list + custom_llm_provider, provider_list ) ) return response diff --git a/poetry.lock b/poetry.lock index ae106ec6..a9bdc0e9 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2195,13 +2195,13 @@ requests = ">=2,<3" [[package]] name = "litellm" -version = "1.44.26" +version = "1.44.28" description = "Library to easily interface with LLM API providers" optional = false python-versions = "!=2.7.*,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,!=3.7.*,>=3.8" files = [ - {file = "litellm-1.44.26-py3-none-any.whl", hash = "sha256:de63115a19e1432a44e38b9a2d8d8dfb77d5745370ef80d45327017b7341dea5"}, - {file = "litellm-1.44.26.tar.gz", hash = "sha256:10856ad8e9b5fed96f2f4ad62dd3266240c2d996ef20ee9e233a2140e9c1eb5b"}, + {file = "litellm-1.44.28-py3-none-any.whl", hash = "sha256:a4476c1f076b7996a97bd5d51e53760be482f1ec888c0c626c7877d5c6ff0849"}, + {file = "litellm-1.44.28.tar.gz", hash = "sha256:9a9055ce3f655201e4527786c219eaa98579c0134c031418bc38744fed3cd265"}, ] [package.dependencies] @@ -2210,7 +2210,7 @@ click = "*" importlib-metadata = ">=6.8.0" jinja2 = ">=3.1.2,<4.0.0" jsonschema = ">=4.22.0,<5.0.0" -openai = ">=1.40.0" +openai = ">=1.45.0" pydantic = ">=2.0.0,<3.0.0" python-dotenv = ">=0.2.0" requests = ">=2.31.0,<3.0.0" @@ -2586,13 +2586,13 @@ files = [ [[package]] name = "openai" -version = "1.43.0" +version = "1.45.0" description = "The official Python library for the openai API" optional = false python-versions = ">=3.7.1" files = [ - {file = "openai-1.43.0-py3-none-any.whl", hash = "sha256:1a748c2728edd3a738a72a0212ba866f4fdbe39c9ae03813508b267d45104abe"}, - {file = "openai-1.43.0.tar.gz", hash = "sha256:e607aff9fc3e28eade107e5edd8ca95a910a4b12589336d3cbb6bfe2ac306b3c"}, + {file = "openai-1.45.0-py3-none-any.whl", hash = "sha256:2f1f7b7cf90f038a9f1c24f0d26c0f1790c102ec5acd07ffd70a9b7feac1ff4e"}, + {file = "openai-1.45.0.tar.gz", hash = "sha256:731207d10637335413aa3c0955f8f8df30d7636a4a0f9c381f2209d32cf8de97"}, ] [package.dependencies] diff --git a/tests/test_toolkit/test_litellm.py b/tests/test_toolkit/test_litellm.py index 7c73c50e..ace4ea34 100644 --- a/tests/test_toolkit/test_litellm.py +++ b/tests/test_toolkit/test_litellm.py @@ -47,10 +47,10 @@ {"provider": "google", "model": "gemini-1.5-flash-latest"}, {"provider": "google", "model": "gemini-1.0-pro-latest"}, # {"provider": "replicate", "model": "mistral-7b-instruct-v0.2"}, removed due to replicate side error - {"provider": "replicate", "model": "mixtral-8x7b-instruct-v0.1"}, - {"provider": "replicate", "model": "meta-llama-3-70b-instruct"}, + # {"provider": "replicate", "model": "mixtral-8x7b-instruct-v0.1"}, removed due to replicate side error + # {"provider": "replicate", "model": "meta-llama-3-70b-instruct"}, removed due to replicate side error + # {"provider": "replicate", "model": "meta-llama-3.1-405b-instruct"}, removed due to replicate side error {"provider": "replicate", "model": "meta-llama-3-8b-instruct"}, - {"provider": "replicate", "model": "meta-llama-3.1-405b-instruct"}, {"provider": "togetherai", "model": "Mistral-7B-Instruct-v0.2"}, {"provider": "togetherai", "model": "Mixtral-8x7B-Instruct-v0.1"}, {"provider": "togetherai", "model": "Mixtral-8x22B-Instruct-v0.1"}, From 95df0402c47be9200caaba0a69c652b7773d01f8 Mon Sep 17 00:00:00 2001 From: Tze-Yang Tung Date: Thu, 26 Sep 2024 14:55:57 -0400 Subject: [PATCH 4/7] massive cleanup --- notdiamond/toolkit/litellm/__init__.py | 983 +----- .../toolkit/litellm/litellm_notdiamond.py | 6 +- notdiamond/toolkit/litellm/main.py | 3101 +---------------- tests/test_toolkit/test_litellm.py | 2 + 4 files changed, 38 insertions(+), 4054 deletions(-) diff --git a/notdiamond/toolkit/litellm/__init__.py b/notdiamond/toolkit/litellm/__init__.py index 1140ee27..b53f29ac 100644 --- a/notdiamond/toolkit/litellm/__init__.py +++ b/notdiamond/toolkit/litellm/__init__.py @@ -1,714 +1,11 @@ -# flake8: noqa -# -### Hide pydantic namespace conflict warnings globally ### -import warnings - -from .litellm_notdiamond import NotDiamondConfig - -warnings.filterwarnings( - "ignore", message=".*conflict with protected namespace.*" -) -import os - -### INIT VARIABLES ### -import threading from enum import Enum -from typing import ( - Any, - Callable, - Dict, - List, - Literal, - Optional, - Union, - get_args, -) +from typing import List, Optional, Union -import dotenv -import httpx -import requests -from litellm._logging import ( - _turn_on_debug, - _turn_on_json, - json_logs, - log_level, - set_verbose, - verbose_logger, -) -from litellm.caching import Cache -from litellm.llms.custom_httpx.http_handler import ( - AsyncHTTPHandler, - HTTPHandler, -) -from litellm.proxy._types import ( - KeyManagementSettings, - KeyManagementSystem, - LiteLLM_UpperboundKeyGenerateParams, -) -from litellm.types.guardrails import GuardrailItem +from litellm.__init__ import * # noqa -litellm_mode = os.getenv("LITELLM_MODE", "DEV") # "PRODUCTION", "DEV" -if litellm_mode == "DEV": - dotenv.load_dotenv() -############################################# -if set_verbose == True: - _turn_on_debug() -############################################# -### Callbacks /Logging / Success / Failure Handlers ### -input_callback: List[Union[str, Callable]] = [] -success_callback: List[Union[str, Callable]] = [] -failure_callback: List[Union[str, Callable]] = [] -service_callback: List[Union[str, Callable]] = [] -_custom_logger_compatible_callbacks_literal = Literal[ - "lago", - "openmeter", - "logfire", - "dynamic_rate_limiter", - "langsmith", - "prometheus", - "galileo", - "braintrust", - "arize", - "gcs_bucket", -] -_known_custom_logger_compatible_callbacks: List = list( - get_args(_custom_logger_compatible_callbacks_literal) -) -callbacks: List[ - Union[Callable, _custom_logger_compatible_callbacks_literal] -] = [] -langfuse_default_tags: Optional[List[str]] = None -langsmith_batch_size: Optional[int] = None -_async_input_callback: List[ - Callable -] = [] # internal variable - async custom callbacks are routed here. -_async_success_callback: List[ - Union[str, Callable] -] = [] # internal variable - async custom callbacks are routed here. -_async_failure_callback: List[ - Callable -] = [] # internal variable - async custom callbacks are routed here. -pre_call_rules: List[Callable] = [] -post_call_rules: List[Callable] = [] -turn_off_message_logging: Optional[bool] = False -log_raw_request_response: bool = False -redact_messages_in_exceptions: Optional[bool] = False -redact_user_api_key_info: Optional[bool] = False -store_audit_logs = False # Enterprise feature, allow users to see audit logs -## end of callbacks ############# +from .litellm_notdiamond import NotDiamondConfig # noqa -email: Optional[ - str -] = None # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648 -token: Optional[ - str -] = None # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648 -telemetry = True -max_tokens = 256 # OpenAI Defaults -drop_params = bool(os.getenv("LITELLM_DROP_PARAMS", False)) -modify_params = False -retry = True -### AUTH ### -api_key: Optional[str] = None -openai_key: Optional[str] = None -databricks_key: Optional[str] = None -azure_key: Optional[str] = None -anthropic_key: Optional[str] = None -replicate_key: Optional[str] = None -cohere_key: Optional[str] = None -clarifai_key: Optional[str] = None -maritalk_key: Optional[str] = None -ai21_key: Optional[str] = None -ollama_key: Optional[str] = None -openrouter_key: Optional[str] = None -predibase_key: Optional[str] = None -huggingface_key: Optional[str] = None -vertex_project: Optional[str] = None -vertex_location: Optional[str] = None -predibase_tenant_id: Optional[str] = None -togetherai_api_key: Optional[str] = None -cloudflare_api_key: Optional[str] = None -baseten_key: Optional[str] = None notdiamond_key: Optional[str] = None -aleph_alpha_key: Optional[str] = None -nlp_cloud_key: Optional[str] = None -common_cloud_provider_auth_params: dict = { - "params": ["project", "region_name", "token"], - "providers": [ - "vertex_ai", - "bedrock", - "watsonx", - "azure", - "vertex_ai_beta", - ], -} -use_client: bool = False -ssl_verify: Union[str, bool] = True -ssl_certificate: Optional[str] = None -disable_streaming_logging: bool = False -in_memory_llm_clients_cache: dict = {} -safe_memory_mode: bool = False -enable_azure_ad_token_refresh: Optional[bool] = False -### DEFAULT AZURE API VERSION ### -AZURE_DEFAULT_API_VERSION = ( - "2024-08-01-preview" # this is updated to the latest -) -### COHERE EMBEDDINGS DEFAULT TYPE ### -COHERE_DEFAULT_EMBEDDING_INPUT_TYPE = "search_document" -### GUARDRAILS ### -llamaguard_model_name: Optional[str] = None -openai_moderations_model_name: Optional[str] = None -presidio_ad_hoc_recognizers: Optional[str] = None -google_moderation_confidence_threshold: Optional[float] = None -llamaguard_unsafe_content_categories: Optional[str] = None -blocked_user_list: Optional[Union[str, List]] = None -banned_keywords_list: Optional[Union[str, List]] = None -llm_guard_mode: Literal["all", "key-specific", "request-specific"] = "all" -guardrail_name_config_map: Dict[str, GuardrailItem] = {} -################## -### PREVIEW FEATURES ### -enable_preview_features: bool = False -return_response_headers: bool = False # get response headers from LLM Api providers - example x-remaining-requests, -enable_json_schema_validation: bool = False -################## -logging: bool = True -enable_loadbalancing_on_batch_endpoints: Optional[bool] = None -enable_caching_on_provider_specific_optional_params: bool = ( - False # feature-flag for caching on optional params - e.g. 'top_k' -) -caching: bool = False # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648 -always_read_redis: bool = ( - True # always use redis for rate limiting logic on litellm proxy -) -caching_with_models: bool = False # # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648 -cache: Optional[ - Cache -] = None # cache object <- use this - https://docs.litellm.ai/docs/caching -default_in_memory_ttl: Optional[float] = None -default_redis_ttl: Optional[float] = None -model_alias_map: Dict[str, str] = {} -model_group_alias_map: Dict[str, str] = {} -max_budget: float = 0.0 # set the max budget across all providers -budget_duration: Optional[ - str -] = None # proxy only - resets budget after fixed duration. You can set duration as seconds ("30s"), minutes ("30m"), hours ("30h"), days ("30d"). -default_soft_budget: float = ( - 50.0 # by default all litellm proxy keys have a soft budget of 50.0 -) -forward_traceparent_to_llm_provider: bool = False -_openai_finish_reasons = [ - "stop", - "length", - "function_call", - "content_filter", - "null", -] -_openai_completion_params = [ - "functions", - "function_call", - "temperature", - "temperature", - "top_p", - "n", - "stream", - "stop", - "max_tokens", - "presence_penalty", - "frequency_penalty", - "logit_bias", - "user", - "request_timeout", - "api_base", - "api_version", - "api_key", - "deployment_id", - "organization", - "base_url", - "default_headers", - "timeout", - "response_format", - "seed", - "tools", - "tool_choice", - "max_retries", -] -_litellm_completion_params = [ - "metadata", - "acompletion", - "caching", - "mock_response", - "api_key", - "api_version", - "api_base", - "force_timeout", - "logger_fn", - "verbose", - "custom_llm_provider", - "litellm_logging_obj", - "litellm_call_id", - "use_client", - "id", - "fallbacks", - "azure", - "headers", - "model_list", - "num_retries", - "context_window_fallback_dict", - "roles", - "final_prompt_value", - "bos_token", - "eos_token", - "request_timeout", - "complete_response", - "self", - "client", - "rpm", - "tpm", - "input_cost_per_token", - "output_cost_per_token", - "hf_model_name", - "model_info", - "proxy_server_request", - "preset_cache_key", -] -_current_cost = 0 # private variable, used if max budget is set -error_logs: Dict = {} -add_function_to_prompt: bool = False # if function calling not supported by api, append function call details to system prompt -client_session: Optional[httpx.Client] = None -aclient_session: Optional[httpx.AsyncClient] = None -model_fallbacks: Optional[List] = None # Deprecated for 'litellm.fallbacks' -model_cost_map_url: str = "https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json" -suppress_debug_info = False -dynamodb_table_name: Optional[str] = None -s3_callback_params: Optional[Dict] = None -generic_logger_headers: Optional[Dict] = None -default_key_generate_params: Optional[Dict] = None -upperbound_key_generate_params: Optional[ - LiteLLM_UpperboundKeyGenerateParams -] = None -default_user_params: Optional[Dict] = None -default_team_settings: Optional[List] = None -max_user_budget: Optional[float] = None -default_max_internal_user_budget: Optional[float] = None -max_internal_user_budget: Optional[float] = None -internal_user_budget_duration: Optional[str] = None -max_end_user_budget: Optional[float] = None -#### REQUEST PRIORITIZATION #### -priority_reservation: Optional[Dict[str, float]] = None -#### RELIABILITY #### -REPEATED_STREAMING_CHUNK_LIMIT = 100 # catch if model starts looping the same chunk while streaming. Uses high default to prevent false positives. -request_timeout: float = 6000 -module_level_aclient = AsyncHTTPHandler(timeout=request_timeout) -module_level_client = HTTPHandler(timeout=request_timeout) -num_retries: Optional[int] = None # per model endpoint -default_fallbacks: Optional[List] = None -fallbacks: Optional[List] = None -context_window_fallbacks: Optional[List] = None -content_policy_fallbacks: Optional[List] = None -allowed_fails: int = 3 -num_retries_per_request: Optional[ - int -] = None # for the request overall (incl. fallbacks + model retries) -####### SECRET MANAGERS ##################### -secret_manager_client: Optional[ - Any -] = None # list of instantiated key management clients - e.g. azure kv, infisical, etc. -_google_kms_resource_name: Optional[str] = None -_key_management_system: Optional[KeyManagementSystem] = None -_key_management_settings: Optional[KeyManagementSettings] = None -#### PII MASKING #### -output_parse_pii: bool = False -############################################# - - -def get_model_cost_map(url: str): - if ( - os.getenv("LITELLM_LOCAL_MODEL_COST_MAP", False) == True - or os.getenv("LITELLM_LOCAL_MODEL_COST_MAP", False) == "True" - ): - import importlib.resources - import json - - with importlib.resources.open_text( - "litellm", "model_prices_and_context_window_backup.json" - ) as f: - content = json.load(f) - return content - - try: - with requests.get( - url, timeout=5 - ) as response: # set a 5 second timeout for the get request - response.raise_for_status() # Raise an exception if the request is unsuccessful - content = response.json() - return content - except Exception as e: - import importlib.resources - import json - - with importlib.resources.open_text( - "litellm", "model_prices_and_context_window_backup.json" - ) as f: - content = json.load(f) - return content - - -model_cost = get_model_cost_map(url=model_cost_map_url) -custom_prompt_dict: Dict[str, dict] = {} - - -####### THREAD-SPECIFIC DATA ################### -class MyLocal(threading.local): - def __init__(self): - self.user = "Hello World" - - -_thread_context = MyLocal() - - -def identify(event_details): - # Store user in thread local data - if "user" in event_details: - _thread_context.user = event_details["user"] - - -####### ADDITIONAL PARAMS ################### configurable params if you use proxy models like Helicone, map spend to org id, etc. -api_base = None -headers = None -api_version = None -organization = None -project = None -config_path = None -vertex_ai_safety_settings: Optional[dict] = None -####### COMPLETION MODELS ################### -open_ai_chat_completion_models: List = [] -open_ai_text_completion_models: List = [] -cohere_models: List = [] -cohere_chat_models: List = [] -mistral_chat_models: List = [] -anthropic_models: List = [] -empower_models: List = [] -openrouter_models: List = [] -vertex_language_models: List = [] -vertex_vision_models: List = [] -vertex_chat_models: List = [] -vertex_code_chat_models: List = [] -vertex_ai_image_models: List = [] -vertex_text_models: List = [] -vertex_code_text_models: List = [] -vertex_embedding_models: List = [] -vertex_anthropic_models: List = [] -vertex_llama3_models: List = [] -vertex_ai_ai21_models: List = [] -vertex_mistral_models: List = [] -ai21_models: List = [] -ai21_chat_models: List = [] -nlp_cloud_models: List = [] -aleph_alpha_models: List = [] -bedrock_models: List = [] -fireworks_ai_models: List = [] -deepinfra_models: List = [] -perplexity_models: List = [] -watsonx_models: List = [] -gemini_models: List = [] -for key, value in model_cost.items(): - if value.get("litellm_provider") == "openai": - open_ai_chat_completion_models.append(key) - elif value.get("litellm_provider") == "text-completion-openai": - open_ai_text_completion_models.append(key) - elif value.get("litellm_provider") == "cohere": - cohere_models.append(key) - elif value.get("litellm_provider") == "cohere_chat": - cohere_chat_models.append(key) - elif value.get("litellm_provider") == "mistral": - mistral_chat_models.append(key) - elif value.get("litellm_provider") == "anthropic": - anthropic_models.append(key) - elif value.get("litellm_provider") == "empower": - empower_models.append(key) - elif value.get("litellm_provider") == "openrouter": - openrouter_models.append(key) - elif value.get("litellm_provider") == "vertex_ai-text-models": - vertex_text_models.append(key) - elif value.get("litellm_provider") == "vertex_ai-code-text-models": - vertex_code_text_models.append(key) - elif value.get("litellm_provider") == "vertex_ai-language-models": - vertex_language_models.append(key) - elif value.get("litellm_provider") == "vertex_ai-vision-models": - vertex_vision_models.append(key) - elif value.get("litellm_provider") == "vertex_ai-chat-models": - vertex_chat_models.append(key) - elif value.get("litellm_provider") == "vertex_ai-code-chat-models": - vertex_code_chat_models.append(key) - elif value.get("litellm_provider") == "vertex_ai-embedding-models": - vertex_embedding_models.append(key) - elif value.get("litellm_provider") == "vertex_ai-anthropic_models": - key = key.replace("vertex_ai/", "") - vertex_anthropic_models.append(key) - elif value.get("litellm_provider") == "vertex_ai-llama_models": - key = key.replace("vertex_ai/", "") - vertex_llama3_models.append(key) - elif value.get("litellm_provider") == "vertex_ai-mistral_models": - key = key.replace("vertex_ai/", "") - vertex_mistral_models.append(key) - elif value.get("litellm_provider") == "vertex_ai-ai21_models": - key = key.replace("vertex_ai/", "") - vertex_ai_ai21_models.append(key) - elif value.get("litellm_provider") == "vertex_ai-image-models": - key = key.replace("vertex_ai/", "") - vertex_ai_image_models.append(key) - elif value.get("litellm_provider") == "ai21": - if value.get("mode") == "chat": - ai21_chat_models.append(key) - else: - ai21_models.append(key) - elif value.get("litellm_provider") == "nlp_cloud": - nlp_cloud_models.append(key) - elif value.get("litellm_provider") == "aleph_alpha": - aleph_alpha_models.append(key) - elif value.get("litellm_provider") == "bedrock": - bedrock_models.append(key) - elif value.get("litellm_provider") == "deepinfra": - deepinfra_models.append(key) - elif value.get("litellm_provider") == "perplexity": - perplexity_models.append(key) - elif value.get("litellm_provider") == "watsonx": - watsonx_models.append(key) - elif value.get("litellm_provider") == "gemini": - gemini_models.append(key) - elif value.get("litellm_provider") == "fireworks_ai": - fireworks_ai_models.append(key) -# known openai compatible endpoints - we'll eventually move this list to the model_prices_and_context_window.json dictionary -openai_compatible_endpoints: List = [ - "api.perplexity.ai", - "api.endpoints.anyscale.com/v1", - "api.deepinfra.com/v1/openai", - "api.mistral.ai/v1", - "codestral.mistral.ai/v1/chat/completions", - "codestral.mistral.ai/v1/fim/completions", - "api.groq.com/openai/v1", - "https://integrate.api.nvidia.com/v1", - "api.deepseek.com/v1", - "api.together.xyz/v1", - "app.empower.dev/api/v1", - "inference.friendli.ai/v1", -] - -# this is maintained for Exception Mapping -openai_compatible_providers: List = [ - "anyscale", - "mistral", - "groq", - "nvidia_nim", - "cerebras", - "ai21_chat", - "volcengine", - "codestral", - "deepseek", - "deepinfra", - "perplexity", - "xinference", - "together_ai", - "fireworks_ai", - "empower", - "friendliai", - "azure_ai", - "github", -] -openai_text_completion_compatible_providers: List = ( - [ # providers that support `/v1/completions` - "together_ai", - "fireworks_ai", - ] -) - -# well supported replicate llms -replicate_models: List = [ - # llama replicate supported LLMs - "replicate/llama-2-70b-chat:2796ee9483c3fd7aa2e171d38f4ca12251a30609463dcfd4cd76703f22e96cdf", - "a16z-infra/llama-2-13b-chat:2a7f981751ec7fdf87b5b91ad4db53683a98082e9ff7bfd12c8cd5ea85980a52", - "meta/codellama-13b:1c914d844307b0588599b8393480a3ba917b660c7e9dfae681542b5325f228db", - # Vicuna - "replicate/vicuna-13b:6282abe6a492de4145d7bb601023762212f9ddbbe78278bd6771c8b3b2f2a13b", - "joehoover/instructblip-vicuna13b:c4c54e3c8c97cd50c2d2fec9be3b6065563ccf7d43787fb99f84151b867178fe", - # Flan T-5 - "daanelson/flan-t5-large:ce962b3f6792a57074a601d3979db5839697add2e4e02696b3ced4c022d4767f", - # Others - "replicate/dolly-v2-12b:ef0e1aefc61f8e096ebe4db6b2bacc297daf2ef6899f0f7e001ec445893500e5", - "replit/replit-code-v1-3b:b84f4c074b807211cd75e3e8b1589b6399052125b4c27106e43d47189e8415ad", -] - -clarifai_models: List = [ - "clarifai/meta.Llama-3.Llama-3-8B-Instruct", - "clarifai/gcp.generate.gemma-1_1-7b-it", - "clarifai/mistralai.completion.mixtral-8x22B", - "clarifai/cohere.generate.command-r-plus", - "clarifai/databricks.drbx.dbrx-instruct", - "clarifai/mistralai.completion.mistral-large", - "clarifai/mistralai.completion.mistral-medium", - "clarifai/mistralai.completion.mistral-small", - "clarifai/mistralai.completion.mixtral-8x7B-Instruct-v0_1", - "clarifai/gcp.generate.gemma-2b-it", - "clarifai/gcp.generate.gemma-7b-it", - "clarifai/deci.decilm.deciLM-7B-instruct", - "clarifai/mistralai.completion.mistral-7B-Instruct", - "clarifai/gcp.generate.gemini-pro", - "clarifai/anthropic.completion.claude-v1", - "clarifai/anthropic.completion.claude-instant-1_2", - "clarifai/anthropic.completion.claude-instant", - "clarifai/anthropic.completion.claude-v2", - "clarifai/anthropic.completion.claude-2_1", - "clarifai/meta.Llama-2.codeLlama-70b-Python", - "clarifai/meta.Llama-2.codeLlama-70b-Instruct", - "clarifai/openai.completion.gpt-3_5-turbo-instruct", - "clarifai/meta.Llama-2.llama2-7b-chat", - "clarifai/meta.Llama-2.llama2-13b-chat", - "clarifai/meta.Llama-2.llama2-70b-chat", - "clarifai/openai.chat-completion.gpt-4-turbo", - "clarifai/microsoft.text-generation.phi-2", - "clarifai/meta.Llama-2.llama2-7b-chat-vllm", - "clarifai/upstage.solar.solar-10_7b-instruct", - "clarifai/openchat.openchat.openchat-3_5-1210", - "clarifai/togethercomputer.stripedHyena.stripedHyena-Nous-7B", - "clarifai/gcp.generate.text-bison", - "clarifai/meta.Llama-2.llamaGuard-7b", - "clarifai/fblgit.una-cybertron.una-cybertron-7b-v2", - "clarifai/openai.chat-completion.GPT-4", - "clarifai/openai.chat-completion.GPT-3_5-turbo", - "clarifai/ai21.complete.Jurassic2-Grande", - "clarifai/ai21.complete.Jurassic2-Grande-Instruct", - "clarifai/ai21.complete.Jurassic2-Jumbo-Instruct", - "clarifai/ai21.complete.Jurassic2-Jumbo", - "clarifai/ai21.complete.Jurassic2-Large", - "clarifai/cohere.generate.cohere-generate-command", - "clarifai/wizardlm.generate.wizardCoder-Python-34B", - "clarifai/wizardlm.generate.wizardLM-70B", - "clarifai/tiiuae.falcon.falcon-40b-instruct", - "clarifai/togethercomputer.RedPajama.RedPajama-INCITE-7B-Chat", - "clarifai/gcp.generate.code-gecko", - "clarifai/gcp.generate.code-bison", - "clarifai/mistralai.completion.mistral-7B-OpenOrca", - "clarifai/mistralai.completion.openHermes-2-mistral-7B", - "clarifai/wizardlm.generate.wizardLM-13B", - "clarifai/huggingface-research.zephyr.zephyr-7B-alpha", - "clarifai/wizardlm.generate.wizardCoder-15B", - "clarifai/microsoft.text-generation.phi-1_5", - "clarifai/databricks.Dolly-v2.dolly-v2-12b", - "clarifai/bigcode.code.StarCoder", - "clarifai/salesforce.xgen.xgen-7b-8k-instruct", - "clarifai/mosaicml.mpt.mpt-7b-instruct", - "clarifai/anthropic.completion.claude-3-opus", - "clarifai/anthropic.completion.claude-3-sonnet", - "clarifai/gcp.generate.gemini-1_5-pro", - "clarifai/gcp.generate.imagen-2", - "clarifai/salesforce.blip.general-english-image-caption-blip-2", -] - - -huggingface_models: List = [ - "meta-llama/Llama-2-7b-hf", - "meta-llama/Llama-2-7b-chat-hf", - "meta-llama/Llama-2-13b-hf", - "meta-llama/Llama-2-13b-chat-hf", - "meta-llama/Llama-2-70b-hf", - "meta-llama/Llama-2-70b-chat-hf", - "meta-llama/Llama-2-7b", - "meta-llama/Llama-2-7b-chat", - "meta-llama/Llama-2-13b", - "meta-llama/Llama-2-13b-chat", - "meta-llama/Llama-2-70b", - "meta-llama/Llama-2-70b-chat", -] # these have been tested on extensively. But by default all text2text-generation and text-generation models are supported by liteLLM. - https://docs.litellm.ai/docs/providers -empower_models = [ - "empower/empower-functions", - "empower/empower-functions-small", -] - -together_ai_models: List = [ - # llama llms - chat - "togethercomputer/llama-2-70b-chat", - # llama llms - language / instruct - "togethercomputer/llama-2-70b", - "togethercomputer/LLaMA-2-7B-32K", - "togethercomputer/Llama-2-7B-32K-Instruct", - "togethercomputer/llama-2-7b", - # falcon llms - "togethercomputer/falcon-40b-instruct", - "togethercomputer/falcon-7b-instruct", - # alpaca - "togethercomputer/alpaca-7b", - # chat llms - "HuggingFaceH4/starchat-alpha", - # code llms - "togethercomputer/CodeLlama-34b", - "togethercomputer/CodeLlama-34b-Instruct", - "togethercomputer/CodeLlama-34b-Python", - "defog/sqlcoder", - "NumbersStation/nsql-llama-2-7B", - "WizardLM/WizardCoder-15B-V1.0", - "WizardLM/WizardCoder-Python-34B-V1.0", - # language llms - "NousResearch/Nous-Hermes-Llama2-13b", - "Austism/chronos-hermes-13b", - "upstage/SOLAR-0-70b-16bit", - "WizardLM/WizardLM-70B-V1.0", -] # supports all together ai models, just pass in the model id e.g. completion(model="together_computer/replit_code_3b",...) - - -baseten_models: List = [ - "qvv0xeq", - "q841o8w", - "31dxrj3", -] # FALCON 7B # WizardLM # Mosaic ML - - -# used for Cost Tracking & Token counting -# https://azure.microsoft.com/en-in/pricing/details/cognitive-services/openai-service/ -# Azure returns gpt-35-turbo in their responses, we need to map this to azure/gpt-3.5-turbo for token counting -azure_llms = { - "gpt-35-turbo": "azure/gpt-35-turbo", - "gpt-35-turbo-16k": "azure/gpt-35-turbo-16k", - "gpt-35-turbo-instruct": "azure/gpt-35-turbo-instruct", -} - -azure_embedding_models = { - "ada": "azure/ada", -} - -petals_models = [ - "petals-team/StableBeluga2", -] - -ollama_models = ["llama2"] - -maritalk_models = ["maritalk"] - -model_list = ( - open_ai_chat_completion_models - + open_ai_text_completion_models - + cohere_models - + cohere_chat_models - + anthropic_models - + replicate_models - + openrouter_models - + huggingface_models - + vertex_chat_models - + vertex_text_models - + ai21_models - + ai21_chat_models - + together_ai_models - + baseten_models - + aleph_alpha_models - + nlp_cloud_models - + ollama_models - + bedrock_models - + deepinfra_models - + perplexity_models - + maritalk_models - + vertex_language_models - + watsonx_models - + gemini_models -) class LlmProviders(str, Enum): @@ -771,276 +68,4 @@ class LlmProviders(str, Enum): provider_list: List[Union[LlmProviders, str]] = list(LlmProviders) - -models_by_provider: dict = { - "openai": open_ai_chat_completion_models + open_ai_text_completion_models, - "cohere": cohere_models + cohere_chat_models, - "cohere_chat": cohere_chat_models, - "anthropic": anthropic_models, - "replicate": replicate_models, - "huggingface": huggingface_models, - "together_ai": together_ai_models, - "baseten": baseten_models, - "openrouter": openrouter_models, - "vertex_ai": vertex_chat_models - + vertex_text_models - + vertex_anthropic_models - + vertex_vision_models - + vertex_language_models, - "ai21": ai21_models, - "bedrock": bedrock_models, - "petals": petals_models, - "ollama": ollama_models, - "deepinfra": deepinfra_models, - "perplexity": perplexity_models, - "maritalk": maritalk_models, - "watsonx": watsonx_models, - "gemini": gemini_models, - "fireworks_ai": fireworks_ai_models, -} - -# mapping for those models which have larger equivalents -longer_context_model_fallback_dict: dict = { - # openai chat completion models - "gpt-3.5-turbo": "gpt-3.5-turbo-16k", - "gpt-3.5-turbo-0301": "gpt-3.5-turbo-16k-0301", - "gpt-3.5-turbo-0613": "gpt-3.5-turbo-16k-0613", - "gpt-4": "gpt-4-32k", - "gpt-4-0314": "gpt-4-32k-0314", - "gpt-4-0613": "gpt-4-32k-0613", - # anthropic - "claude-instant-1": "claude-2", - "claude-instant-1.2": "claude-2", - # vertexai - "chat-bison": "chat-bison-32k", - "chat-bison@001": "chat-bison-32k", - "codechat-bison": "codechat-bison-32k", - "codechat-bison@001": "codechat-bison-32k", - # openrouter - "openrouter/openai/gpt-3.5-turbo": "openrouter/openai/gpt-3.5-turbo-16k", - "openrouter/anthropic/claude-instant-v1": "openrouter/anthropic/claude-2", -} - -####### EMBEDDING MODELS ################### -open_ai_embedding_models: List = ["text-embedding-ada-002"] -cohere_embedding_models: List = [ - "embed-english-v3.0", - "embed-english-light-v3.0", - "embed-multilingual-v3.0", - "embed-english-v2.0", - "embed-english-light-v2.0", - "embed-multilingual-v2.0", -] -bedrock_embedding_models: List = [ - "amazon.titan-embed-text-v1", - "cohere.embed-english-v3", - "cohere.embed-multilingual-v3", -] - -all_embedding_models = ( - open_ai_embedding_models - + cohere_embedding_models - + bedrock_embedding_models - + vertex_embedding_models -) - -####### IMAGE GENERATION MODELS ################### -openai_image_generation_models = ["dall-e-2", "dall-e-3"] - -from litellm.cost_calculator import completion_cost -from litellm.litellm_core_utils.core_helpers import ( - remove_index_from_tool_calls, -) -from litellm.litellm_core_utils.get_llm_provider_logic import get_llm_provider -from litellm.litellm_core_utils.litellm_logging import Logging -from litellm.litellm_core_utils.token_counter import get_modified_max_tokens -from litellm.timeout import timeout -from litellm.utils import ( - EmbeddingResponse, - ImageResponse, - ModelResponse, - ModelResponseListIterator, - TextCompletionResponse, - TranscriptionResponse, - _calculate_retry_after, - _should_retry, - acreate, - check_valid_key, - client, - create_pretrained_tokenizer, - create_tokenizer, - decode, - encode, - exception_type, - get_api_base, - get_first_chars_messages, - get_litellm_params, - get_max_tokens, - get_model_info, - get_model_list, - get_optional_params, - get_provider_fields, - get_response_string, - get_supported_openai_params, - modify_integration, - register_model, - register_prompt_template, - supports_function_calling, - supports_parallel_function_calling, - supports_response_schema, - supports_system_messages, - supports_vision, - token_counter, - validate_environment, -) - -ALL_LITELLM_RESPONSE_TYPES = [ - ModelResponse, - EmbeddingResponse, - ImageResponse, - TranscriptionResponse, - TextCompletionResponse, -] - -from litellm.assistants.main import * -from litellm.batches.main import * -from litellm.budget_manager import BudgetManager -from litellm.cost_calculator import cost_per_token, response_cost_calculator -from litellm.exceptions import ( - LITELLM_EXCEPTION_TYPES, - APIConnectionError, - APIError, - APIResponseValidationError, - AuthenticationError, - BadRequestError, - BudgetExceededError, - ContentPolicyViolationError, - ContextWindowExceededError, - InternalServerError, - InvalidRequestError, - JSONSchemaValidationError, - MockException, - NotFoundError, - OpenAIError, - RateLimitError, - ServiceUnavailableError, - Timeout, - UnprocessableEntityError, - UnsupportedParamsError, -) -from litellm.files.main import * -from litellm.fine_tuning.main import * -from litellm.integrations import * -from litellm.llms.AI21.chat import AI21ChatConfig -from litellm.llms.AI21.completion import AI21Config -from litellm.llms.aleph_alpha import AlephAlphaConfig -from litellm.llms.anthropic.chat import AnthropicConfig -from litellm.llms.anthropic.completion import AnthropicTextConfig -from litellm.llms.AzureOpenAI.azure import ( - AzureOpenAIAssistantsAPIConfig, - AzureOpenAIConfig, - AzureOpenAIError, -) -from litellm.llms.bedrock.chat import ( - BEDROCK_CONVERSE_MODELS, - AmazonCohereChatConfig, - AmazonConverseConfig, - bedrock_tool_name_mappings, -) -from litellm.llms.bedrock.common_utils import ( - AmazonAI21Config, - AmazonAnthropicClaude3Config, - AmazonAnthropicConfig, - AmazonBedrockGlobalConfig, - AmazonCohereConfig, - AmazonLlamaConfig, - AmazonMistralConfig, - AmazonStabilityConfig, - AmazonTitanConfig, -) -from litellm.llms.bedrock.embed.amazon_titan_g1_transformation import ( - AmazonTitanG1Config, -) -from litellm.llms.bedrock.embed.amazon_titan_multimodal_transformation import ( - AmazonTitanMultimodalEmbeddingG1Config, -) -from litellm.llms.bedrock.embed.amazon_titan_v2_transformation import ( - AmazonTitanV2Config, -) -from litellm.llms.bedrock.embed.cohere_transformation import ( - BedrockCohereEmbeddingConfig, -) -from litellm.llms.cerebras.chat import CerebrasConfig -from litellm.llms.clarifai import ClarifaiConfig -from litellm.llms.cloudflare import CloudflareConfig -from litellm.llms.cohere.completion import CohereConfig -from litellm.llms.custom_llm import CustomLLM -from litellm.llms.databricks.chat import ( - DatabricksConfig, - DatabricksEmbeddingConfig, -) -from litellm.llms.fireworks_ai import FireworksAIConfig -from litellm.llms.gemini import GeminiConfig -from litellm.llms.huggingface_restapi import HuggingfaceConfig -from litellm.llms.maritalk import MaritTalkConfig -from litellm.llms.nlp_cloud import NLPCloudConfig -from litellm.llms.nvidia_nim import NvidiaNimConfig -from litellm.llms.ollama import OllamaConfig -from litellm.llms.ollama_chat import OllamaChatConfig -from litellm.llms.OpenAI.o1_reasoning import OpenAIO1Config -from litellm.llms.OpenAI.openai import ( - AzureAIStudioConfig, - DeepInfraConfig, - GroqConfig, - MistralConfig, - MistralEmbeddingConfig, - OpenAIConfig, - OpenAITextCompletionConfig, -) -from litellm.llms.palm import PalmConfig -from litellm.llms.petals import PetalsConfig -from litellm.llms.predibase import PredibaseConfig -from litellm.llms.replicate import ReplicateConfig -from litellm.llms.sagemaker.sagemaker import SagemakerConfig -from litellm.llms.text_completion_codestral import MistralTextCompletionConfig -from litellm.llms.together_ai import TogetherAIConfig -from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import ( - GoogleAIStudioGeminiConfig, - VertexAIConfig, - VertexGeminiConfig, -) -from litellm.llms.vertex_ai_and_google_ai_studio.vertex_ai_anthropic import ( - VertexAIAnthropicConfig, -) -from litellm.llms.vertex_ai_and_google_ai_studio.vertex_ai_partner_models.ai21.transformation import ( - VertexAIAi21Config, -) -from litellm.llms.vertex_ai_and_google_ai_studio.vertex_ai_partner_models.llama3.transformation import ( - VertexAILlama3Config, -) -from litellm.llms.vertex_ai_and_google_ai_studio.vertex_embeddings.embedding_handler import ( - VertexAITextEmbeddingConfig, -) -from litellm.llms.volcengine import VolcEngineConfig -from litellm.llms.watsonx import IBMWatsonXAIConfig -from litellm.proxy.proxy_cli import run_server -from litellm.rerank_api.main import * -from litellm.router import Router -from litellm.scheduler import * - -### ADAPTERS ### -from litellm.types.adapter import AdapterItem -from litellm.types.utils import ImageObject - -from .main import * - -adapters: List[AdapterItem] = [] - -### CUSTOM LLMs ### -from litellm.types.llms.custom_llm import CustomLLMItem -from litellm.types.utils import GenericStreamingChunk - -custom_provider_map: List[CustomLLMItem] = [] -_custom_providers: List[ - str -] = [] # internal helper util, used to track names of custom providers +from .main import * # noqa diff --git a/notdiamond/toolkit/litellm/litellm_notdiamond.py b/notdiamond/toolkit/litellm/litellm_notdiamond.py index c44ab627..3e3decd0 100644 --- a/notdiamond/toolkit/litellm/litellm_notdiamond.py +++ b/notdiamond/toolkit/litellm/litellm_notdiamond.py @@ -26,6 +26,10 @@ "openai/gpt-4-1106-preview": "gpt-4-1106-preview", "openai/gpt-4o-mini": "gpt-4o-mini", "openai/gpt-4o-mini-2024-07-18": "gpt-4o-mini-2024-07-18", + "openai/o1-preview-2024-09-12": "o1-preview-2024-09-12", + "openai/o1-preview": "o1-preview", + "openai/o1-mini-2024-09-12": "o1-mini-2024-09-12", + "openai/o1-mini": "o1-mini", # anthropic "anthropic/claude-2.1": "claude-2.1", "anthropic/claude-3-opus-20240229": "claude-3-opus-20240229", @@ -76,7 +80,7 @@ def __init__( self, status_code, message, - url="https://not-diamond-server.onrender.com/v2/optimizer/modelSelect", + url="https://api.notdiamond.ai", ): self.status_code = status_code self.message = message diff --git a/notdiamond/toolkit/litellm/main.py b/notdiamond/toolkit/litellm/main.py index aeb57643..61f76d5b 100644 --- a/notdiamond/toolkit/litellm/main.py +++ b/notdiamond/toolkit/litellm/main.py @@ -155,6 +155,33 @@ from typing_extensions import overload encoding = tiktoken.get_encoding("cl100k_base") +from litellm.main import ( + AsyncCompletions, + Chat, + Completions, + LiteLLM, + _async_streaming, + aadapter_completion, + adapter_completion, + aembedding, + ahealth_check, + aimage_generation, + amoderation, + aspeech, + atext_completion, + atranscription, + batch_completion_models, + config_completion, + embedding, + image_generation, + moderation, + print_verbose, + speech, + stream_chunk_builder, + stream_chunk_builder_text_completion, + text_completion, + transcription, +) from litellm.types.router import LiteLLM_Params from litellm.utils import ( Choices, @@ -200,73 +227,6 @@ sagemaker_llm = SagemakerLLM() -class LiteLLM: - def __init__( - self, - *, - api_key=None, - organization: Optional[str] = None, - base_url: Optional[str] = None, - timeout: Optional[float] = 600, - max_retries: Optional[int] = litellm.num_retries, - default_headers: Optional[Mapping[str, str]] = None, - ): - self.params = locals() - self.chat = Chat(self.params, router_obj=None) - - -class Chat: - def __init__(self, params, router_obj: Optional[Any]): - self.params = params - if self.params.get("acompletion", False) == True: - self.params.pop("acompletion") - self.completions: Union[ - AsyncCompletions, Completions - ] = AsyncCompletions(self.params, router_obj=router_obj) - else: - self.completions = Completions(self.params, router_obj=router_obj) - - -class Completions: - def __init__(self, params, router_obj: Optional[Any]): - self.params = params - self.router_obj = router_obj - - def create(self, messages, model=None, **kwargs): - for k, v in kwargs.items(): - self.params[k] = v - model = model or self.params.get("model") - if self.router_obj is not None: - response = self.router_obj.completion( - model=model, messages=messages, **self.params - ) - else: - response = completion( - model=model, messages=messages, **self.params - ) - return response - - -class AsyncCompletions: - def __init__(self, params, router_obj: Optional[Any]): - self.params = params - self.router_obj = router_obj - - async def create(self, messages, model=None, **kwargs): - for k, v in kwargs.items(): - self.params[k] = v - model = model or self.params.get("model") - if self.router_obj is not None: - response = await self.router_obj.acompletion( - model=model, messages=messages, **self.params - ) - else: - response = await acompletion( - model=model, messages=messages, **self.params - ) - return response - - def get_api_key(llm_provider: str, dynamic_api_key: Optional[str]): api_key = dynamic_api_key or litellm.api_key # openai @@ -894,193 +854,6 @@ async def acompletion( ) -async def _async_streaming(response, model, custom_llm_provider, args): - try: - print_verbose(f"received response in _async_streaming: {response}") - if asyncio.iscoroutine(response): - response = await response - async for line in response: - print_verbose(f"line in async streaming: {line}") - yield line - except Exception as e: - custom_llm_provider = custom_llm_provider or "openai" - raise exception_type( - model=model, - custom_llm_provider=custom_llm_provider, - original_exception=e, - ) - - -def mock_completion( - model: str, - messages: List, - stream: Optional[bool] = False, - n: Optional[int] = None, - mock_response: Union[str, Exception, dict] = "This is a mock request", - mock_tool_calls: Optional[List] = None, - logging=None, - custom_llm_provider=None, - **kwargs, -): - """ - Generate a mock completion response for testing or debugging purposes. - - This is a helper function that simulates the response structure of the OpenAI completion API. - - Parameters: - model (str): The name of the language model for which the mock response is generated. - messages (List): A list of message objects representing the conversation context. - stream (bool, optional): If True, returns a mock streaming response (default is False). - mock_response (str, optional): The content of the mock response (default is "This is a mock request"). - **kwargs: Additional keyword arguments that can be used but are not required. - - Returns: - litellm.ModelResponse: A ModelResponse simulating a completion response with the specified model, messages, and mock response. - - Raises: - Exception: If an error occurs during the generation of the mock completion response. - - Note: - - This function is intended for testing or debugging purposes to generate mock completion responses. - - If 'stream' is True, it returns a response that mimics the behavior of a streaming completion. - """ - try: - ## LOGGING - if logging is not None: - logging.pre_call( - input=messages, - api_key="mock-key", - ) - if isinstance(mock_response, Exception): - if isinstance(mock_response, openai.APIError): - raise mock_response - raise litellm.MockException( - status_code=getattr(mock_response, "status_code", 500), # type: ignore - message=getattr(mock_response, "text", str(mock_response)), - llm_provider=getattr(mock_response, "llm_provider", custom_llm_provider or "openai"), # type: ignore - model=model, # type: ignore - request=httpx.Request( - method="POST", url="https://api.openai.com/v1/" - ), - ) - elif ( - isinstance(mock_response, str) - and mock_response == "litellm.RateLimitError" - ): - raise litellm.RateLimitError( - message="this is a mock rate limit error", - llm_provider=getattr(mock_response, "llm_provider", custom_llm_provider or "openai"), # type: ignore - model=model, - ) - elif isinstance(mock_response, str) and mock_response.startswith( - "Exception: content_filter_policy" - ): - raise litellm.MockException( - status_code=400, - message=mock_response, - llm_provider="azure", - model=model, # type: ignore - request=httpx.Request( - method="POST", url="https://api.openai.com/v1/" - ), - ) - elif isinstance(mock_response, str) and mock_response.startswith( - "Exception: mock_streaming_error" - ): - mock_response = litellm.MockException( - message="This is a mock error raised mid-stream", - llm_provider="anthropic", - model=model, - status_code=529, - ) - time_delay = kwargs.get("mock_delay", None) - if time_delay is not None: - time.sleep(time_delay) - - if isinstance(mock_response, dict): - return ModelResponse(**mock_response) - - model_response = ModelResponse(stream=stream) - if stream is True: - # don't try to access stream object, - if kwargs.get("acompletion", False) is True: - return CustomStreamWrapper( - completion_stream=async_mock_completion_streaming_obj( - model_response, - mock_response=mock_response, - model=model, - n=n, - ), - model=model, - custom_llm_provider="openai", - logging_obj=logging, - ) - return CustomStreamWrapper( - completion_stream=mock_completion_streaming_obj( - model_response, - mock_response=mock_response, - model=model, - n=n, - ), - model=model, - custom_llm_provider="openai", - logging_obj=logging, - ) - if isinstance(mock_response, litellm.MockException): - raise mock_response - if n is None: - model_response.choices[0].message.content = mock_response # type: ignore - else: - _all_choices = [] - for i in range(n): - _choice = litellm.utils.Choices( - index=i, - message=litellm.utils.Message( - content=mock_response, role="assistant" - ), - ) - _all_choices.append(_choice) - model_response.choices = _all_choices # type: ignore - model_response.created = int(time.time()) - model_response.model = model - - if mock_tool_calls: - model_response.choices[0].message.tool_calls = [ # type: ignore - ChatCompletionMessageToolCall(**tool_call) - for tool_call in mock_tool_calls - ] - - setattr( - model_response, - "usage", - Usage(prompt_tokens=10, completion_tokens=20, total_tokens=30), - ) - - try: - _, custom_llm_provider, _, _ = litellm.utils.get_llm_provider( - model=model - ) - model_response._hidden_params[ - "custom_llm_provider" - ] = custom_llm_provider - except Exception: - # dont let setting a hidden param block a mock_respose - pass - - if logging is not None: - logging.post_call( - input=messages, - api_key="my-secret-key", - original_response="my-original-response", - ) - return model_response - - except Exception as e: - if isinstance(e, openai.APIError): - raise e - raise Exception("Mock completion response failed") - - @client def completion( model: str, @@ -3423,2823 +3196,3 @@ def completion( completion_kwargs=args, extra_kwargs=kwargs, ) - - -def completion_with_retries(*args, **kwargs): - """ - Executes a litellm.completion() with 3 retries - """ - try: - import tenacity - except Exception as e: - raise Exception( - f"tenacity import failed please run `pip install tenacity`. Error{e}" - ) - - num_retries = kwargs.pop("num_retries", 3) - retry_strategy = kwargs.pop("retry_strategy", "constant_retry") - original_function = kwargs.pop("original_function", completion) - if retry_strategy == "constant_retry": - retryer = tenacity.Retrying( - stop=tenacity.stop_after_attempt(num_retries), reraise=True - ) - elif retry_strategy == "exponential_backoff_retry": - retryer = tenacity.Retrying( - wait=tenacity.wait_exponential(multiplier=1, max=10), - stop=tenacity.stop_after_attempt(num_retries), - reraise=True, - ) - return retryer(original_function, *args, **kwargs) - - -async def acompletion_with_retries(*args, **kwargs): - """ - [DEPRECATED]. Use 'acompletion' or router.acompletion instead! - Executes a litellm.completion() with 3 retries - """ - try: - import tenacity - except Exception as e: - raise Exception( - f"tenacity import failed please run `pip install tenacity`. Error{e}" - ) - - num_retries = kwargs.pop("num_retries", 3) - retry_strategy = kwargs.pop("retry_strategy", "constant_retry") - original_function = kwargs.pop("original_function", completion) - if retry_strategy == "constant_retry": - retryer = tenacity.Retrying( - stop=tenacity.stop_after_attempt(num_retries), reraise=True - ) - elif retry_strategy == "exponential_backoff_retry": - retryer = tenacity.Retrying( - wait=tenacity.wait_exponential(multiplier=1, max=10), - stop=tenacity.stop_after_attempt(num_retries), - reraise=True, - ) - return await retryer(original_function, *args, **kwargs) - - -def batch_completion( - model: str, - # Optional OpenAI params: see https://platform.openai.com/docs/api-reference/chat/create - messages: List = [], - functions: Optional[List] = None, - function_call: Optional[str] = None, - temperature: Optional[float] = None, - top_p: Optional[float] = None, - n: Optional[int] = None, - stream: Optional[bool] = None, - stop=None, - max_tokens: Optional[int] = None, - presence_penalty: Optional[float] = None, - frequency_penalty: Optional[float] = None, - logit_bias: Optional[dict] = None, - user: Optional[str] = None, - deployment_id=None, - request_timeout: Optional[int] = None, - timeout: Optional[int] = 600, - # Optional liteLLM function params - **kwargs, -): - """ - Batch litellm.completion function for a given model. - - Args: - model (str): The model to use for generating completions. - messages (List, optional): List of messages to use as input for generating completions. Defaults to []. - functions (List, optional): List of functions to use as input for generating completions. Defaults to []. - function_call (str, optional): The function call to use as input for generating completions. Defaults to "". - temperature (float, optional): The temperature parameter for generating completions. Defaults to None. - top_p (float, optional): The top-p parameter for generating completions. Defaults to None. - n (int, optional): The number of completions to generate. Defaults to None. - stream (bool, optional): Whether to stream completions or not. Defaults to None. - stop (optional): The stop parameter for generating completions. Defaults to None. - max_tokens (float, optional): The maximum number of tokens to generate. Defaults to None. - presence_penalty (float, optional): The presence penalty for generating completions. Defaults to None. - frequency_penalty (float, optional): The frequency penalty for generating completions. Defaults to None. - logit_bias (dict, optional): The logit bias for generating completions. Defaults to {}. - user (str, optional): The user string for generating completions. Defaults to "". - deployment_id (optional): The deployment ID for generating completions. Defaults to None. - request_timeout (int, optional): The request timeout for generating completions. Defaults to None. - - Returns: - list: A list of completion results. - """ - args = locals() - - batch_messages = messages - completions = [] - model = model - custom_llm_provider = None - if model.split("/", 1)[0] in provider_list: - custom_llm_provider = model.split("/", 1)[0] - model = model.split("/", 1)[1] - if custom_llm_provider == "vllm": - optional_params = get_optional_params( - functions=functions, - function_call=function_call, - temperature=temperature, - top_p=top_p, - n=n, - stream=stream, - stop=stop, - max_tokens=max_tokens, - presence_penalty=presence_penalty, - frequency_penalty=frequency_penalty, - logit_bias=logit_bias, - user=user, - # params to identify the model - model=model, - custom_llm_provider=custom_llm_provider, - ) - results = vllm.batch_completions( - model=model, - messages=batch_messages, - custom_prompt_dict=litellm.custom_prompt_dict, - optional_params=optional_params, - ) - # all non VLLM models for batch completion models - else: - - def chunks(lst, n): - """Yield successive n-sized chunks from lst.""" - for i in range(0, len(lst), n): - yield lst[i : i + n] - - with ThreadPoolExecutor(max_workers=100) as executor: - for sub_batch in chunks(batch_messages, 100): - for message_list in sub_batch: - kwargs_modified = args.copy() - kwargs_modified["messages"] = message_list - original_kwargs = {} - if "kwargs" in kwargs_modified: - original_kwargs = kwargs_modified.pop("kwargs") - future = executor.submit( - completion, **kwargs_modified, **original_kwargs - ) - completions.append(future) - - # Retrieve the results from the futures - # results = [future.result() for future in completions] - # return exceptions if any - results = [] - for future in completions: - try: - results.append(future.result()) - except Exception as exc: - results.append(exc) - - return results - - -# send one request to multiple models -# return as soon as one of the llms responds -def batch_completion_models(*args, **kwargs): - """ - Send a request to multiple language models concurrently and return the response - as soon as one of the models responds. - - Args: - *args: Variable-length positional arguments passed to the completion function. - **kwargs: Additional keyword arguments: - - models (str or list of str): The language models to send requests to. - - Other keyword arguments to be passed to the completion function. - - Returns: - str or None: The response from one of the language models, or None if no response is received. - - Note: - This function utilizes a ThreadPoolExecutor to parallelize requests to multiple models. - It sends requests concurrently and returns the response from the first model that responds. - """ - import concurrent - - if "model" in kwargs: - kwargs.pop("model") - if "models" in kwargs: - models = kwargs["models"] - kwargs.pop("models") - futures = {} - with concurrent.futures.ThreadPoolExecutor( - max_workers=len(models) - ) as executor: - for model in models: - futures[model] = executor.submit( - completion, *args, model=model, **kwargs - ) - - for model, future in sorted( - futures.items(), key=lambda x: models.index(x[0]) - ): - if future.result() is not None: - return future.result() - elif "deployments" in kwargs: - deployments = kwargs["deployments"] - kwargs.pop("deployments") - kwargs.pop("model_list") - nested_kwargs = kwargs.pop("kwargs", {}) - futures = {} - with concurrent.futures.ThreadPoolExecutor( - max_workers=len(deployments) - ) as executor: - for deployment in deployments: - for key in kwargs.keys(): - if ( - key not in deployment - ): # don't override deployment values e.g. model name, api base, etc. - deployment[key] = kwargs[key] - kwargs = {**deployment, **nested_kwargs} - futures[deployment["model"]] = executor.submit( - completion, **kwargs - ) - - while futures: - # wait for the first returned future - print_verbose("\n\n waiting for next result\n\n") - done, _ = concurrent.futures.wait( - futures.values(), - return_when=concurrent.futures.FIRST_COMPLETED, - ) - print_verbose(f"done list\n{done}") - for future in done: - try: - result = future.result() - return result - except Exception as e: - # if model 1 fails, continue with response from model 2, model3 - print_verbose( - f"\n\ngot an exception, ignoring, removing from futures" - ) - print_verbose(futures) - new_futures = {} - for key, value in futures.items(): - if future == value: - print_verbose(f"removing key{key}") - continue - else: - new_futures[key] = value - futures = new_futures - print_verbose(f"new futures{futures}") - continue - - print_verbose("\n\ndone looping through futures\n\n") - print_verbose(futures) - - return None # If no response is received from any model - - -def batch_completion_models_all_responses(*args, **kwargs): - """ - Send a request to multiple language models concurrently and return a list of responses - from all models that respond. - - Args: - *args: Variable-length positional arguments passed to the completion function. - **kwargs: Additional keyword arguments: - - models (str or list of str): The language models to send requests to. - - Other keyword arguments to be passed to the completion function. - - Returns: - list: A list of responses from the language models that responded. - - Note: - This function utilizes a ThreadPoolExecutor to parallelize requests to multiple models. - It sends requests concurrently and collects responses from all models that respond. - """ - import concurrent.futures - - # ANSI escape codes for colored output - GREEN = "\033[92m" - RED = "\033[91m" - RESET = "\033[0m" - - if "model" in kwargs: - kwargs.pop("model") - if "models" in kwargs: - models = kwargs["models"] - kwargs.pop("models") - - responses = [] - - with concurrent.futures.ThreadPoolExecutor( - max_workers=len(models) - ) as executor: - for idx, model in enumerate(models): - future = executor.submit(completion, *args, model=model, **kwargs) - if future.result() is not None: - responses.append(future.result()) - - return responses - - -### EMBEDDING ENDPOINTS #################### -@client -async def aembedding(*args, **kwargs) -> EmbeddingResponse: - """ - Asynchronously calls the `embedding` function with the given arguments and keyword arguments. - - Parameters: - - `args` (tuple): Positional arguments to be passed to the `embedding` function. - - `kwargs` (dict): Keyword arguments to be passed to the `embedding` function. - - Returns: - - `response` (Any): The response returned by the `embedding` function. - """ - loop = asyncio.get_event_loop() - model = args[0] if len(args) > 0 else kwargs["model"] - ### PASS ARGS TO Embedding ### - kwargs["aembedding"] = True - custom_llm_provider = None - try: - # Use a partial function to pass your keyword arguments - func = partial(embedding, *args, **kwargs) - - # Add the context to the function - ctx = contextvars.copy_context() - func_with_context = partial(ctx.run, func) - - _, custom_llm_provider, _, _ = get_llm_provider( - model=model, api_base=kwargs.get("api_base", None) - ) - - if ( - custom_llm_provider == "openai" - or custom_llm_provider == "azure" - or custom_llm_provider == "xinference" - or custom_llm_provider == "voyage" - or custom_llm_provider == "mistral" - or custom_llm_provider == "custom_openai" - or custom_llm_provider == "triton" - or custom_llm_provider == "anyscale" - or custom_llm_provider == "openrouter" - or custom_llm_provider == "deepinfra" - or custom_llm_provider == "perplexity" - or custom_llm_provider == "groq" - or custom_llm_provider == "nvidia_nim" - or custom_llm_provider == "cerebras" - or custom_llm_provider == "ai21_chat" - or custom_llm_provider == "volcengine" - or custom_llm_provider == "deepseek" - or custom_llm_provider == "fireworks_ai" - or custom_llm_provider == "ollama" - or custom_llm_provider == "vertex_ai" - or custom_llm_provider == "gemini" - or custom_llm_provider == "databricks" - or custom_llm_provider == "watsonx" - or custom_llm_provider == "cohere" - or custom_llm_provider == "huggingface" - or custom_llm_provider == "bedrock" - ): # currently implemented aiohttp calls for just azure and openai, soon all. - # Await normally - init_response = await loop.run_in_executor(None, func_with_context) - if isinstance(init_response, dict): - response = EmbeddingResponse(**init_response) - elif isinstance( - init_response, EmbeddingResponse - ): ## CACHING SCENARIO - response = init_response - elif asyncio.iscoroutine(init_response): - response = await init_response - else: - # Call the synchronous function using run_in_executor - response = await loop.run_in_executor(None, func_with_context) - if response is not None and hasattr(response, "_hidden_params"): - response._hidden_params[ - "custom_llm_provider" - ] = custom_llm_provider - return response - except Exception as e: - custom_llm_provider = custom_llm_provider or "openai" - raise exception_type( - model=model, - custom_llm_provider=custom_llm_provider, - original_exception=e, - completion_kwargs=args, - extra_kwargs=kwargs, - ) - - -@client -def embedding( - model, - input=[], - # Optional params - dimensions: Optional[int] = None, - timeout=600, # default to 10 minutes - # set api_base, api_version, api_key - api_base: Optional[str] = None, - api_version: Optional[str] = None, - api_key: Optional[str] = None, - api_type: Optional[str] = None, - caching: bool = False, - user: Optional[str] = None, - custom_llm_provider=None, - litellm_call_id=None, - litellm_logging_obj=None, - logger_fn=None, - **kwargs, -) -> EmbeddingResponse: - """ - Embedding function that calls an API to generate embeddings for the given input. - - Parameters: - - model: The embedding model to use. - - input: The input for which embeddings are to be generated. - - dimensions: The number of dimensions the resulting output embeddings should have. Only supported in text-embedding-3 and later models. - - timeout: The timeout value for the API call, default 10 mins - - litellm_call_id: The call ID for litellm logging. - - litellm_logging_obj: The litellm logging object. - - logger_fn: The logger function. - - api_base: Optional. The base URL for the API. - - api_version: Optional. The version of the API. - - api_key: Optional. The API key to use. - - api_type: Optional. The type of the API. - - caching: A boolean indicating whether to enable caching. - - custom_llm_provider: The custom llm provider. - - Returns: - - response: The response received from the API call. - - Raises: - - exception_type: If an exception occurs during the API call. - """ - azure = kwargs.get("azure", None) - client = kwargs.pop("client", None) - rpm = kwargs.pop("rpm", None) - tpm = kwargs.pop("tpm", None) - cooldown_time = kwargs.get("cooldown_time", None) - max_parallel_requests = kwargs.pop("max_parallel_requests", None) - model_info = kwargs.get("model_info", None) - metadata = kwargs.get("metadata", None) - encoding_format = kwargs.get("encoding_format", None) - proxy_server_request = kwargs.get("proxy_server_request", None) - aembedding = kwargs.get("aembedding", None) - extra_headers = kwargs.get("extra_headers", None) - ### CUSTOM MODEL COST ### - input_cost_per_token = kwargs.get("input_cost_per_token", None) - output_cost_per_token = kwargs.get("output_cost_per_token", None) - input_cost_per_second = kwargs.get("input_cost_per_second", None) - output_cost_per_second = kwargs.get("output_cost_per_second", None) - openai_params = [ - "user", - "dimensions", - "request_timeout", - "api_base", - "api_version", - "api_key", - "deployment_id", - "organization", - "base_url", - "default_headers", - "timeout", - "max_retries", - "encoding_format", - ] - litellm_params = [ - "metadata", - "aembedding", - "caching", - "mock_response", - "api_key", - "api_version", - "api_base", - "force_timeout", - "logger_fn", - "verbose", - "custom_llm_provider", - "litellm_logging_obj", - "litellm_call_id", - "use_client", - "id", - "fallbacks", - "azure", - "headers", - "model_list", - "num_retries", - "context_window_fallback_dict", - "retry_policy", - "roles", - "final_prompt_value", - "bos_token", - "eos_token", - "request_timeout", - "complete_response", - "self", - "client", - "rpm", - "tpm", - "max_parallel_requests", - "input_cost_per_token", - "output_cost_per_token", - "input_cost_per_second", - "output_cost_per_second", - "hf_model_name", - "proxy_server_request", - "model_info", - "preset_cache_key", - "caching_groups", - "ttl", - "cache", - "no-log", - "region_name", - "allowed_model_region", - "model_config", - "cooldown_time", - "tags", - "azure_ad_token_provider", - "tenant_id", - "client_id", - "client_secret", - "extra_headers", - ] - default_params = openai_params + litellm_params - non_default_params = { - k: v for k, v in kwargs.items() if k not in default_params - } # model-specific params - pass them straight to the model/provider - - model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider( - model=model, - custom_llm_provider=custom_llm_provider, - api_base=api_base, - api_key=api_key, - ) - optional_params = get_optional_params_embeddings( - model=model, - user=user, - dimensions=dimensions, - encoding_format=encoding_format, - custom_llm_provider=custom_llm_provider, - **non_default_params, - ) - ### REGISTER CUSTOM MODEL PRICING -- IF GIVEN ### - if input_cost_per_token is not None and output_cost_per_token is not None: - litellm.register_model( - { - model: { - "input_cost_per_token": input_cost_per_token, - "output_cost_per_token": output_cost_per_token, - "litellm_provider": custom_llm_provider, - } - } - ) - if ( - input_cost_per_second is not None - ): # time based pricing just needs cost in place - output_cost_per_second = output_cost_per_second or 0.0 - litellm.register_model( - { - model: { - "input_cost_per_second": input_cost_per_second, - "output_cost_per_second": output_cost_per_second, - "litellm_provider": custom_llm_provider, - } - } - ) - try: - response = None - logging: Logging = litellm_logging_obj # type: ignore - logging.update_environment_variables( - model=model, - user=user, - optional_params=optional_params, - litellm_params={ - "timeout": timeout, - "azure": azure, - "litellm_call_id": litellm_call_id, - "logger_fn": logger_fn, - "proxy_server_request": proxy_server_request, - "model_info": model_info, - "metadata": metadata, - "aembedding": aembedding, - "preset_cache_key": None, - "stream_response": {}, - "cooldown_time": cooldown_time, - }, - ) - if azure is True or custom_llm_provider == "azure": - # azure configs - api_type = get_secret("AZURE_API_TYPE") or "azure" - - api_base = ( - api_base or litellm.api_base or get_secret("AZURE_API_BASE") - ) - - api_version = ( - api_version - or litellm.api_version - or get_secret("AZURE_API_VERSION") - or litellm.AZURE_DEFAULT_API_VERSION - ) - - azure_ad_token = optional_params.pop( - "azure_ad_token", None - ) or get_secret("AZURE_AD_TOKEN") - - api_key = ( - api_key - or litellm.api_key - or litellm.azure_key - or get_secret("AZURE_API_KEY") - ) - ## EMBEDDING CALL - response = azure_chat_completions.embedding( - model=model, - input=input, - api_base=api_base, - api_key=api_key, - api_version=api_version, - azure_ad_token=azure_ad_token, - logging_obj=logging, - timeout=timeout, - model_response=EmbeddingResponse(), - optional_params=optional_params, - client=client, - aembedding=aembedding, - ) - elif ( - model in litellm.open_ai_embedding_models - or custom_llm_provider == "openai" - ): - api_base = ( - api_base - or litellm.api_base - or get_secret("OPENAI_API_BASE") - or "https://api.openai.com/v1" - ) - openai.organization = ( - litellm.organization - or get_secret("OPENAI_ORGANIZATION") - or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105 - ) - # set API KEY - api_key = ( - api_key - or litellm.api_key - or litellm.openai_key - or get_secret("OPENAI_API_KEY") - ) - api_type = "openai" - api_version = None - - ## EMBEDDING CALL - response = openai_chat_completions.embedding( - model=model, - input=input, - api_base=api_base, - api_key=api_key, - logging_obj=logging, - timeout=timeout, - model_response=EmbeddingResponse(), - optional_params=optional_params, - client=client, - aembedding=aembedding, - ) - elif custom_llm_provider == "databricks": - api_base = ( - api_base - or litellm.api_base - or get_secret("DATABRICKS_API_BASE") - ) # type: ignore - - # set API KEY - api_key = ( - api_key - or litellm.api_key - or litellm.databricks_key - or get_secret("DATABRICKS_API_KEY") - ) # type: ignore - - ## EMBEDDING CALL - response = databricks_chat_completions.embedding( - model=model, - input=input, - api_base=api_base, - api_key=api_key, - logging_obj=logging, - timeout=timeout, - model_response=EmbeddingResponse(), - optional_params=optional_params, - client=client, - aembedding=aembedding, - ) - elif ( - custom_llm_provider == "cohere" - or custom_llm_provider == "cohere_chat" - ): - cohere_key = ( - api_key - or litellm.cohere_key - or get_secret("COHERE_API_KEY") - or get_secret("CO_API_KEY") - or litellm.api_key - ) - - if extra_headers is not None and isinstance(extra_headers, dict): - headers = extra_headers - else: - headers = {} - response = cohere_embed.embedding( - model=model, - input=input, - optional_params=optional_params, - encoding=encoding, - api_key=cohere_key, # type: ignore - headers=headers, - logging_obj=logging, - model_response=EmbeddingResponse(), - aembedding=aembedding, - timeout=timeout, - client=client, - ) - elif custom_llm_provider == "huggingface": - api_key = ( - api_key - or litellm.huggingface_key - or get_secret("HUGGINGFACE_API_KEY") - or litellm.api_key - ) # type: ignore - response = huggingface.embedding( - model=model, - input=input, - encoding=encoding, # type: ignore - api_key=api_key, - api_base=api_base, - logging_obj=logging, - model_response=EmbeddingResponse(), - optional_params=optional_params, - client=client, - aembedding=aembedding, - ) - elif custom_llm_provider == "bedrock": - if isinstance(input, str): - transformed_input = [input] - else: - transformed_input = input - response = bedrock_embedding.embeddings( - model=model, - input=transformed_input, - encoding=encoding, - logging_obj=logging, - optional_params=optional_params, - model_response=EmbeddingResponse(), - client=client, - timeout=timeout, - aembedding=aembedding, - litellm_params=litellm_params, - api_base=api_base, - print_verbose=print_verbose, - extra_headers=extra_headers, - ) - elif custom_llm_provider == "triton": - if api_base is None: - raise ValueError( - "api_base is required for triton. Please pass `api_base`" - ) - response = triton_chat_completions.embedding( - model=model, - input=input, - api_base=api_base, - api_key=api_key, - logging_obj=logging, - timeout=timeout, - model_response=EmbeddingResponse(), - optional_params=optional_params, - client=client, - aembedding=aembedding, - ) - elif custom_llm_provider == "gemini": - gemini_api_key = ( - api_key or get_secret("GEMINI_API_KEY") or litellm.api_key - ) - - response = google_batch_embeddings.batch_embeddings( # type: ignore - model=model, - input=input, - encoding=encoding, - logging_obj=logging, - optional_params=optional_params, - model_response=EmbeddingResponse(), - vertex_project=None, - vertex_location=None, - vertex_credentials=None, - aembedding=aembedding, - print_verbose=print_verbose, - custom_llm_provider="gemini", - api_key=gemini_api_key, - ) - - elif custom_llm_provider == "vertex_ai": - vertex_ai_project = ( - optional_params.pop("vertex_project", None) - or optional_params.pop("vertex_ai_project", None) - or litellm.vertex_project - or get_secret("VERTEXAI_PROJECT") - or get_secret("VERTEX_PROJECT") - ) - vertex_ai_location = ( - optional_params.pop("vertex_location", None) - or optional_params.pop("vertex_ai_location", None) - or litellm.vertex_location - or get_secret("VERTEXAI_LOCATION") - or get_secret("VERTEX_LOCATION") - ) - vertex_credentials = ( - optional_params.pop("vertex_credentials", None) - or optional_params.pop("vertex_ai_credentials", None) - or get_secret("VERTEXAI_CREDENTIALS") - or get_secret("VERTEX_CREDENTIALS") - ) - - if ( - "image" in optional_params - or "video" in optional_params - or model - in vertex_multimodal_embedding.SUPPORTED_MULTIMODAL_EMBEDDING_MODELS - ): - # multimodal embedding is supported on vertex httpx - response = vertex_multimodal_embedding.multimodal_embedding( - model=model, - input=input, - encoding=encoding, - logging_obj=logging, - optional_params=optional_params, - model_response=EmbeddingResponse(), - vertex_project=vertex_ai_project, - vertex_location=vertex_ai_location, - vertex_credentials=vertex_credentials, - aembedding=aembedding, - print_verbose=print_verbose, - custom_llm_provider="vertex_ai", - ) - else: - response = vertex_ai_embedding_handler.embedding( - model=model, - input=input, - encoding=encoding, - logging_obj=logging, - optional_params=optional_params, - model_response=EmbeddingResponse(), - vertex_project=vertex_ai_project, - vertex_location=vertex_ai_location, - vertex_credentials=vertex_credentials, - aembedding=aembedding, - print_verbose=print_verbose, - ) - elif custom_llm_provider == "oobabooga": - response = oobabooga.embedding( - model=model, - input=input, - encoding=encoding, - api_base=api_base, - logging_obj=logging, - optional_params=optional_params, - model_response=EmbeddingResponse(), - ) - elif custom_llm_provider == "ollama": - api_base = ( - litellm.api_base - or api_base - or get_secret("OLLAMA_API_BASE") - or "http://localhost:11434" - ) # type: ignore - if isinstance(input, str): - input = [input] - if not all(isinstance(item, str) for item in input): - raise litellm.BadRequestError( - message=f"Invalid input for ollama embeddings. input={input}", - model=model, # type: ignore - llm_provider="ollama", # type: ignore - ) - ollama_embeddings_fn = ( - ollama.ollama_aembeddings - if aembedding is True - else ollama.ollama_embeddings - ) - response = ollama_embeddings_fn( # type: ignore - api_base=api_base, - model=model, - prompts=input, - encoding=encoding, - logging_obj=logging, - optional_params=optional_params, - model_response=EmbeddingResponse(), - ) - elif custom_llm_provider == "sagemaker": - response = sagemaker_llm.embedding( - model=model, - input=input, - encoding=encoding, - logging_obj=logging, - optional_params=optional_params, - model_response=EmbeddingResponse(), - print_verbose=print_verbose, - ) - elif custom_llm_provider == "mistral": - api_key = ( - api_key or litellm.api_key or get_secret("MISTRAL_API_KEY") - ) - response = openai_chat_completions.embedding( - model=model, - input=input, - api_base=api_base, - api_key=api_key, - logging_obj=logging, - timeout=timeout, - model_response=EmbeddingResponse(), - optional_params=optional_params, - client=client, - aembedding=aembedding, - ) - elif custom_llm_provider == "voyage": - api_key = ( - api_key or litellm.api_key or get_secret("VOYAGE_API_KEY") - ) - response = openai_chat_completions.embedding( - model=model, - input=input, - api_base=api_base, - api_key=api_key, - logging_obj=logging, - timeout=timeout, - model_response=EmbeddingResponse(), - optional_params=optional_params, - client=client, - aembedding=aembedding, - ) - elif custom_llm_provider == "xinference": - api_key = ( - api_key - or litellm.api_key - or get_secret("XINFERENCE_API_KEY") - or "stub-xinference-key" - ) # xinference does not need an api key, pass a stub key if user did not set one - api_base = ( - api_base - or litellm.api_base - or get_secret("XINFERENCE_API_BASE") - or "http://127.0.0.1:9997/v1" - ) - response = openai_chat_completions.embedding( - model=model, - input=input, - api_base=api_base, - api_key=api_key, - logging_obj=logging, - timeout=timeout, - model_response=EmbeddingResponse(), - optional_params=optional_params, - client=client, - aembedding=aembedding, - ) - elif custom_llm_provider == "watsonx": - response = watsonxai.embedding( - model=model, - input=input, - encoding=encoding, - logging_obj=logging, - optional_params=optional_params, - model_response=EmbeddingResponse(), - aembedding=aembedding, - ) - else: - args = locals() - raise ValueError( - f"No valid embedding model args passed in - {args}" - ) - if response is not None and hasattr(response, "_hidden_params"): - response._hidden_params[ - "custom_llm_provider" - ] = custom_llm_provider - return response - except Exception as e: - ## LOGGING - logging.post_call( - input=input, - api_key=api_key, - original_response=str(e), - ) - ## Map to OpenAI Exception - raise exception_type( - model=model, - original_exception=e, - custom_llm_provider=custom_llm_provider, - extra_kwargs=kwargs, - ) - - -###### Text Completion ################ -@client -async def atext_completion( - *args, **kwargs -) -> Union[TextCompletionResponse, TextCompletionStreamWrapper]: - """ - Implemented to handle async streaming for the text completion endpoint - """ - loop = asyncio.get_event_loop() - model = args[0] if len(args) > 0 else kwargs["model"] - ### PASS ARGS TO COMPLETION ### - kwargs["acompletion"] = True - custom_llm_provider = None - try: - # Use a partial function to pass your keyword arguments - func = partial(text_completion, *args, **kwargs) - - # Add the context to the function - ctx = contextvars.copy_context() - func_with_context = partial(ctx.run, func) - - _, custom_llm_provider, _, _ = get_llm_provider( - model=model, api_base=kwargs.get("api_base", None) - ) - - if ( - custom_llm_provider == "openai" - or custom_llm_provider == "azure" - or custom_llm_provider == "azure_text" - or custom_llm_provider == "custom_openai" - or custom_llm_provider == "anyscale" - or custom_llm_provider == "mistral" - or custom_llm_provider == "openrouter" - or custom_llm_provider == "deepinfra" - or custom_llm_provider == "perplexity" - or custom_llm_provider == "groq" - or custom_llm_provider == "nvidia_nim" - or custom_llm_provider == "cerebras" - or custom_llm_provider == "ai21_chat" - or custom_llm_provider == "volcengine" - or custom_llm_provider == "text-completion-codestral" - or custom_llm_provider == "deepseek" - or custom_llm_provider == "fireworks_ai" - or custom_llm_provider == "text-completion-openai" - or custom_llm_provider == "huggingface" - or custom_llm_provider == "ollama" - or custom_llm_provider == "vertex_ai" - or custom_llm_provider in litellm.openai_compatible_providers - ): # currently implemented aiohttp calls for just azure and openai, soon all. - # Await normally - response = await loop.run_in_executor(None, func_with_context) - if asyncio.iscoroutine(response): - response = await response - else: - # Call the synchronous function using run_in_executor - response = await loop.run_in_executor(None, func_with_context) - if kwargs.get("stream", False) is True: # return an async generator - return TextCompletionStreamWrapper( - completion_stream=_async_streaming( - response=response, - model=model, - custom_llm_provider=custom_llm_provider, - args=args, - ), - model=model, - custom_llm_provider=custom_llm_provider, - ) - else: - transformed_logprobs = None - # only supported for TGI models - try: - raw_response = response._hidden_params.get( - "original_response", None - ) - transformed_logprobs = litellm.utils.transform_logprobs( - raw_response - ) - except Exception as e: - print_verbose(f"LiteLLM non blocking exception: {e}") - - ## TRANSLATE CHAT TO TEXT FORMAT ## - if isinstance(response, TextCompletionResponse): - return response - elif asyncio.iscoroutine(response): - response = await response - - text_completion_response = TextCompletionResponse() - text_completion_response["id"] = response.get("id", None) - text_completion_response["object"] = "text_completion" - text_completion_response["created"] = response.get("created", None) - text_completion_response["model"] = response.get("model", None) - text_choices = TextChoices() - text_choices["text"] = response["choices"][0]["message"]["content"] - text_choices["index"] = response["choices"][0]["index"] - text_choices["logprobs"] = transformed_logprobs - text_choices["finish_reason"] = response["choices"][0][ - "finish_reason" - ] - text_completion_response["choices"] = [text_choices] - text_completion_response["usage"] = response.get("usage", None) - text_completion_response._hidden_params = HiddenParams( - **response._hidden_params - ) - return text_completion_response - except Exception as e: - custom_llm_provider = custom_llm_provider or "openai" - raise exception_type( - model=model, - custom_llm_provider=custom_llm_provider, - original_exception=e, - completion_kwargs=args, - extra_kwargs=kwargs, - ) - - -@client -def text_completion( - prompt: Union[ - str, List[Union[str, List[Union[str, List[int]]]]] - ], # Required: The prompt(s) to generate completions for. - model: Optional[ - str - ] = None, # Optional: either `model` or `engine` can be set - best_of: Optional[ - int - ] = None, # Optional: Generates best_of completions server-side. - echo: Optional[ - bool - ] = None, # Optional: Echo back the prompt in addition to the completion. - frequency_penalty: Optional[ - float - ] = None, # Optional: Penalize new tokens based on their existing frequency. - logit_bias: Optional[ - Dict[int, int] - ] = None, # Optional: Modify the likelihood of specified tokens. - logprobs: Optional[ - int - ] = None, # Optional: Include the log probabilities on the most likely tokens. - max_tokens: Optional[ - int - ] = None, # Optional: The maximum number of tokens to generate in the completion. - n: Optional[ - int - ] = None, # Optional: How many completions to generate for each prompt. - presence_penalty: Optional[ - float - ] = None, # Optional: Penalize new tokens based on whether they appear in the text so far. - stop: Optional[ - Union[str, List[str]] - ] = None, # Optional: Sequences where the API will stop generating further tokens. - stream: Optional[ - bool - ] = None, # Optional: Whether to stream back partial progress. - stream_options: Optional[dict] = None, - suffix: Optional[ - str - ] = None, # Optional: The suffix that comes after a completion of inserted text. - temperature: Optional[ - float - ] = None, # Optional: Sampling temperature to use. - top_p: Optional[float] = None, # Optional: Nucleus sampling parameter. - user: Optional[ - str - ] = None, # Optional: A unique identifier representing your end-user. - # set api_base, api_version, api_key - api_base: Optional[str] = None, - api_version: Optional[str] = None, - api_key: Optional[str] = None, - model_list: Optional[list] = None, # pass in a list of api_base,keys, etc. - # Optional liteLLM function params - custom_llm_provider: Optional[str] = None, - *args, - **kwargs, -): - global print_verbose - import copy - - """ - Generate text completions using the OpenAI API. - - Args: - model (str): ID of the model to use. - prompt (Union[str, List[Union[str, List[Union[str, List[int]]]]]): The prompt(s) to generate completions for. - best_of (Optional[int], optional): Generates best_of completions server-side. Defaults to 1. - echo (Optional[bool], optional): Echo back the prompt in addition to the completion. Defaults to False. - frequency_penalty (Optional[float], optional): Penalize new tokens based on their existing frequency. Defaults to 0. - logit_bias (Optional[Dict[int, int]], optional): Modify the likelihood of specified tokens. Defaults to None. - logprobs (Optional[int], optional): Include the log probabilities on the most likely tokens. Defaults to None. - max_tokens (Optional[int], optional): The maximum number of tokens to generate in the completion. Defaults to 16. - n (Optional[int], optional): How many completions to generate for each prompt. Defaults to 1. - presence_penalty (Optional[float], optional): Penalize new tokens based on whether they appear in the text so far. Defaults to 0. - stop (Optional[Union[str, List[str]]], optional): Sequences where the API will stop generating further tokens. Defaults to None. - stream (Optional[bool], optional): Whether to stream back partial progress. Defaults to False. - suffix (Optional[str], optional): The suffix that comes after a completion of inserted text. Defaults to None. - temperature (Optional[float], optional): Sampling temperature to use. Defaults to 1. - top_p (Optional[float], optional): Nucleus sampling parameter. Defaults to 1. - user (Optional[str], optional): A unique identifier representing your end-user. - Returns: - TextCompletionResponse: A response object containing the generated completion and associated metadata. - - Example: - Your example of how to use this function goes here. - """ - if "engine" in kwargs: - if model == None: - # only use engine when model not passed - model = kwargs["engine"] - kwargs.pop("engine") - - text_completion_response = TextCompletionResponse() - - optional_params: Dict[str, Any] = {} - # default values for all optional params are none, litellm only passes them to the llm when they are set to non None values - if best_of is not None: - optional_params["best_of"] = best_of - if echo is not None: - optional_params["echo"] = echo - if frequency_penalty is not None: - optional_params["frequency_penalty"] = frequency_penalty - if logit_bias is not None: - optional_params["logit_bias"] = logit_bias - if logprobs is not None: - optional_params["logprobs"] = logprobs - if max_tokens is not None: - optional_params["max_tokens"] = max_tokens - if n is not None: - optional_params["n"] = n - if presence_penalty is not None: - optional_params["presence_penalty"] = presence_penalty - if stop is not None: - optional_params["stop"] = stop - if stream is not None: - optional_params["stream"] = stream - if stream_options is not None: - optional_params["stream_options"] = stream_options - if suffix is not None: - optional_params["suffix"] = suffix - if temperature is not None: - optional_params["temperature"] = temperature - if top_p is not None: - optional_params["top_p"] = top_p - if user is not None: - optional_params["user"] = user - if api_base is not None: - optional_params["api_base"] = api_base - if api_version is not None: - optional_params["api_version"] = api_version - if api_key is not None: - optional_params["api_key"] = api_key - if custom_llm_provider is not None: - optional_params["custom_llm_provider"] = custom_llm_provider - - # get custom_llm_provider - _model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base) # type: ignore - - if custom_llm_provider == "huggingface": - # if echo == True, for TGI llms we need to set top_n_tokens to 3 - if echo == True: - # for tgi llms - if "top_n_tokens" not in kwargs: - kwargs["top_n_tokens"] = 3 - - # processing prompt - users can pass raw tokens to OpenAI Completion() - if type(prompt) == list: - import concurrent.futures - - tokenizer = tiktoken.encoding_for_model("text-davinci-003") - ## if it's a 2d list - each element in the list is a text_completion() request - if len(prompt) > 0 and type(prompt[0]) == list: - responses = [None for x in prompt] # init responses - - def process_prompt(i, individual_prompt): - decoded_prompt = tokenizer.decode(individual_prompt) - all_params = {**kwargs, **optional_params} - response = text_completion( - model=model, - prompt=decoded_prompt, - num_retries=3, # ensure this does not fail for the batch - *args, - **all_params, - ) - - text_completion_response["id"] = response.get("id", None) - text_completion_response["object"] = "text_completion" - text_completion_response["created"] = response.get( - "created", None - ) - text_completion_response["model"] = response.get( - "model", None - ) - return response["choices"][0] - - with concurrent.futures.ThreadPoolExecutor() as executor: - futures = [ - executor.submit(process_prompt, i, individual_prompt) - for i, individual_prompt in enumerate(prompt) - ] - for i, future in enumerate( - concurrent.futures.as_completed(futures) - ): - responses[i] = future.result() - text_completion_response.choices = responses # type: ignore - - return text_completion_response - # else: - # check if non default values passed in for best_of, echo, logprobs, suffix - # these are the params supported by Completion() but not ChatCompletion - - # default case, non OpenAI requests go through here - # handle prompt formatting if prompt is a string vs. list of strings - messages = [] - if ( - isinstance(prompt, list) - and len(prompt) > 0 - and isinstance(prompt[0], str) - ): - for p in prompt: - message = {"role": "user", "content": p} - messages.append(message) - elif isinstance(prompt, str): - messages = [{"role": "user", "content": prompt}] - elif ( - ( - custom_llm_provider == "openai" - or custom_llm_provider == "azure" - or custom_llm_provider == "azure_text" - or custom_llm_provider == "text-completion-codestral" - or custom_llm_provider == "text-completion-openai" - ) - and isinstance(prompt, list) - and len(prompt) > 0 - and isinstance(prompt[0], list) - ): - verbose_logger.warning( - msg="List of lists being passed. If this is for tokens, then it might not work across all models." - ) - messages = [{"role": "user", "content": prompt}] # type: ignore - else: - raise Exception( - f"Unmapped prompt format. Your prompt is neither a list of strings nor a string. prompt={prompt}. File an issue - https://github.com/BerriAI/litellm/issues" - ) - - kwargs.pop("prompt", None) - - if _model is not None and ( - custom_llm_provider == "openai" - ): # for openai compatible endpoints - e.g. vllm, call the native /v1/completions endpoint for text completion calls - if _model not in litellm.open_ai_chat_completion_models: - model = "text-completion-openai/" + _model - optional_params.pop("custom_llm_provider", None) - - kwargs["text_completion"] = True - response = completion( - model=model, - messages=messages, - *args, - **kwargs, - **optional_params, - ) - if kwargs.get("acompletion", False) is True: - return response - if stream is True or kwargs.get("stream", False) is True: - response = TextCompletionStreamWrapper( - completion_stream=response, - model=model, - stream_options=stream_options, - custom_llm_provider=custom_llm_provider, - ) - return response - transformed_logprobs = None - # only supported for TGI models - try: - raw_response = response._hidden_params.get("original_response", None) - transformed_logprobs = litellm.utils.transform_logprobs(raw_response) - except Exception as e: - print_verbose(f"LiteLLM non blocking exception: {e}") - - if isinstance(response, TextCompletionResponse): - return response - - text_completion_response["id"] = response.get("id", None) - text_completion_response["object"] = "text_completion" - text_completion_response["created"] = response.get("created", None) - text_completion_response["model"] = response.get("model", None) - text_choices = TextChoices() - text_choices["text"] = response["choices"][0]["message"]["content"] - text_choices["index"] = response["choices"][0]["index"] - text_choices["logprobs"] = transformed_logprobs - text_choices["finish_reason"] = response["choices"][0]["finish_reason"] - text_completion_response["choices"] = [text_choices] - text_completion_response["usage"] = response.get("usage", None) - text_completion_response._hidden_params = HiddenParams( - **response._hidden_params - ) - - return text_completion_response - - -###### Adapter Completion ################ - - -async def aadapter_completion( - *, adapter_id: str, **kwargs -) -> Optional[Union[BaseModel, AdapterCompletionStreamWrapper]]: - """ - Implemented to handle async calls for adapter_completion() - """ - try: - translation_obj: Optional[CustomLogger] = None - for item in litellm.adapters: - if item["id"] == adapter_id: - translation_obj = item["adapter"] - - if translation_obj is None: - raise ValueError( - "No matching adapter given. Received 'adapter_id'={}, litellm.adapters={}".format( - adapter_id, litellm.adapters - ) - ) - - new_kwargs = translation_obj.translate_completion_input_params( - kwargs=kwargs - ) - - response: Union[ModelResponse, CustomStreamWrapper] = await acompletion(**new_kwargs) # type: ignore - translated_response: Optional[ - Union[BaseModel, AdapterCompletionStreamWrapper] - ] = None - if isinstance(response, ModelResponse): - translated_response = ( - translation_obj.translate_completion_output_params( - response=response - ) - ) - if isinstance(response, CustomStreamWrapper): - translated_response = ( - translation_obj.translate_completion_output_params_streaming( - completion_stream=response - ) - ) - - return translated_response - except Exception as e: - raise e - - -def adapter_completion( - *, adapter_id: str, **kwargs -) -> Optional[Union[BaseModel, AdapterCompletionStreamWrapper]]: - translation_obj: Optional[CustomLogger] = None - for item in litellm.adapters: - if item["id"] == adapter_id: - translation_obj = item["adapter"] - - if translation_obj is None: - raise ValueError( - "No matching adapter given. Received 'adapter_id'={}, litellm.adapters={}".format( - adapter_id, litellm.adapters - ) - ) - - new_kwargs = translation_obj.translate_completion_input_params( - kwargs=kwargs - ) - - response: Union[ModelResponse, CustomStreamWrapper] = completion(**new_kwargs) # type: ignore - translated_response: Optional[ - Union[BaseModel, AdapterCompletionStreamWrapper] - ] = None - if isinstance(response, ModelResponse): - translated_response = ( - translation_obj.translate_completion_output_params( - response=response - ) - ) - elif isinstance(response, CustomStreamWrapper) or inspect.isgenerator( - response - ): - translated_response = ( - translation_obj.translate_completion_output_params_streaming( - completion_stream=response - ) - ) - - return translated_response - - -##### Moderation ####################### - - -def moderation( - input: str, - model: Optional[str] = None, - api_key: Optional[str] = None, - **kwargs, -): - # only supports open ai for now - api_key = ( - api_key - or litellm.api_key - or litellm.openai_key - or get_secret("OPENAI_API_KEY") - ) - - openai_client = kwargs.get("client", None) - if openai_client is None: - openai_client = openai.OpenAI( - api_key=api_key, - ) - - response = openai_client.moderations.create(input=input, model=model) - return response - - -@client -async def amoderation( - input: str, - model: Optional[str] = None, - api_key: Optional[str] = None, - **kwargs, -): - # only supports open ai for now - api_key = ( - api_key - or litellm.api_key - or litellm.openai_key - or get_secret("OPENAI_API_KEY") - ) - openai_client = kwargs.get("client", None) - if openai_client is None: - # call helper to get OpenAI client - # _get_openai_client maintains in-memory caching logic for OpenAI clients - openai_client = openai_chat_completions._get_openai_client( - is_async=True, - api_key=api_key, - ) - response = await openai_client.moderations.create(input=input, model=model) - return response - - -##### Image Generation ####################### -@client -async def aimage_generation(*args, **kwargs) -> ImageResponse: - """ - Asynchronously calls the `image_generation` function with the given arguments and keyword arguments. - - Parameters: - - `args` (tuple): Positional arguments to be passed to the `image_generation` function. - - `kwargs` (dict): Keyword arguments to be passed to the `image_generation` function. - - Returns: - - `response` (Any): The response returned by the `image_generation` function. - """ - loop = asyncio.get_event_loop() - model = args[0] if len(args) > 0 else kwargs["model"] - ### PASS ARGS TO Image Generation ### - kwargs["aimg_generation"] = True - custom_llm_provider = None - try: - # Use a partial function to pass your keyword arguments - func = partial(image_generation, *args, **kwargs) - - # Add the context to the function - ctx = contextvars.copy_context() - func_with_context = partial(ctx.run, func) - - _, custom_llm_provider, _, _ = get_llm_provider( - model=model, api_base=kwargs.get("api_base", None) - ) - - # Await normally - init_response = await loop.run_in_executor(None, func_with_context) - if isinstance(init_response, dict) or isinstance( - init_response, ImageResponse - ): ## CACHING SCENARIO - if isinstance(init_response, dict): - init_response = ImageResponse(**init_response) - response = init_response - elif asyncio.iscoroutine(init_response): - response = await init_response - else: - # Call the synchronous function using run_in_executor - response = await loop.run_in_executor(None, func_with_context) - return response - except Exception as e: - custom_llm_provider = custom_llm_provider or "openai" - raise exception_type( - model=model, - custom_llm_provider=custom_llm_provider, - original_exception=e, - completion_kwargs=args, - extra_kwargs=kwargs, - ) - - -@client -def image_generation( - prompt: str, - model: Optional[str] = None, - n: Optional[int] = None, - quality: Optional[str] = None, - response_format: Optional[str] = None, - size: Optional[str] = None, - style: Optional[str] = None, - user: Optional[str] = None, - timeout=600, # default to 10 minutes - api_key: Optional[str] = None, - api_base: Optional[str] = None, - api_version: Optional[str] = None, - litellm_logging_obj=None, - custom_llm_provider=None, - **kwargs, -) -> ImageResponse: - """ - Maps the https://api.openai.com/v1/images/generations endpoint. - - Currently supports just Azure + OpenAI. - """ - try: - aimg_generation = kwargs.get("aimg_generation", False) - litellm_call_id = kwargs.get("litellm_call_id", None) - logger_fn = kwargs.get("logger_fn", None) - proxy_server_request = kwargs.get("proxy_server_request", None) - model_info = kwargs.get("model_info", None) - metadata = kwargs.get("metadata", {}) - client = kwargs.get("client", None) - - model_response = litellm.utils.ImageResponse() - if model is not None or custom_llm_provider is not None: - model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base) # type: ignore - else: - model = "dall-e-2" - custom_llm_provider = "openai" # default to dall-e-2 on openai - model_response._hidden_params["model"] = model - openai_params = [ - "user", - "request_timeout", - "api_base", - "api_version", - "api_key", - "deployment_id", - "organization", - "base_url", - "default_headers", - "timeout", - "max_retries", - "n", - "quality", - "size", - "style", - ] - litellm_params = [ - "metadata", - "aimg_generation", - "caching", - "mock_response", - "api_key", - "api_version", - "api_base", - "force_timeout", - "logger_fn", - "verbose", - "custom_llm_provider", - "litellm_logging_obj", - "litellm_call_id", - "use_client", - "id", - "fallbacks", - "azure", - "headers", - "model_list", - "num_retries", - "context_window_fallback_dict", - "retry_policy", - "roles", - "final_prompt_value", - "bos_token", - "eos_token", - "request_timeout", - "complete_response", - "self", - "client", - "rpm", - "tpm", - "max_parallel_requests", - "input_cost_per_token", - "output_cost_per_token", - "hf_model_name", - "proxy_server_request", - "model_info", - "preset_cache_key", - "caching_groups", - "ttl", - "cache", - "region_name", - "allowed_model_region", - "model_config", - ] - default_params = openai_params + litellm_params - non_default_params = { - k: v for k, v in kwargs.items() if k not in default_params - } # model-specific params - pass them straight to the model/provider - optional_params = get_optional_params_image_gen( - n=n, - quality=quality, - response_format=response_format, - size=size, - style=style, - user=user, - custom_llm_provider=custom_llm_provider, - **non_default_params, - ) - logging: Logging = litellm_logging_obj - logging.update_environment_variables( - model=model, - user=user, - optional_params=optional_params, - litellm_params={ - "timeout": timeout, - "azure": False, - "litellm_call_id": litellm_call_id, - "logger_fn": logger_fn, - "proxy_server_request": proxy_server_request, - "model_info": model_info, - "metadata": metadata, - "preset_cache_key": None, - "stream_response": {}, - }, - custom_llm_provider=custom_llm_provider, - ) - - if custom_llm_provider == "azure": - # azure configs - api_type = get_secret("AZURE_API_TYPE") or "azure" - - api_base = ( - api_base or litellm.api_base or get_secret("AZURE_API_BASE") - ) - - api_version = ( - api_version - or litellm.api_version - or get_secret("AZURE_API_VERSION") - ) - - api_key = ( - api_key - or litellm.api_key - or litellm.azure_key - or get_secret("AZURE_OPENAI_API_KEY") - or get_secret("AZURE_API_KEY") - ) - - azure_ad_token = optional_params.pop( - "azure_ad_token", None - ) or get_secret("AZURE_AD_TOKEN") - - model_response = azure_chat_completions.image_generation( - model=model, - prompt=prompt, - timeout=timeout, - api_key=api_key, - api_base=api_base, - logging_obj=litellm_logging_obj, - optional_params=optional_params, - model_response=model_response, - api_version=api_version, - aimg_generation=aimg_generation, - client=client, - ) - elif custom_llm_provider == "openai": - model_response = openai_chat_completions.image_generation( - model=model, - prompt=prompt, - timeout=timeout, - api_key=api_key, - api_base=api_base, - logging_obj=litellm_logging_obj, - optional_params=optional_params, - model_response=model_response, - aimg_generation=aimg_generation, - client=client, - ) - elif custom_llm_provider == "bedrock": - if model is None: - raise Exception("Model needs to be set for bedrock") - model_response = bedrock_image_generation.image_generation( - model=model, - prompt=prompt, - timeout=timeout, - logging_obj=litellm_logging_obj, - optional_params=optional_params, - model_response=model_response, - aimg_generation=aimg_generation, - ) - elif custom_llm_provider == "vertex_ai": - vertex_ai_project = ( - optional_params.pop("vertex_project", None) - or optional_params.pop("vertex_ai_project", None) - or litellm.vertex_project - or get_secret("VERTEXAI_PROJECT") - ) - vertex_ai_location = ( - optional_params.pop("vertex_location", None) - or optional_params.pop("vertex_ai_location", None) - or litellm.vertex_location - or get_secret("VERTEXAI_LOCATION") - ) - vertex_credentials = ( - optional_params.pop("vertex_credentials", None) - or optional_params.pop("vertex_ai_credentials", None) - or get_secret("VERTEXAI_CREDENTIALS") - ) - model_response = vertex_image_generation.image_generation( - model=model, - prompt=prompt, - timeout=timeout, - logging_obj=litellm_logging_obj, - optional_params=optional_params, - model_response=model_response, - vertex_project=vertex_ai_project, - vertex_location=vertex_ai_location, - vertex_credentials=vertex_credentials, - aimg_generation=aimg_generation, - ) - - return model_response - except Exception as e: - ## Map to OpenAI Exception - raise exception_type( - model=model, - custom_llm_provider=custom_llm_provider, - original_exception=e, - completion_kwargs=locals(), - extra_kwargs=kwargs, - ) - - -##### Transcription ####################### - - -@client -async def atranscription(*args, **kwargs) -> TranscriptionResponse: - """ - Calls openai + azure whisper endpoints. - - Allows router to load balance between them - """ - loop = asyncio.get_event_loop() - model = args[0] if len(args) > 0 else kwargs["model"] - ### PASS ARGS TO Image Generation ### - kwargs["atranscription"] = True - custom_llm_provider = None - try: - # Use a partial function to pass your keyword arguments - func = partial(transcription, *args, **kwargs) - - # Add the context to the function - ctx = contextvars.copy_context() - func_with_context = partial(ctx.run, func) - - _, custom_llm_provider, _, _ = get_llm_provider( - model=model, api_base=kwargs.get("api_base", None) - ) - - # Await normally - init_response = await loop.run_in_executor(None, func_with_context) - if isinstance(init_response, dict): - response = TranscriptionResponse(**init_response) - elif isinstance( - init_response, TranscriptionResponse - ): ## CACHING SCENARIO - response = init_response - elif asyncio.iscoroutine(init_response): - response = await init_response - else: - # Call the synchronous function using run_in_executor - response = await loop.run_in_executor(None, func_with_context) - return response - except Exception as e: - custom_llm_provider = custom_llm_provider or "openai" - raise exception_type( - model=model, - custom_llm_provider=custom_llm_provider, - original_exception=e, - completion_kwargs=args, - extra_kwargs=kwargs, - ) - - -@client -def transcription( - model: str, - file: FileTypes, - ## OPTIONAL OPENAI PARAMS ## - language: Optional[str] = None, - prompt: Optional[str] = None, - response_format: Optional[ - Literal["json", "text", "srt", "verbose_json", "vtt"] - ] = None, - temperature: Optional[int] = None, # openai defaults this to 0 - ## LITELLM PARAMS ## - user: Optional[str] = None, - timeout=600, # default to 10 minutes - api_key: Optional[str] = None, - api_base: Optional[str] = None, - api_version: Optional[str] = None, - max_retries: Optional[int] = None, - litellm_logging_obj: Optional[LiteLLMLoggingObj] = None, - custom_llm_provider=None, - **kwargs, -) -> TranscriptionResponse: - """ - Calls openai + azure whisper endpoints. - - Allows router to load balance between them - """ - atranscription = kwargs.get("atranscription", False) - litellm_call_id = kwargs.get("litellm_call_id", None) - logger_fn = kwargs.get("logger_fn", None) - proxy_server_request = kwargs.get("proxy_server_request", None) - model_info = kwargs.get("model_info", None) - metadata = kwargs.get("metadata", {}) - tags = kwargs.pop("tags", []) - - drop_params = kwargs.get("drop_params", None) - client: Optional[ - Union[ - openai.AsyncOpenAI, - openai.OpenAI, - openai.AzureOpenAI, - openai.AsyncAzureOpenAI, - ] - ] = kwargs.pop("client", None) - - if litellm_logging_obj: - litellm_logging_obj.model_call_details["client"] = str(client) - - if max_retries is None: - max_retries = openai.DEFAULT_MAX_RETRIES - - model_response = litellm.utils.TranscriptionResponse() - - model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base) # type: ignore - - if dynamic_api_key is not None: - api_key = dynamic_api_key - - optional_params = get_optional_params_transcription( - model=model, - language=language, - prompt=prompt, - response_format=response_format, - temperature=temperature, - custom_llm_provider=custom_llm_provider, - drop_params=drop_params, - ) - # optional_params = { - # "language": language, - # "prompt": prompt, - # "response_format": response_format, - # "temperature": None, # openai defaults this to 0 - # } - - if custom_llm_provider == "azure": - # azure configs - api_base = api_base or litellm.api_base or get_secret("AZURE_API_BASE") - - api_version = ( - api_version - or litellm.api_version - or get_secret("AZURE_API_VERSION") - ) - - azure_ad_token = kwargs.pop("azure_ad_token", None) or get_secret( - "AZURE_AD_TOKEN" - ) - - api_key = ( - api_key - or litellm.api_key - or litellm.azure_key - or get_secret("AZURE_API_KEY") - ) # type: ignore - - response = azure_audio_transcriptions.audio_transcriptions( - model=model, - audio_file=file, - optional_params=optional_params, - model_response=model_response, - atranscription=atranscription, - client=client, - timeout=timeout, - logging_obj=litellm_logging_obj, - api_base=api_base, - api_key=api_key, - api_version=api_version, - azure_ad_token=azure_ad_token, - max_retries=max_retries, - ) - elif custom_llm_provider == "openai" or custom_llm_provider == "groq": - api_base = ( - api_base - or litellm.api_base - or get_secret("OPENAI_API_BASE") - or "https://api.openai.com/v1" - ) # type: ignore - openai.organization = ( - litellm.organization - or get_secret("OPENAI_ORGANIZATION") - or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105 - ) - # set API KEY - api_key = ( - api_key - or litellm.api_key - or litellm.openai_key - or get_secret("OPENAI_API_KEY") - ) # type: ignore - response = openai_audio_transcriptions.audio_transcriptions( - model=model, - audio_file=file, - optional_params=optional_params, - model_response=model_response, - atranscription=atranscription, - client=client, - timeout=timeout, - logging_obj=litellm_logging_obj, - max_retries=max_retries, - api_base=api_base, - api_key=api_key, - ) - return response - - -@client -async def aspeech(*args, **kwargs) -> HttpxBinaryResponseContent: - """ - Calls openai tts endpoints. - """ - loop = asyncio.get_event_loop() - model = args[0] if len(args) > 0 else kwargs["model"] - ### PASS ARGS TO Image Generation ### - kwargs["aspeech"] = True - custom_llm_provider = kwargs.get("custom_llm_provider", None) - try: - # Use a partial function to pass your keyword arguments - func = partial(speech, *args, **kwargs) - - # Add the context to the function - ctx = contextvars.copy_context() - func_with_context = partial(ctx.run, func) - - _, custom_llm_provider, _, _ = get_llm_provider( - model=model, api_base=kwargs.get("api_base", None) - ) - - # Await normally - init_response = await loop.run_in_executor(None, func_with_context) - if asyncio.iscoroutine(init_response): - response = await init_response - else: - # Call the synchronous function using run_in_executor - response = await loop.run_in_executor(None, func_with_context) - return response # type: ignore - except Exception as e: - custom_llm_provider = custom_llm_provider or "openai" - raise exception_type( - model=model, - custom_llm_provider=custom_llm_provider, - original_exception=e, - completion_kwargs=args, - extra_kwargs=kwargs, - ) - - -@client -def speech( - model: str, - input: str, - voice: Optional[Union[str, dict]] = None, - api_key: Optional[str] = None, - api_base: Optional[str] = None, - api_version: Optional[str] = None, - organization: Optional[str] = None, - project: Optional[str] = None, - max_retries: Optional[int] = None, - metadata: Optional[dict] = None, - timeout: Optional[Union[float, httpx.Timeout]] = None, - response_format: Optional[str] = None, - speed: Optional[int] = None, - client=None, - headers: Optional[dict] = None, - custom_llm_provider: Optional[str] = None, - aspeech: Optional[bool] = None, - **kwargs, -) -> HttpxBinaryResponseContent: - model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base) # type: ignore - tags = kwargs.pop("tags", []) - - optional_params = {} - if response_format is not None: - optional_params["response_format"] = response_format - if speed is not None: - optional_params["speed"] = speed # type: ignore - - if timeout is None: - timeout = litellm.request_timeout - - if max_retries is None: - max_retries = litellm.num_retries or openai.DEFAULT_MAX_RETRIES - - logging_obj = kwargs.get("litellm_logging_obj", None) - response: Optional[HttpxBinaryResponseContent] = None - if custom_llm_provider == "openai": - if voice is None or not (isinstance(voice, str)): - raise litellm.BadRequestError( - message="'voice' is required to be passed as a string for OpenAI TTS", - model=model, - llm_provider=custom_llm_provider, - ) - api_base = ( - api_base # for deepinfra/perplexity/anyscale/groq/friendliai we check in get_llm_provider and pass in the api base from there - or litellm.api_base - or get_secret("OPENAI_API_BASE") - or "https://api.openai.com/v1" - ) # type: ignore - # set API KEY - api_key = ( - api_key - or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there - or litellm.openai_key - or get_secret("OPENAI_API_KEY") - ) # type: ignore - - organization = ( - organization - or litellm.organization - or get_secret("OPENAI_ORGANIZATION") - or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105 - ) # type: ignore - - project = ( - project - or litellm.project - or get_secret("OPENAI_PROJECT") - or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105 - ) # type: ignore - - headers = headers or litellm.headers - - response = openai_chat_completions.audio_speech( - model=model, - input=input, - voice=voice, - optional_params=optional_params, - api_key=api_key, - api_base=api_base, - organization=organization, - project=project, - max_retries=max_retries, - timeout=timeout, - client=client, # pass AsyncOpenAI, OpenAI client - aspeech=aspeech, - ) - elif custom_llm_provider == "azure": - # azure configs - if voice is None or not (isinstance(voice, str)): - raise litellm.BadRequestError( - message="'voice' is required to be passed as a string for Azure TTS", - model=model, - llm_provider=custom_llm_provider, - ) - api_base = api_base or litellm.api_base or get_secret("AZURE_API_BASE") # type: ignore - - api_version = ( - api_version - or litellm.api_version - or get_secret("AZURE_API_VERSION") - ) # type: ignore - - api_key = ( - api_key - or litellm.api_key - or litellm.azure_key - or get_secret("AZURE_OPENAI_API_KEY") - or get_secret("AZURE_API_KEY") - ) # type: ignore - - azure_ad_token: Optional[str] = optional_params.get("extra_body", {}).pop( # type: ignore - "azure_ad_token", None - ) or get_secret( - "AZURE_AD_TOKEN" - ) - - headers = headers or litellm.headers - - response = azure_chat_completions.audio_speech( - model=model, - input=input, - voice=voice, - optional_params=optional_params, - api_key=api_key, - api_base=api_base, - api_version=api_version, - azure_ad_token=azure_ad_token, - organization=organization, - max_retries=max_retries, - timeout=timeout, - client=client, # pass AsyncOpenAI, OpenAI client - aspeech=aspeech, - ) - elif ( - custom_llm_provider == "vertex_ai" - or custom_llm_provider == "vertex_ai_beta" - ): - from litellm.types.router import GenericLiteLLMParams - - generic_optional_params = GenericLiteLLMParams(**kwargs) - - api_base = generic_optional_params.api_base or "" - vertex_ai_project = ( - generic_optional_params.vertex_project - or litellm.vertex_project - or get_secret("VERTEXAI_PROJECT") - ) - vertex_ai_location = ( - generic_optional_params.vertex_location - or litellm.vertex_location - or get_secret("VERTEXAI_LOCATION") - ) - vertex_credentials = ( - generic_optional_params.vertex_credentials - or get_secret("VERTEXAI_CREDENTIALS") - ) - - if voice is not None and not isinstance(voice, dict): - raise litellm.BadRequestError( - message=f"'voice' is required to be passed as a dict for Vertex AI TTS, passed in voice={voice}", - model=model, - llm_provider=custom_llm_provider, - ) - response = vertex_text_to_speech.audio_speech( - _is_async=aspeech, - vertex_credentials=vertex_credentials, - vertex_project=vertex_ai_project, - vertex_location=vertex_ai_location, - timeout=timeout, - api_base=api_base, - model=model, - input=input, - voice=voice, - optional_params=optional_params, - kwargs=kwargs, - logging_obj=logging_obj, - ) - - if response is None: - raise Exception( - "Unable to map the custom llm provider={} to a known provider={}.".format( - custom_llm_provider, provider_list - ) - ) - return response - - -##### Health Endpoints ####################### - - -async def ahealth_check( - model_params: dict, - mode: Optional[ - Literal["completion", "embedding", "image_generation", "chat", "batch"] - ] = None, - prompt: Optional[str] = None, - input: Optional[List] = None, - default_timeout: float = 6000, -): - """ - Support health checks for different providers. Return remaining rate limit, etc. - - For azure/openai -> completion.with_raw_response - For rest -> litellm.acompletion() - """ - passed_in_mode: Optional[str] = None - try: - model: Optional[str] = model_params.get("model", None) - - if model is None: - raise Exception("model not set") - - if model in litellm.model_cost and mode is None: - mode = litellm.model_cost[model].get("mode") - - model, custom_llm_provider, _, _ = get_llm_provider(model=model) - - if model in litellm.model_cost and mode is None: - mode = litellm.model_cost[model].get("mode") - - mode = mode - passed_in_mode = mode - if mode is None: - mode = "chat" # default to chat completion calls - - if custom_llm_provider == "azure": - api_key = ( - model_params.get("api_key") - or get_secret("AZURE_API_KEY") - or get_secret("AZURE_OPENAI_API_KEY") - ) - - api_base = ( - model_params.get("api_base") - or get_secret("AZURE_API_BASE") - or get_secret("AZURE_OPENAI_API_BASE") - ) - - api_version = ( - model_params.get("api_version") - or get_secret("AZURE_API_VERSION") - or get_secret("AZURE_OPENAI_API_VERSION") - ) - - timeout = ( - model_params.get("timeout") - or litellm.request_timeout - or default_timeout - ) - - response = await azure_chat_completions.ahealth_check( - model=model, - messages=model_params.get( - "messages", None - ), # Replace with your actual messages list - api_key=api_key, - api_base=api_base, - api_version=api_version, - timeout=timeout, - mode=mode, - prompt=prompt, - input=input, - ) - elif ( - custom_llm_provider == "openai" - or custom_llm_provider == "text-completion-openai" - ): - api_key = model_params.get("api_key") or get_secret( - "OPENAI_API_KEY" - ) - organization = model_params.get("organization") - - timeout = ( - model_params.get("timeout") - or litellm.request_timeout - or default_timeout - ) - - api_base = model_params.get("api_base") or get_secret( - "OPENAI_API_BASE" - ) - - if custom_llm_provider == "text-completion-openai": - mode = "completion" - - response = await openai_chat_completions.ahealth_check( - model=model, - messages=model_params.get( - "messages", None - ), # Replace with your actual messages list - api_key=api_key, - api_base=api_base, - timeout=timeout, - mode=mode, - prompt=prompt, - input=input, - organization=organization, - ) - else: - model_params["cache"] = { - "no-cache": True - } # don't used cached responses for making health check calls - if mode == "embedding": - model_params.pop("messages", None) - model_params["input"] = input - await litellm.aembedding(**model_params) - response = {} - elif mode == "image_generation": - model_params.pop("messages", None) - model_params["prompt"] = prompt - await litellm.aimage_generation(**model_params) - response = {} - elif "*" in model: - from litellm.litellm_core_utils.llm_request_utils import ( - pick_cheapest_model_from_llm_provider, - ) - - # this is a wildcard model, we need to pick a random model from the provider - cheapest_model = pick_cheapest_model_from_llm_provider( - custom_llm_provider=custom_llm_provider - ) - model_params["model"] = cheapest_model - await acompletion(**model_params) - response = {} # args like remaining ratelimit etc. - else: # default to completion calls - await acompletion(**model_params) - response = {} # args like remaining ratelimit etc. - return response - except Exception as e: - stack_trace = traceback.format_exc() - if isinstance(stack_trace, str): - stack_trace = stack_trace[:1000] - - if passed_in_mode is None: - return { - "error": f"error:{str(e)}. Missing `mode`. Set the `mode` for the model - https://docs.litellm.ai/docs/proxy/health#embedding-models \nstacktrace: {stack_trace}" - } - - error_to_return = ( - str(e) - + "\nHave you set 'mode' - https://docs.litellm.ai/docs/proxy/health#embedding-models" - + "\nstack trace: " - + stack_trace - ) - return {"error": error_to_return} - - -####### HELPER FUNCTIONS ################ -## Set verbose to true -> ```litellm.set_verbose = True``` -def print_verbose(print_statement): - try: - verbose_logger.debug(print_statement) - if litellm.set_verbose: - print(print_statement) # noqa - except: - pass - - -def config_completion(**kwargs): - if litellm.config_path != None: - config_args = read_config_args(litellm.config_path) - # overwrite any args passed in with config args - return completion(**kwargs, **config_args) - else: - raise ValueError( - "No config path set, please set a config path using `litellm.config_path = 'path/to/config.json'`" - ) - - -def stream_chunk_builder_text_completion( - chunks: list, messages: Optional[List] = None -): - id = chunks[0]["id"] - object = chunks[0]["object"] - created = chunks[0]["created"] - model = chunks[0]["model"] - system_fingerprint = chunks[0].get("system_fingerprint", None) - finish_reason = chunks[-1]["choices"][0]["finish_reason"] - logprobs = chunks[-1]["choices"][0]["logprobs"] - - response = { - "id": id, - "object": object, - "created": created, - "model": model, - "system_fingerprint": system_fingerprint, - "choices": [ - { - "text": None, - "index": 0, - "logprobs": logprobs, - "finish_reason": finish_reason, - } - ], - "usage": { - "prompt_tokens": None, - "completion_tokens": None, - "total_tokens": None, - }, - } - content_list = [] - for chunk in chunks: - choices = chunk["choices"] - for choice in choices: - if ( - choice is not None - and hasattr(choice, "text") - and choice.get("text") is not None - ): - _choice = choice.get("text") - content_list.append(_choice) - - # Combine the "content" strings into a single string || combine the 'function' strings into a single string - combined_content = "".join(content_list) - - # Update the "content" field within the response dictionary - response["choices"][0]["text"] = combined_content - - if len(combined_content) > 0: - completion_output = combined_content - else: - completion_output = "" - # # Update usage information if needed - try: - response["usage"]["prompt_tokens"] = token_counter( - model=model, messages=messages - ) - except: # don't allow this failing to block a complete streaming response from being returned - print_verbose(f"token_counter failed, assuming prompt tokens is 0") - response["usage"]["prompt_tokens"] = 0 - response["usage"]["completion_tokens"] = token_counter( - model=model, - text=combined_content, - count_response_tokens=True, # count_response_tokens is a Flag to tell token counter this is a response, No need to add extra tokens we do for input messages - ) - response["usage"]["total_tokens"] = ( - response["usage"]["prompt_tokens"] - + response["usage"]["completion_tokens"] - ) - return response - - -def stream_chunk_builder( - chunks: list, - messages: Optional[list] = None, - start_time=None, - end_time=None, -) -> Optional[Union[ModelResponse, TextCompletionResponse]]: - try: - model_response = litellm.ModelResponse() - ### BASE-CASE ### - if len(chunks) == 0: - return None - ### SORT CHUNKS BASED ON CREATED ORDER ## - print_verbose( - "Goes into checking if chunk has hiddden created at param" - ) - if chunks[0]._hidden_params.get("created_at", None): - print_verbose("Chunks have a created at hidden param") - # Sort chunks based on created_at in ascending order - chunks = sorted( - chunks, - key=lambda x: x._hidden_params.get("created_at", float("inf")), - ) - print_verbose("Chunks sorted") - - # set hidden params from chunk to model_response - if model_response is not None and hasattr( - model_response, "_hidden_params" - ): - model_response._hidden_params = chunks[0].get("_hidden_params", {}) - id = chunks[0]["id"] - object = chunks[0]["object"] - created = chunks[0]["created"] - model = chunks[0]["model"] - system_fingerprint = chunks[0].get("system_fingerprint", None) - - if isinstance( - chunks[0]["choices"][0], litellm.utils.TextChoices - ): # route to the text completion logic - return stream_chunk_builder_text_completion( - chunks=chunks, messages=messages - ) - role = chunks[0]["choices"][0]["delta"]["role"] - finish_reason = chunks[-1]["choices"][0]["finish_reason"] - - # Initialize the response dictionary - response = { - "id": id, - "object": object, - "created": created, - "model": model, - "system_fingerprint": system_fingerprint, - "choices": [ - { - "index": 0, - "message": {"role": role, "content": ""}, - "finish_reason": finish_reason, - } - ], - "usage": { - "prompt_tokens": 0, # Modify as needed - "completion_tokens": 0, # Modify as needed - "total_tokens": 0, # Modify as needed - }, - } - - # Extract the "content" strings from the nested dictionaries within "choices" - content_list = [] - combined_content = "" - combined_arguments = "" - - tool_call_chunks = [ - chunk - for chunk in chunks - if "tool_calls" in chunk["choices"][0]["delta"] - and chunk["choices"][0]["delta"]["tool_calls"] is not None - ] - - if len(tool_call_chunks) > 0: - argument_list: List = [] - delta = tool_call_chunks[0]["choices"][0]["delta"] - message = response["choices"][0]["message"] - message["tool_calls"] = [] - id = None - name = None - type = None - tool_calls_list = [] - prev_index = None - prev_name = None - prev_id = None - curr_id = None - curr_index = 0 - for chunk in tool_call_chunks: - choices = chunk["choices"] - for choice in choices: - delta = choice.get("delta", {}) - tool_calls = delta.get("tool_calls", "") - # Check if a tool call is present - if tool_calls and tool_calls[0].function is not None: - if tool_calls[0].id: - id = tool_calls[0].id - curr_id = id - if prev_id is None: - prev_id = curr_id - if tool_calls[0].index: - curr_index = tool_calls[0].index - if tool_calls[0].function.arguments: - # Now, tool_calls is expected to be a dictionary - arguments = tool_calls[0].function.arguments - argument_list.append(arguments) - if tool_calls[0].function.name: - name = tool_calls[0].function.name - if tool_calls[0].type: - type = tool_calls[0].type - if prev_index is None: - prev_index = curr_index - if prev_name is None: - prev_name = name - if curr_index != prev_index: # new tool call - combined_arguments = "".join(argument_list) - tool_calls_list.append( - { - "id": prev_id, - "function": { - "arguments": combined_arguments, - "name": prev_name, - }, - "type": type, - } - ) - argument_list = [] # reset - prev_index = curr_index - prev_id = curr_id - prev_name = name - - combined_arguments = ( - "".join(argument_list) or "{}" - ) # base case, return empty dict - - tool_calls_list.append( - { - "id": id, - "function": { - "arguments": combined_arguments, - "name": name, - }, - "type": type, - } - ) - response["choices"][0]["message"]["content"] = None - response["choices"][0]["message"]["tool_calls"] = tool_calls_list - - function_call_chunks = [ - chunk - for chunk in chunks - if "function_call" in chunk["choices"][0]["delta"] - and chunk["choices"][0]["delta"]["function_call"] is not None - ] - - if len(function_call_chunks) > 0: - argument_list = [] - delta = function_call_chunks[0]["choices"][0]["delta"] - function_call = delta.get("function_call", "") - function_call_name = function_call.name - - message = response["choices"][0]["message"] - message["function_call"] = {} - message["function_call"]["name"] = function_call_name - - for chunk in function_call_chunks: - choices = chunk["choices"] - for choice in choices: - delta = choice.get("delta", {}) - function_call = delta.get("function_call", "") - - # Check if a function call is present - if function_call: - # Now, function_call is expected to be a dictionary - arguments = function_call.arguments - argument_list.append(arguments) - - combined_arguments = "".join(argument_list) - response["choices"][0]["message"]["content"] = None - response["choices"][0]["message"]["function_call"][ - "arguments" - ] = combined_arguments - - content_chunks = [ - chunk - for chunk in chunks - if "content" in chunk["choices"][0]["delta"] - and chunk["choices"][0]["delta"]["content"] is not None - ] - - if len(content_chunks) > 0: - for chunk in chunks: - choices = chunk["choices"] - for choice in choices: - delta = choice.get("delta", {}) - content = delta.get("content", "") - if content is None: - continue # openai v1.0.0 sets content = None for chunks - content_list.append(content) - - # Combine the "content" strings into a single string || combine the 'function' strings into a single string - combined_content = "".join(content_list) - - # Update the "content" field within the response dictionary - response["choices"][0]["message"]["content"] = combined_content - - completion_output = "" - if len(combined_content) > 0: - completion_output += combined_content - if len(combined_arguments) > 0: - completion_output += combined_arguments - - # Update usage information if needed - prompt_tokens = 0 - completion_tokens = 0 - # anthropic prompt caching information - cache_creation_input_tokens: Optional[int] = None - cache_read_input_tokens: Optional[int] = None - for chunk in chunks: - usage_chunk: Optional[Usage] = None - if "usage" in chunk: - usage_chunk = chunk.usage - elif ( - hasattr(chunk, "_hidden_params") - and "usage" in chunk._hidden_params - ): - usage_chunk = chunk._hidden_params["usage"] - if usage_chunk is not None: - if "prompt_tokens" in usage_chunk: - prompt_tokens = usage_chunk.get("prompt_tokens", 0) or 0 - if "completion_tokens" in usage_chunk: - completion_tokens = ( - usage_chunk.get("completion_tokens", 0) or 0 - ) - if "cache_creation_input_tokens" in usage_chunk: - cache_creation_input_tokens = usage_chunk.get( - "cache_creation_input_tokens" - ) - if "cache_read_input_tokens" in usage_chunk: - cache_read_input_tokens = usage_chunk.get( - "cache_read_input_tokens" - ) - - try: - response["usage"][ - "prompt_tokens" - ] = prompt_tokens or token_counter(model=model, messages=messages) - except ( - Exception - ): # don't allow this failing to block a complete streaming response from being returned - print_verbose("token_counter failed, assuming prompt tokens is 0") - response["usage"]["prompt_tokens"] = 0 - response["usage"][ - "completion_tokens" - ] = completion_tokens or token_counter( - model=model, - text=completion_output, - count_response_tokens=True, # count_response_tokens is a Flag to tell token counter this is a response, No need to add extra tokens we do for input messages - ) - response["usage"]["total_tokens"] = ( - response["usage"]["prompt_tokens"] - + response["usage"]["completion_tokens"] - ) - - if cache_creation_input_tokens is not None: - response["usage"][ - "cache_creation_input_tokens" - ] = cache_creation_input_tokens - if cache_read_input_tokens is not None: - response["usage"][ - "cache_read_input_tokens" - ] = cache_read_input_tokens - - return convert_to_model_response_object( - response_object=response, - model_response_object=model_response, - start_time=start_time, - end_time=end_time, - ) # type: ignore - except Exception as e: - verbose_logger.exception( - "litellm.main.py::stream_chunk_builder() - Exception occurred - {}".format( - str(e) - ) - ) - raise litellm.APIError( - status_code=500, - message="Error building chunks for logging/streaming usage calculation", - llm_provider="", - model="", - ) diff --git a/tests/test_toolkit/test_litellm.py b/tests/test_toolkit/test_litellm.py index ace4ea34..bfdc6bd8 100644 --- a/tests/test_toolkit/test_litellm.py +++ b/tests/test_toolkit/test_litellm.py @@ -46,10 +46,12 @@ {"provider": "google", "model": "gemini-1.5-pro-latest"}, {"provider": "google", "model": "gemini-1.5-flash-latest"}, {"provider": "google", "model": "gemini-1.0-pro-latest"}, + # # {"provider": "replicate", "model": "mistral-7b-instruct-v0.2"}, removed due to replicate side error # {"provider": "replicate", "model": "mixtral-8x7b-instruct-v0.1"}, removed due to replicate side error # {"provider": "replicate", "model": "meta-llama-3-70b-instruct"}, removed due to replicate side error # {"provider": "replicate", "model": "meta-llama-3.1-405b-instruct"}, removed due to replicate side error + # {"provider": "replicate", "model": "meta-llama-3-8b-instruct"}, {"provider": "togetherai", "model": "Mistral-7B-Instruct-v0.2"}, {"provider": "togetherai", "model": "Mixtral-8x7B-Instruct-v0.1"}, From cf072797d3ced571f9c6a2c8512a9cc00dcffa14 Mon Sep 17 00:00:00 2001 From: Tze-Yang Tung Date: Thu, 26 Sep 2024 14:58:19 -0400 Subject: [PATCH 5/7] minor cleanup --- notdiamond/toolkit/litellm/main.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/notdiamond/toolkit/litellm/main.py b/notdiamond/toolkit/litellm/main.py index 61f76d5b..205082cc 100644 --- a/notdiamond/toolkit/litellm/main.py +++ b/notdiamond/toolkit/litellm/main.py @@ -432,7 +432,7 @@ def get_llm_provider( or get_secret("FRIENDLI_TOKEN") ) elif custom_llm_provider == "notdiamond": - api_base = "https://not-diamond-server.onrender.com/v2/optimizer/modelSelect" + api_base = "https://api.notdiamond.ai/v2/optimizer/modelSelect" dynamic_api_key = get_secret("NOTDIAMOND_API_KEY") or None if api_base is not None and not isinstance(api_base, str): raise Exception( @@ -1341,7 +1341,7 @@ def completion( api_base or litellm.api_base or get_secret("NOTDIAMOND_API_BASE") - or "https://not-diamond-server.onrender.com/v2/optimizer/modelSelect" + or "https://api.notdiamond.ai/v2/optimizer/modelSelect" ) # since notdiamond.completion() internally calls other models' completion functions From 04834c91dbda696b26c31d0c57305dd82207b3ec Mon Sep 17 00:00:00 2001 From: Tze-Yang Tung Date: Mon, 30 Sep 2024 09:40:24 -0400 Subject: [PATCH 6/7] minor cleanup --- notdiamond/toolkit/litellm/main.py | 23 +---------------------- 1 file changed, 1 insertion(+), 22 deletions(-) diff --git a/notdiamond/toolkit/litellm/main.py b/notdiamond/toolkit/litellm/main.py index 205082cc..d532c85e 100644 --- a/notdiamond/toolkit/litellm/main.py +++ b/notdiamond/toolkit/litellm/main.py @@ -2,33 +2,13 @@ import asyncio import contextvars -import datetime import inspect -import json import os -import random -import sys -import threading import time -import traceback -import uuid -from concurrent.futures import ThreadPoolExecutor from copy import deepcopy from functools import partial -from typing import ( - Any, - Callable, - Dict, - List, - Literal, - Mapping, - Optional, - Tuple, - Type, - Union, -) +from typing import List, Optional, Tuple, Type, Union -import dotenv import httpx import litellm import openai @@ -152,7 +132,6 @@ token_counter, ) from pydantic import BaseModel -from typing_extensions import overload encoding = tiktoken.get_encoding("cl100k_base") from litellm.main import ( From 2c3bcf390042279a2d39f051d92f9b033b154d69 Mon Sep 17 00:00:00 2001 From: Tze-Yang Tung Date: Mon, 30 Sep 2024 11:39:07 -0400 Subject: [PATCH 7/7] more cleanup --- notdiamond/toolkit/litellm/main.py | 149 ++--------------------------- tests/test_toolkit/test_litellm.py | 2 +- 2 files changed, 8 insertions(+), 143 deletions(-) diff --git a/notdiamond/toolkit/litellm/main.py b/notdiamond/toolkit/litellm/main.py index d532c85e..fad018c6 100644 --- a/notdiamond/toolkit/litellm/main.py +++ b/notdiamond/toolkit/litellm/main.py @@ -14,18 +14,12 @@ import openai import tiktoken from litellm import ( # type: ignore - Logging, client, exception_type, get_litellm_params, get_optional_params, ) from litellm._logging import verbose_logger -from litellm.caching import disable_cache, enable_cache, update_cache -from litellm.integrations.custom_logger import CustomLogger -from litellm.litellm_core_utils.litellm_logging import ( - Logging as LiteLLMLoggingObj, -) from litellm.llms import ( aleph_alpha, baseten, @@ -43,33 +37,10 @@ vllm, ) from litellm.llms.AI21 import completion as ai21 -from litellm.llms.anthropic.chat import AnthropicChatCompletion -from litellm.llms.anthropic.completion import AnthropicTextCompletion -from litellm.llms.azure_text import AzureTextCompletion -from litellm.llms.AzureOpenAI.audio_transcriptions import ( - AzureAudioTranscription, -) -from litellm.llms.AzureOpenAI.azure import ( - AzureChatCompletion, - _check_dynamic_azure_params, -) -from litellm.llms.bedrock import ( - image_generation as bedrock_image_generation, # type: ignore -) -from litellm.llms.bedrock.chat import BedrockConverseLLM, BedrockLLM -from litellm.llms.bedrock.embed.embedding import BedrockEmbedding +from litellm.llms.AzureOpenAI.azure import _check_dynamic_azure_params from litellm.llms.cohere import chat as cohere_chat from litellm.llms.cohere import completion as cohere_completion # type: ignore -from litellm.llms.cohere import embed as cohere_embed from litellm.llms.custom_llm import CustomLLM, custom_chat_llm_router -from litellm.llms.databricks.chat import DatabricksChatCompletion -from litellm.llms.huggingface_restapi import Huggingface -from litellm.llms.OpenAI.audio_transcriptions import OpenAIAudioTranscription -from litellm.llms.OpenAI.openai import ( - OpenAIChatCompletion, - OpenAITextCompletion, -) -from litellm.llms.predibase import PredibaseChatCompletion from litellm.llms.prompt_templates.factory import ( custom_prompt, function_call_prompt, @@ -77,133 +48,27 @@ prompt_factory, stringify_json_tool_call_content, ) -from litellm.llms.sagemaker.sagemaker import SagemakerLLM -from litellm.llms.text_completion_codestral import CodestralTextCompletion -from litellm.llms.triton import TritonChatCompletion from litellm.llms.vertex_ai_and_google_ai_studio import ( vertex_ai_anthropic, vertex_ai_non_gemini, ) -from litellm.llms.vertex_ai_and_google_ai_studio.gemini.vertex_and_google_ai_studio_gemini import ( - VertexLLM, -) -from litellm.llms.vertex_ai_and_google_ai_studio.gemini_embeddings.batch_embed_content_handler import ( - GoogleBatchEmbeddings, -) -from litellm.llms.vertex_ai_and_google_ai_studio.image_generation.image_generation_handler import ( - VertexImageGeneration, -) -from litellm.llms.vertex_ai_and_google_ai_studio.multimodal_embeddings.embedding_handler import ( - VertexMultimodalEmbedding, -) -from litellm.llms.vertex_ai_and_google_ai_studio.text_to_speech.text_to_speech_handler import ( - VertexTextToSpeechAPI, -) -from litellm.llms.vertex_ai_and_google_ai_studio.vertex_ai_partner_models.main import ( - VertexAIPartnerModels, -) -from litellm.llms.vertex_ai_and_google_ai_studio.vertex_embeddings import ( - embedding_handler as vertex_ai_embedding_handler, -) -from litellm.llms.watsonx import IBMWatsonXAI -from litellm.types.llms.openai import HttpxBinaryResponseContent -from litellm.types.utils import ( - AdapterCompletionStreamWrapper, - ChatCompletionMessageToolCall, - FileTypes, - HiddenParams, - all_litellm_params, -) -from litellm.utils import ( - CustomStreamWrapper, - Usage, - async_mock_completion_streaming_obj, - completion_with_fallbacks, - convert_to_model_response_object, - create_pretrained_tokenizer, - create_tokenizer, - get_optional_params_embeddings, - get_optional_params_image_gen, - get_optional_params_transcription, - get_secret, - mock_completion_streaming_obj, - read_config_args, - supports_httpx_timeout, - token_counter, -) -from pydantic import BaseModel - -encoding = tiktoken.get_encoding("cl100k_base") -from litellm.main import ( - AsyncCompletions, - Chat, - Completions, - LiteLLM, - _async_streaming, - aadapter_completion, - adapter_completion, - aembedding, - ahealth_check, - aimage_generation, - amoderation, - aspeech, - atext_completion, - atranscription, - batch_completion_models, - config_completion, - embedding, - image_generation, - moderation, - print_verbose, - speech, - stream_chunk_builder, - stream_chunk_builder_text_completion, - text_completion, - transcription, -) +from litellm.main import * from litellm.types.router import LiteLLM_Params +from litellm.types.utils import all_litellm_params from litellm.utils import ( - Choices, CustomStreamWrapper, - EmbeddingResponse, - ImageResponse, - Message, ModelResponse, - TextChoices, TextCompletionResponse, - TextCompletionStreamWrapper, - TranscriptionResponse, + completion_with_fallbacks, get_secret, - read_config_args, + supports_httpx_timeout, ) +from pydantic import BaseModel from . import notdiamond_key, provider_list from .litellm_notdiamond import completion as notdiamond_completion -openai_chat_completions = OpenAIChatCompletion() -openai_text_completions = OpenAITextCompletion() -openai_audio_transcriptions = OpenAIAudioTranscription() -databricks_chat_completions = DatabricksChatCompletion() -anthropic_chat_completions = AnthropicChatCompletion() -anthropic_text_completions = AnthropicTextCompletion() -azure_chat_completions = AzureChatCompletion() -azure_text_completions = AzureTextCompletion() -azure_audio_transcriptions = AzureAudioTranscription() -huggingface = Huggingface() -predibase_chat_completions = PredibaseChatCompletion() -codestral_text_completions = CodestralTextCompletion() -triton_chat_completions = TritonChatCompletion() -bedrock_chat_completion = BedrockLLM() -bedrock_converse_chat_completion = BedrockConverseLLM() -bedrock_embedding = BedrockEmbedding() -vertex_chat_completion = VertexLLM() -vertex_multimodal_embedding = VertexMultimodalEmbedding() -vertex_image_generation = VertexImageGeneration() -google_batch_embeddings = GoogleBatchEmbeddings() -vertex_partner_models_chat_completion = VertexAIPartnerModels() -vertex_text_to_speech = VertexTextToSpeechAPI() -watsonxai = IBMWatsonXAI() -sagemaker_llm = SagemakerLLM() +encoding = tiktoken.get_encoding("cl100k_base") def get_api_key(llm_provider: str, dynamic_api_key: Optional[str]): diff --git a/tests/test_toolkit/test_litellm.py b/tests/test_toolkit/test_litellm.py index bfdc6bd8..cca6fa83 100644 --- a/tests/test_toolkit/test_litellm.py +++ b/tests/test_toolkit/test_litellm.py @@ -51,8 +51,8 @@ # {"provider": "replicate", "model": "mixtral-8x7b-instruct-v0.1"}, removed due to replicate side error # {"provider": "replicate", "model": "meta-llama-3-70b-instruct"}, removed due to replicate side error # {"provider": "replicate", "model": "meta-llama-3.1-405b-instruct"}, removed due to replicate side error + # {"provider": "replicate", "model": "meta-llama-3-8b-instruct"}, # - {"provider": "replicate", "model": "meta-llama-3-8b-instruct"}, {"provider": "togetherai", "model": "Mistral-7B-Instruct-v0.2"}, {"provider": "togetherai", "model": "Mixtral-8x7B-Instruct-v0.1"}, {"provider": "togetherai", "model": "Mixtral-8x22B-Instruct-v0.1"},