Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 18 additions & 4 deletions src/gaia/llm/lemonade_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,10 +634,13 @@ def __init__(
self.port = port if port is not None else env_port
self.base_url = f"http://{self.host}:{self.port}/api/{LEMONADE_API_VERSION}"
elif base_url is not None:
# base_url parameter provided - normalize and use it
if not base_url.rstrip("/").endswith(f"/api/{LEMONADE_API_VERSION}"):
base_url = f"{base_url.rstrip('/')}/api/{LEMONADE_API_VERSION}"
self.base_url = base_url
# base_url parameter provided - use as-is for non-Lemonade backends
# (e.g. llama.cpp uses /v1, not /api/v1)
self.base_url = base_url.rstrip("/")
# Only append /api/v1 if it looks like a bare Lemonade URL (no path)
parsed_path = urlparse(base_url).path.rstrip("/")
if not parsed_path or parsed_path == "/":
self.base_url = f"{base_url.rstrip('/')}/api/{LEMONADE_API_VERSION}"
# Parse for backwards compatibility with code accessing self.host/self.port
parsed = urlparse(base_url)
self.host = parsed.hostname or DEFAULT_HOST
Expand Down Expand Up @@ -2331,6 +2334,17 @@ def _ensure_model_loaded(self, model: str, auto_download: bool = True) -> None:
return # Skip if auto_download disabled

try:
# Quick check: if /health doesn't have Lemonade-specific fields,
# this is a plain OpenAI-compatible server (llama.cpp, etc.)
# that always has its model loaded. Skip the Lemonade load dance.
try:
health = self.health_check()
if "all_models_loaded" not in health and "model_loaded" not in health:
self.log.debug("Non-Lemonade backend detected — skipping model load check")
return
except Exception:
pass # If health check fails, continue with normal flow

# Check current server state
status = self.get_status()
loaded_models = [m.get("id", "") for m in status.loaded_models]
Expand Down
26 changes: 26 additions & 0 deletions src/gaia/ui/routers/system.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,32 @@ async def system_status(db: ChatDatabase = Depends(get_db)):
if legacy_ctx is not None:
status.model_context_size = legacy_ctx

# Fallback: if /health didn't report model_loaded, check
# /v1/models for llama.cpp-style backends that list loaded models
# under a "models" key (not "data").
if not status.model_loaded:
try:
v1_resp = await client.get(f"{base_url}/models", timeout=5.0)
if v1_resp.status_code == 200:
v1_data = v1_resp.json()
# llama.cpp uses {"models": [...]} with "model" field
for m in v1_data.get("models", []):
m_name = m.get("model") or m.get("id") or ""
if m_name and "embed" not in m_name.lower():
status.model_loaded = m_name
status.model_downloaded = True
break
# Also check OpenAI-style {"data": [...]}
if not status.model_loaded:
for m in v1_data.get("data", []):
m_id = m.get("id", "")
if m_id and "embed" not in m_id.lower():
status.model_loaded = m_id
status.model_downloaded = True
break
except Exception:
pass

# Fetch model catalog for size, labels, and fallback context size
models_resp = await client.get(f"{base_url}/models")
if models_resp.status_code == 200:
Expand Down
40 changes: 34 additions & 6 deletions tests/unit/test_lemonade_model_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,13 @@
class TestEnsureModelLoaded:
"""Test _ensure_model_loaded helper method."""

@patch.object(LemonadeClient, "health_check")
@patch.object(LemonadeClient, "get_status")
@patch.object(LemonadeClient, "load_model")
def test_calls_load_when_model_not_loaded(self, mock_load, mock_status):
def test_calls_load_when_model_not_loaded(self, mock_load, mock_status, mock_health):
"""Verify load_model is called when model not in loaded_models list."""
# Setup
mock_health.return_value = {"status": "ok", "all_models_loaded": [], "model_loaded": "model-a"}
client = LemonadeClient(host="localhost", port=8000)
mock_status.return_value = LemonadeStatus(
url="http://localhost:8000",
Expand All @@ -30,11 +32,13 @@ def test_calls_load_when_model_not_loaded(self, mock_load, mock_status):
"model-b", auto_download=True, prompt=False, ctx_size=None
)

@patch.object(LemonadeClient, "health_check")
@patch.object(LemonadeClient, "get_status")
@patch.object(LemonadeClient, "load_model")
def test_skips_load_when_model_already_loaded(self, mock_load, mock_status):
def test_skips_load_when_model_already_loaded(self, mock_load, mock_status, mock_health):
"""Verify no load_model call when model already in loaded_models list."""
# Setup
mock_health.return_value = {"status": "ok", "all_models_loaded": [], "model_loaded": "model-a"}
client = LemonadeClient(host="localhost", port=8000)
mock_status.return_value = LemonadeStatus(
url="http://localhost:8000",
Expand Down Expand Up @@ -62,11 +66,13 @@ def test_skips_check_when_auto_download_disabled(self, mock_load, mock_status):
mock_status.assert_not_called()
mock_load.assert_not_called()

@patch.object(LemonadeClient, "health_check")
@patch.object(LemonadeClient, "get_status")
@patch.object(LemonadeClient, "load_model")
def test_handles_status_check_error_gracefully(self, mock_load, mock_status):
def test_handles_status_check_error_gracefully(self, mock_load, mock_status, mock_health):
"""Verify errors during status check are logged but don't fail."""
# Setup
mock_health.return_value = {"status": "ok", "all_models_loaded": [], "model_loaded": None}
client = LemonadeClient(host="localhost", port=8000)
mock_status.side_effect = Exception("Connection failed")

Expand Down Expand Up @@ -170,11 +176,13 @@ def test_calls_ensure_model_loaded_before_request(
class TestNoPromptBehavior:
"""Test that model downloads happen without prompting."""

@patch.object(LemonadeClient, "health_check")
@patch.object(LemonadeClient, "get_status")
@patch.object(LemonadeClient, "load_model")
def test_ensure_model_loaded_passes_prompt_false(self, mock_load, mock_status):
def test_ensure_model_loaded_passes_prompt_false(self, mock_load, mock_status, mock_health):
"""Verify _ensure_model_loaded passes prompt=False to avoid user prompts."""
# Setup
mock_health.return_value = {"status": "ok", "all_models_loaded": [], "model_loaded": None}
client = LemonadeClient(host="localhost", port=8000)
mock_status.return_value = LemonadeStatus(
url="http://localhost:8000",
Expand All @@ -195,14 +203,16 @@ def test_ensure_model_loaded_passes_prompt_false(self, mock_load, mock_status):
class TestModelLoadingIntegration:
"""Integration-style tests for model loading behavior."""

@patch.object(LemonadeClient, "health_check")
@patch.object(LemonadeClient, "get_status")
@patch.object(LemonadeClient, "load_model")
@patch("gaia.llm.lemonade_client.OpenAI")
def test_model_loaded_when_not_present(
self, mock_openai_class, mock_load, mock_status
self, mock_openai_class, mock_load, mock_status, mock_health
):
"""Integration test: model is loaded when not in loaded_models list."""
# Setup
mock_health.return_value = {"status": "ok", "all_models_loaded": [], "model_loaded": "different-model"}
client = LemonadeClient(host="localhost", port=8000)

# Mock status to show model NOT loaded
Expand Down Expand Up @@ -239,14 +249,16 @@ def test_model_loaded_when_not_present(
"new-model", auto_download=True, prompt=False, ctx_size=None
)

@patch.object(LemonadeClient, "health_check")
@patch.object(LemonadeClient, "get_status")
@patch.object(LemonadeClient, "load_model")
@patch("gaia.llm.lemonade_client.OpenAI")
def test_model_not_loaded_when_already_present(
self, mock_openai_class, mock_load, mock_status
self, mock_openai_class, mock_load, mock_status, mock_health
):
"""Integration test: no load when model already in loaded_models list."""
# Setup
mock_health.return_value = {"status": "ok", "all_models_loaded": [], "model_loaded": "existing-model"}
client = LemonadeClient(host="localhost", port=8000)

# Mock status to show model IS loaded
Expand Down Expand Up @@ -280,3 +292,19 @@ def test_model_not_loaded_when_already_present(

# Verify load_model was NOT called (model already loaded)
mock_load.assert_not_called()

class TestNonLemonadeBackend:
"""Test non-Lemonade backend detection (llama.cpp direct)."""

@patch.object(LemonadeClient, "health_check")
@patch.object(LemonadeClient, "get_status")
@patch.object(LemonadeClient, "load_model")
def test_skips_load_for_non_lemonade_backend(self, mock_load, mock_status, mock_health):
"""Verify _ensure_model_loaded skips load for non-Lemonade backends (llama.cpp)."""
client = LemonadeClient(host="localhost", port=8000)
mock_health.return_value = {"status": "ok"}

client._ensure_model_loaded("model-a", auto_download=True)

mock_status.assert_not_called()
mock_load.assert_not_called()
Loading