Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 128 additions & 19 deletions flocks/provider/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from pydantic import BaseModel, Field, PrivateAttr
from enum import Enum
import os
import threading

from flocks.utils.log import Log
from flocks.config.config import Config
Expand All @@ -16,6 +17,60 @@
log = Log.create(service="provider")


def _model_info_signature(model: "ModelInfo") -> tuple:
"""Return a hashable signature of the user-visible fields of a ``ModelInfo``.

Used by :meth:`Provider.apply_config` to short-circuit the
``_config_models`` rebuild when the desired list matches the existing
one. We intentionally compare only the fields that ``apply_config``
populates from ``flocks.json``: id / name / provider_id / capabilities /
pricing / explicit_keys. Other private attributes (e.g. catalog
metadata) are not part of the user contract.
"""
cap = getattr(model, "capabilities", None)
cap_sig = (
(
getattr(cap, "supports_streaming", None),
getattr(cap, "supports_tools", None),
getattr(cap, "supports_vision", None),
getattr(cap, "supports_reasoning", None),
getattr(cap, "interleaved", None),
getattr(cap, "max_tokens", None),
getattr(cap, "context_window", None),
)
if cap is not None
else None
)
pricing = getattr(model, "pricing", None)
pricing_sig = (
(
pricing.get("input") if isinstance(pricing, dict) else None,
pricing.get("output") if isinstance(pricing, dict) else None,
pricing.get("currency") if isinstance(pricing, dict) else None,
)
if pricing is not None
else None
)
explicit = tuple(sorted(getattr(model, "_explicit_keys", set()) or set()))
return (
getattr(model, "id", None),
getattr(model, "name", None),
getattr(model, "provider_id", None),
cap_sig,
pricing_sig,
explicit,
)


def _model_lists_equal(a: List["ModelInfo"], b: List["ModelInfo"]) -> bool:
"""Order-sensitive equality of two ``ModelInfo`` lists by signature."""
if len(a) != len(b):
return False
return all(
_model_info_signature(x) == _model_info_signature(y) for x, y in zip(a, b)
)


class ProviderType(str, Enum):
"""Provider types"""
ANTHROPIC = "anthropic"
Expand Down Expand Up @@ -152,7 +207,14 @@ class Provider:
_providers: Dict[str, "BaseProvider"] = {}
_models: Dict[str, ModelInfo] = {}
_initialized = False

# Guards the lazy provider-registration sequence in ``_ensure_initialized``.
# ``_initialized`` must only flip to ``True`` *after* the registry is
# fully populated, otherwise a thread that wins the cheap fast-path read
# (``if not cls._initialized``) can return while another thread is still
# mid-registration, causing ``Provider.get(...)`` to return ``None`` for
# built-in providers that should already exist.
_init_lock: "threading.Lock" = threading.Lock()

@classmethod
async def init(cls) -> None:
"""Initialize provider system"""
Expand Down Expand Up @@ -187,10 +249,20 @@ def register(cls, provider: "BaseProvider") -> None:

@classmethod
def _ensure_initialized(cls):
"""Ensure providers are initialized"""
if not cls._initialized:
cls._initialized = True

"""Ensure providers are initialized (thread-safe lazy init).

Uses double-checked locking so the hot path (already initialized)
is a single dict-attribute read with no lock contention, while
concurrent first-time callers see a fully populated registry —
not a half-built one — before returning.
"""
if cls._initialized:
return

with cls._init_lock:
if cls._initialized:
return

# Auto-register built-in providers (Batch 1+2)
providers_to_register = [
("openai", "flocks.provider.sdk.openai", "OpenAIProvider"),
Expand Down Expand Up @@ -246,9 +318,14 @@ def _ensure_initialized(cls):
log.debug("provider.auto_registered", {"provider": provider_id})
except Exception as e:
log.warning("provider.register.failed", {"provider": provider_id, "error": str(e)})

# Load dynamic providers from flocks.json
cls._load_dynamic_providers()

# Flip the "initialized" flag only after the registry is fully
# populated. Other threads spinning on the fast-path check above
# then see a complete registry, not a half-built one.
cls._initialized = True

@classmethod
def _load_dynamic_providers(cls):
Expand Down Expand Up @@ -622,25 +699,50 @@ async def apply_config(cls, config: Optional[Any] = None, provider_id: Optional[
if api_key is None and base_url is None and not options_data:
continue

provider.configure(ProviderConfig(
# ----- Idempotent ProviderConfig update -------------------------------
# ``apply_config`` is called from many hot paths: every session
# step (``session.runner._step``), every workflow ``llm.ask``,
# the ``/session/*`` HTTP routes, plus startup. When session and
# workflow run concurrently on different event loops they would
# otherwise rewrite the same ``provider._config`` repeatedly and
# race on the ``_config_models`` rebuild. Skip mutation whenever
# the desired config already matches.
desired_cfg = ProviderConfig(
provider_id=pid,
api_key=api_key,
base_url=base_url,
custom_settings=options_data,
))
)
current_cfg = provider._config
current_unchanged = (
current_cfg is not None
and getattr(current_cfg, "api_key", None) == desired_cfg.api_key
and getattr(current_cfg, "base_url", None) == desired_cfg.base_url
and (getattr(current_cfg, "custom_settings", None) or {})
== (desired_cfg.custom_settings or {})
)
if not current_unchanged:
provider.configure(desired_cfg)

# Update provider display name from flocks.json only for providers
# that support custom naming (openai-compatible instances and custom-* providers).
# Standard catalog providers (anthropic, openai, etc.) always keep their SDK name.
if pid == "openai-compatible" or pid.startswith("custom-"):
config_name = getattr(pconfig, "name", None)
if config_name and isinstance(config_name, str):
if (
config_name
and isinstance(config_name, str)
and provider.name != config_name
):
provider.name = config_name

# Load models from config

# ----- Idempotent _config_models rebuild ------------------------------
# Build the desired model list first, then assign atomically so
# readers (e.g. a session calling ``get_models()`` on another
# thread) never observe a half-rebuilt list.
models_config = getattr(pconfig, "models", None)
if models_config:
provider._config_models = []
desired_models: List[ModelInfo] = []
if isinstance(models_config, dict):
for model_id, model_data in models_config.items():
try:
Expand Down Expand Up @@ -680,19 +782,26 @@ async def apply_config(cls, config: Optional[Any] = None, provider_id: Optional[
pricing=_pricing,
)
model_info._explicit_keys = _explicit_keys
provider._config_models.append(model_info)
desired_models.append(model_info)
except Exception as e:
log.warning("provider.config_model.parse_failed", {
"provider_id": pid,
"model_id": model_id,
"error": str(e)
})

if provider._config_models:
log.info("provider.config_models.loaded", {
"provider_id": pid,
"count": len(provider._config_models)
})

# Skip mutation when the desired list matches what the
# provider already exposes — avoids racing readers on the
# mutable ``_config_models`` attribute and silences noisy
# ``config_models.loaded`` logging on every session step.
existing_models = list(getattr(provider, "_config_models", []) or [])
if not _model_lists_equal(existing_models, desired_models):
provider._config_models = desired_models
if desired_models:
log.info("provider.config_models.loaded", {
"provider_id": pid,
"count": len(desired_models)
})

@classmethod
async def chat(
Expand Down
Loading