diff --git a/optillm.py b/optillm.py index 498ba2ec..547c45fc 100644 --- a/optillm.py +++ b/optillm.py @@ -726,7 +726,7 @@ def proxy_models(): models_response = client.models.list() logger.debug('Models retrieved successfully') - return models_response.model_dump(), 200 + return models_response, 200 except Exception as e: logger.error(f"Error fetching models: {str(e)}") return jsonify({"error": f"Error fetching models: {str(e)}"}), 500 diff --git a/optillm/inference.py b/optillm/inference.py index b2dc4fba..f7f265c0 100644 --- a/optillm/inference.py +++ b/optillm/inference.py @@ -575,7 +575,7 @@ def _load_model(): logger.info(f"Using device: {device}") # Load tokenizer - tokenizer = AutoTokenizer.from_pretrained(model_id) + tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) # Base kwargs for model loading model_kwargs = { @@ -1545,7 +1545,7 @@ def list(self): try: import requests response = requests.get( - "https://huggingface.co/api/models?sort=downloads&direction=-1&limit=100" + "https://huggingface.co/api/models?sort=downloads&direction=-1&filter=text-generation&limit=20" ) models = response.json() model_list = [] diff --git a/optillm/litellm_wrapper.py b/optillm/litellm_wrapper.py index 5ac62322..7bd93543 100644 --- a/optillm/litellm_wrapper.py +++ b/optillm/litellm_wrapper.py @@ -1,8 +1,13 @@ import os +import time import litellm from litellm import completion +from litellm.utils import get_valid_models from typing import List, Dict, Any, Optional +# Configure litellm to drop unsupported parameters +litellm.drop_params = True + SAFETY_SETTINGS = [ {"category": cat, "threshold": "BLOCK_NONE"} for cat in [ @@ -36,16 +41,36 @@ def create(model: str, messages: List[Dict[str, str]], **kwargs): class Models: @staticmethod def list(): - # Since LiteLLM doesn't have a direct method to list models, - # we'll return a predefined list of supported models. - # This list can be expanded as needed. - return { - "data": [ - {"id": "gpt-4o-mini"}, - {"id": "gpt-4o"}, - {"id": "command-nightly"}, - # Add more models as needed - ] - } - + try: + # Get all valid models from LiteLLM + valid_models = get_valid_models() + + # Format the response to match OpenAI's API format + model_list = [] + for model in valid_models: + model_list.append({ + "id": model, + "object": "model", + "created": int(time.time()), + "owned_by": "litellm" + }) + + return { + "object": "list", + "data": model_list + } + except Exception as e: + # Fallback to a basic list if there's an error + print(f"Error fetching LiteLLM models: {str(e)}") + return { + "object": "list", + "data": [ + {"id": "gpt-4o-mini", "object": "model", "created": int(time.time()), "owned_by": "litellm"}, + {"id": "gpt-4o", "object": "model", "created": int(time.time()), "owned_by": "litellm"}, + {"id": "command-nightly", "object": "model", "created": int(time.time()), "owned_by": "litellm"}, + {"id": "claude-3-opus-20240229", "object": "model", "created": int(time.time()), "owned_by": "litellm"}, + {"id": "claude-3-sonnet-20240229", "object": "model", "created": int(time.time()), "owned_by": "litellm"}, + {"id": "gemini-1.5-pro-latest", "object": "model", "created": int(time.time()), "owned_by": "litellm"} + ] + } models = Models()