diff --git a/llmstudio/engine/__init__.py b/llmstudio/engine/__init__.py index 3f503ae3..d541f775 100644 --- a/llmstudio/engine/__init__.py +++ b/llmstudio/engine/__init__.py @@ -1,6 +1,7 @@ import json import os from pathlib import Path +from threading import Event from typing import Any, Dict, List, Optional, Union import uvicorn @@ -78,7 +79,9 @@ def _merge_configs(config1, config2): raise RuntimeError(f"Error in configuration data: {e}") -def create_engine_app(config: EngineConfig = _load_engine_config()) -> FastAPI: +def create_engine_app( + started_event: Event, config: EngineConfig = _load_engine_config() +) -> FastAPI: app = FastAPI( title=ENGINE_TITLE, description=ENGINE_DESCRIPTION, @@ -162,14 +165,15 @@ async def export(request: Request): @app.on_event("startup") async def startup_event(): + started_event.set() print(f"Running LLMstudio Engine on http://{ENGINE_HOST}:{ENGINE_PORT} ") return app -def run_engine_app(): +def run_engine_app(started_event: Event): try: - engine = create_engine_app() + engine = create_engine_app(started_event) uvicorn.run( engine, host=ENGINE_HOST, diff --git a/llmstudio/engine/providers/azure.py b/llmstudio/engine/providers/azure.py index 1bb61516..0f6b711d 100644 --- a/llmstudio/engine/providers/azure.py +++ b/llmstudio/engine/providers/azure.py @@ -120,13 +120,17 @@ async def generate_client( **function_args, **request.parameters.model_dump(), } - # Perform the asynchronous call return await asyncio.to_thread( client.chat.completions.create, **combined_args ) - except openai._exceptions.APIError as e: + except openai._exceptions.APIConnectionError as e: + raise HTTPException( + status_code=404, detail=f"There was an error reaching the endpoint: {e}" + ) + + except openai._exceptions.APIStatusError as e: raise HTTPException(status_code=e.status_code, detail=e.response.json()) def prepare_messages(self, request: AzureRequest): diff --git a/llmstudio/server.py b/llmstudio/server.py index e9643d73..c69a5b96 100644 --- a/llmstudio/server.py +++ b/llmstudio/server.py @@ -1,4 +1,5 @@ import threading +from threading import Event import requests @@ -29,8 +30,10 @@ def is_server_running(host, port, path="/health"): def start_server_component(host, port, run_func, server_name): if not is_server_running(host, port): - thread = threading.Thread(target=run_func, daemon=True) + started_event = Event() + thread = threading.Thread(target=run_func, daemon=True, args=(started_event,)) thread.start() + started_event.wait() # wait for startup, this assumes the event is set somewhere return thread else: print(f"{server_name} server already running on {host}:{port}") @@ -53,7 +56,6 @@ def setup_servers(engine, tracking, ui): TRACKING_HOST, TRACKING_PORT, run_tracking_app, "Tracking" ) - ui_thread = None if ui: ui_thread = start_server_component(UI_HOST, UI_PORT, run_ui_app, "UI") diff --git a/llmstudio/tracking/__init__.py b/llmstudio/tracking/__init__.py index 31a7fbe8..d32bc768 100644 --- a/llmstudio/tracking/__init__.py +++ b/llmstudio/tracking/__init__.py @@ -1,3 +1,5 @@ +from threading import Event + import uvicorn from fastapi import APIRouter, FastAPI from fastapi.middleware.cors import CORSMiddleware @@ -15,7 +17,7 @@ ## Tracking -def create_tracking_app() -> FastAPI: +def create_tracking_app(started_event: Event) -> FastAPI: app = FastAPI( title=TRACKING_TITLE, description=TRACKING_DESCRIPTION, @@ -43,14 +45,15 @@ def health_check(): @app.on_event("startup") async def startup_event(): + started_event.set() print(f"Running LLMstudio Tracking on http://{TRACKING_HOST}:{TRACKING_PORT} ") return app -def run_tracking_app(): +def run_tracking_app(started_event: Event): try: - tracking = create_tracking_app() + tracking = create_tracking_app(started_event) uvicorn.run( tracking, host=TRACKING_HOST, diff --git a/llmstudio/ui/__init__.py b/llmstudio/ui/__init__.py index c2a15c22..1569aa6f 100644 --- a/llmstudio/ui/__init__.py +++ b/llmstudio/ui/__init__.py @@ -2,7 +2,7 @@ import subprocess from pathlib import Path import threading -import webbrowser +from threading import Event from llmstudio.config import UI_PORT @@ -20,6 +20,7 @@ def run_bun_in_thread(): print(f"Error running LLMstudio UI: {e}") -def run_ui_app(): +def run_ui_app(started_event: Event): thread = threading.Thread(target=run_bun_in_thread) thread.start() + started_event.set() #just here for compatibility