Skip to content
Merged

fix #181

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion optillm.py
Original file line number Diff line number Diff line change
Expand Up @@ -726,7 +726,7 @@ def proxy_models():
models_response = client.models.list()

logger.debug('Models retrieved successfully')
return models_response.model_dump(), 200
return models_response, 200
except Exception as e:
logger.error(f"Error fetching models: {str(e)}")
return jsonify({"error": f"Error fetching models: {str(e)}"}), 500
Expand Down
4 changes: 2 additions & 2 deletions optillm/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -575,7 +575,7 @@ def _load_model():
logger.info(f"Using device: {device}")

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)

# Base kwargs for model loading
model_kwargs = {
Expand Down Expand Up @@ -1545,7 +1545,7 @@ def list(self):
try:
import requests
response = requests.get(
"https://huggingface.co/api/models?sort=downloads&direction=-1&limit=100"
"https://huggingface.co/api/models?sort=downloads&direction=-1&filter=text-generation&limit=20"
)
models = response.json()
model_list = []
Expand Down
49 changes: 37 additions & 12 deletions optillm/litellm_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
import os
import time
import litellm
from litellm import completion
from litellm.utils import get_valid_models
from typing import List, Dict, Any, Optional

# Configure litellm to drop unsupported parameters
litellm.drop_params = True

SAFETY_SETTINGS = [
{"category": cat, "threshold": "BLOCK_NONE"}
for cat in [
Expand Down Expand Up @@ -36,16 +41,36 @@ def create(model: str, messages: List[Dict[str, str]], **kwargs):
class Models:
@staticmethod
def list():
# Since LiteLLM doesn't have a direct method to list models,
# we'll return a predefined list of supported models.
# This list can be expanded as needed.
return {
"data": [
{"id": "gpt-4o-mini"},
{"id": "gpt-4o"},
{"id": "command-nightly"},
# Add more models as needed
]
}

try:
# Get all valid models from LiteLLM
valid_models = get_valid_models()

# Format the response to match OpenAI's API format
model_list = []
for model in valid_models:
model_list.append({
"id": model,
"object": "model",
"created": int(time.time()),
"owned_by": "litellm"
})

return {
"object": "list",
"data": model_list
}
except Exception as e:
# Fallback to a basic list if there's an error
print(f"Error fetching LiteLLM models: {str(e)}")
return {
"object": "list",
"data": [
{"id": "gpt-4o-mini", "object": "model", "created": int(time.time()), "owned_by": "litellm"},
{"id": "gpt-4o", "object": "model", "created": int(time.time()), "owned_by": "litellm"},
{"id": "command-nightly", "object": "model", "created": int(time.time()), "owned_by": "litellm"},
{"id": "claude-3-opus-20240229", "object": "model", "created": int(time.time()), "owned_by": "litellm"},
{"id": "claude-3-sonnet-20240229", "object": "model", "created": int(time.time()), "owned_by": "litellm"},
{"id": "gemini-1.5-pro-latest", "object": "model", "created": int(time.time()), "owned_by": "litellm"}
]
}
models = Models()