Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 61 additions & 2 deletions src/app.rs
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,10 @@ pub struct AppState {
/// Model catalog registry for enriching API responses with model metadata.
/// Loaded from embedded data at startup and optionally synced at runtime.
pub model_catalog: catalog::ModelCatalogRegistry,
/// In-memory cache of model lists fetched from static (config-file) providers.
/// Warmed on startup and refreshed periodically to avoid per-request latency.
pub static_models_cache:
Arc<tokio::sync::RwLock<std::collections::HashMap<String, providers::ModelsResponse>>>,
}

impl AppState {
Expand Down Expand Up @@ -1059,7 +1063,7 @@ impl AppState {
Arc::new(services::ProviderMetricsService::new())
};

Ok(Self {
let result = Ok(Self {
http_client,
config: Arc::new(config),
db,
Expand Down Expand Up @@ -1096,7 +1100,19 @@ impl AppState {
default_org_id,
provider_metrics,
model_catalog,
})
static_models_cache: Arc::new(tokio::sync::RwLock::new(
std::collections::HashMap::new(),
)),
});

// Warm the static models cache so /v1/models is fast from the first request
if let Ok(ref state) = result
&& state.config.features.static_models_cache.enabled()
{
state.warm_static_models_cache().await;
}

result
}

/// Ensure a default user exists for anonymous access when auth is disabled.
Expand Down Expand Up @@ -1816,6 +1832,49 @@ impl AppState {
}
}
}

/// Fetch model lists from all static (config-file) providers in parallel and
/// store them in `self.static_models_cache`. Failures for individual providers
/// are logged and skipped so one slow/broken provider cannot block the rest.
pub async fn warm_static_models_cache(&self) {
use futures::future::join_all;

let futures: Vec<_> = self
.config
.providers
.iter()
.map(|(name, cfg)| {
let name = name.to_owned();
let http = self.http_client.clone();
let cbs = self.circuit_breakers.clone();
async move {
let result = providers::list_models_for_config(cfg, &name, &http, &cbs).await;
(name, result)
}
})
.collect();

let results = join_all(futures).await;

let mut cache = self.static_models_cache.write().await;
cache.retain(|name, _| self.config.providers.get(name).is_some());
for (name, result) in results {
match result {
Ok(response) => {
cache.insert(name, response);
}
Err(e) => {
tracing::warn!(provider = %name, error = %e, "Failed to fetch models for cache warm");
}
}
}
let total_models: usize = cache.values().map(|r| r.data.len()).sum();
tracing::info!(
providers = cache.len(),
models = total_models,
"Static models cache warmed"
);
}
}

#[cfg(feature = "server")]
Expand Down
15 changes: 15 additions & 0 deletions src/cli/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,21 @@ pub(crate) async fn run_server(explicit_config_path: Option<&str>, no_browser: b
None
};

// Refresh the static models cache periodically in the background
// (initial warming already happened in AppState::new)
if config.features.static_models_cache.enabled() {
let interval = config.features.static_models_cache.refresh_interval();
let state_ref = state.clone();
tokio::spawn(async move {
let mut ticker = tokio::time::interval(interval);
ticker.tick().await; // skip the immediate first tick (already warmed)
loop {
ticker.tick().await;
state_ref.warm_static_models_cache().await;
}
});
}

let task_tracker = state.task_tracker.clone();
let app = build_app(&config, state);

Expand Down
50 changes: 50 additions & 0 deletions src/config/features.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,11 @@ pub struct FeaturesConfig {
/// Validates URLs with SSRF protection and enforces size limits.
#[serde(default)]
pub web_fetch: Option<WebFetchConfig>,

/// Static models cache configuration.
/// Caches model lists from config-file providers to avoid per-request latency.
#[serde(default)]
pub static_models_cache: StaticModelsCacheConfig,
}

impl FeaturesConfig {
Expand Down Expand Up @@ -2563,6 +2568,51 @@ fn default_catalog_api_url() -> String {
"https://models.dev/api.json".to_string()
}

/// Configuration for the static models cache.
///
/// Model lists from config-file providers are cached in memory and refreshed
/// periodically so that `/v1/models` does not make upstream HTTP calls on every
/// request.
///
/// ```toml
/// [features.static_models_cache]
/// refresh_interval_secs = 300
/// ```
#[derive(Debug, Clone, Serialize, Deserialize)]
#[cfg_attr(feature = "json-schema", derive(schemars::JsonSchema))]
#[serde(deny_unknown_fields)]
pub struct StaticModelsCacheConfig {
/// How often to refresh the cached model lists, in seconds.
/// Set to 0 to disable caching (every request will query providers directly).
/// Default: 300 (5 minutes).
#[serde(default = "default_static_models_refresh_interval_secs")]
pub refresh_interval_secs: u64,
}

impl Default for StaticModelsCacheConfig {
fn default() -> Self {
Self {
refresh_interval_secs: default_static_models_refresh_interval_secs(),
}
}
}

impl StaticModelsCacheConfig {
/// Whether caching is enabled (interval > 0).
pub fn enabled(&self) -> bool {
self.refresh_interval_secs > 0
}

/// Refresh interval as a `Duration`.
pub fn refresh_interval(&self) -> std::time::Duration {
std::time::Duration::from_secs(self.refresh_interval_secs)
}
}

fn default_static_models_refresh_interval_secs() -> u64 {
300 // 5 minutes
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
6 changes: 6 additions & 0 deletions src/middleware/layers/admin.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2373,6 +2373,9 @@ mod tests {
crate::services::ProviderMetricsService::with_local_metrics(|| None),
),
model_catalog: crate::catalog::ModelCatalogRegistry::new(),
static_models_cache: std::sync::Arc::new(tokio::sync::RwLock::new(
std::collections::HashMap::new(),
)),
}
}

Expand Down Expand Up @@ -2674,6 +2677,9 @@ mod tests {
crate::services::ProviderMetricsService::with_local_metrics(|| None),
),
model_catalog: crate::catalog::ModelCatalogRegistry::new(),
static_models_cache: std::sync::Arc::new(tokio::sync::RwLock::new(
std::collections::HashMap::new(),
)),
}
}

Expand Down
6 changes: 6 additions & 0 deletions src/middleware/layers/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2289,6 +2289,9 @@ mod tests {
crate::services::ProviderMetricsService::with_local_metrics(|| None),
),
model_catalog: crate::catalog::ModelCatalogRegistry::new(),
static_models_cache: std::sync::Arc::new(tokio::sync::RwLock::new(
std::collections::HashMap::new(),
)),
}
}

Expand Down Expand Up @@ -2340,6 +2343,9 @@ mod tests {
crate::services::ProviderMetricsService::with_local_metrics(|| None),
),
model_catalog: crate::catalog::ModelCatalogRegistry::new(),
static_models_cache: std::sync::Arc::new(tokio::sync::RwLock::new(
std::collections::HashMap::new(),
)),
}
}

Expand Down
Loading
Loading