Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 66 additions & 20 deletions crates/openshell-cli/src/run.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,9 @@ use openshell_core::proto::{
use openshell_core::settings::{self, SettingValueKind};
use openshell_core::{ObjectId, ObjectName};
use openshell_providers::{
ProviderRegistry, ProviderTypeProfile, detect_provider_from_command, normalize_provider_type,
parse_profile_json, parse_profile_yaml, profile_to_json, profile_to_yaml, profiles_to_json,
profiles_to_yaml,
ProviderRegistry, ProviderTypeProfile, RealDiscoveryContext, detect_provider_from_command,
discover_from_profile, normalize_provider_type, parse_profile_json, parse_profile_yaml,
profile_to_json, profile_to_yaml, profiles_to_json, profiles_to_yaml,
};
use owo_colors::OwoColorize;
use std::collections::{HashMap, HashSet};
Expand Down Expand Up @@ -1670,7 +1670,12 @@ pub async fn sandbox_create(
};
let requested_gpu = gpu || image.as_deref().is_some_and(image_requests_gpu);

let inferred_types: Vec<String> = inferred_provider_type(command).into_iter().collect();
let providers_v2_enabled = gateway_providers_v2_enabled(&mut client).await?;
let inferred_types: Vec<String> = if providers_v2_enabled {
Vec::new()
} else {
inferred_provider_type(command).into_iter().collect()
};
let configured_providers = ensure_required_providers(
&mut client,
providers,
Expand Down Expand Up @@ -3592,9 +3597,8 @@ async fn auto_create_provider(
return Ok(());
}

let registry = ProviderRegistry::new();
let discovered = registry
.discover_existing(provider_type)
let discovered = discover_existing_provider_data(client, provider_type)
.await
.map_err(|err| miette::miette!("failed to discover provider '{provider_type}': {err}"))?;
let Some(discovered) = discovered else {
eprintln!(
Expand Down Expand Up @@ -4055,6 +4059,58 @@ fn service_url_for_gateway(service_url: &str, gateway_endpoint: &str) -> String
service_url.to_string()
}

async fn gateway_providers_v2_enabled(client: &mut crate::tls::GrpcClient) -> Result<bool> {
let response = client
.get_gateway_config(GetGatewayConfigRequest {})
.await
.into_diagnostic()?
.into_inner();
let Some(setting) = response.settings.get(settings::PROVIDERS_V2_ENABLED_KEY) else {
return Ok(false);
};
match setting.value.as_ref() {
Some(setting_value::Value::BoolValue(enabled)) => Ok(*enabled),
None => Ok(false),
Some(_) => Err(miette::miette!(
"gateway setting '{}' has invalid value type; expected bool",
settings::PROVIDERS_V2_ENABLED_KEY
)),
}
}

async fn fetch_provider_profile(
client: &mut crate::tls::GrpcClient,
provider_type: &str,
) -> Result<ProviderProfile> {
client
.get_provider_profile(GetProviderProfileRequest {
id: provider_type.to_string(),
})
.await
.into_diagnostic()?
.into_inner()
.profile
.ok_or_else(|| miette::miette!("provider profile '{provider_type}' missing from response"))
}

async fn discover_existing_provider_data(
client: &mut crate::tls::GrpcClient,
provider_type: &str,
) -> Result<Option<openshell_providers::DiscoveredProvider>> {
if gateway_providers_v2_enabled(client).await? {
let profile = fetch_provider_profile(client, provider_type).await?;
let profile = ProviderTypeProfile::from_proto(&profile);
discover_from_profile(&profile, &RealDiscoveryContext).map_err(|err| {
miette::miette!("failed to discover existing provider data from profile: {err}")
})
} else {
let registry = ProviderRegistry::new();
registry
.discover_existing(provider_type)
.map_err(|err| miette::miette!("failed to discover existing provider data: {err}"))
}
}

pub async fn provider_create(
server: &str,
name: &str,
Expand Down Expand Up @@ -4104,10 +4160,7 @@ pub async fn provider_create(
let mut config_map = parse_key_value_pairs(config, "--config")?;

if from_existing {
let registry = ProviderRegistry::new();
let discovered = registry
.discover_existing(&provider_type)
.map_err(|err| miette::miette!("failed to discover existing provider data: {err}"))?;
let discovered = discover_existing_provider_data(&mut client, &provider_type).await?;
let Some(discovered) = discovered else {
return Err(miette::miette!(
"no existing local credentials/config found for provider type '{provider_type}'"
Expand All @@ -4123,13 +4176,9 @@ pub async fn provider_create(
}

if credential_map.is_empty() {
let allows_refresh_bootstrap = client
.get_provider_profile(GetProviderProfileRequest {
id: provider_type.clone(),
})
let allows_refresh_bootstrap = fetch_provider_profile(&mut client, &provider_type)
.await
.ok()
.and_then(|response| response.into_inner().profile)
.is_some_and(|profile| provider_profile_allows_refresh_bootstrap(&profile));
if !allows_refresh_bootstrap {
return Err(miette::miette!(
Expand Down Expand Up @@ -4872,10 +4921,7 @@ pub async fn provider_update(
.ok_or_else(|| miette::miette!("provider '{name}' not found"))?;

let provider_type = existing.r#type;
let registry = ProviderRegistry::new();
let discovered = registry
.discover_existing(&provider_type)
.map_err(|err| miette::miette!("failed to discover existing provider data: {err}"))?;
let discovered = discover_existing_provider_data(&mut client, &provider_type).await?;
let Some(discovered) = discovered else {
return Err(miette::miette!(
"no existing local credentials/config found for provider type '{provider_type}'"
Expand Down
176 changes: 172 additions & 4 deletions crates/openshell-cli/tests/provider_commands_integration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@ use openshell_core::proto::{
ListSandboxProvidersRequest, ListSandboxProvidersResponse, ListSandboxesRequest,
ListSandboxesResponse, Provider, ProviderCredentialRefresh, ProviderCredentialRefreshStatus,
ProviderCredentialRefreshStrategy, ProviderProfile, ProviderProfileCredential,
ProviderResponse, RevokeSshSessionRequest, RevokeSshSessionResponse,
ProviderProfileDiscovery, ProviderResponse, RevokeSshSessionRequest, RevokeSshSessionResponse,
RotateProviderCredentialRequest, RotateProviderCredentialResponse, Sandbox, SandboxResponse,
SandboxStreamEvent, ServiceStatus, SupervisorMessage, UpdateProviderRequest,
WatchSandboxRequest,
SandboxStreamEvent, ServiceStatus, SettingValue, SupervisorMessage, UpdateProviderRequest,
WatchSandboxRequest, setting_value,
};
use openshell_core::{ObjectId, ObjectName};
use std::collections::HashMap;
Expand All @@ -46,6 +46,7 @@ struct ProviderState {
refresh_requests: Arc<Mutex<Vec<ProviderRefreshRequestLog>>>,
sandbox_providers: Arc<Mutex<HashMap<String, Vec<String>>>>,
sandbox_provider_requests: Arc<Mutex<Vec<SandboxProviderRequestLog>>>,
global_settings: Arc<Mutex<HashMap<String, SettingValue>>>,
}

#[derive(Clone, Debug, PartialEq, Eq)]
Expand Down Expand Up @@ -270,7 +271,10 @@ impl OpenShell for TestOpenShell {
&self,
_request: tonic::Request<GetGatewayConfigRequest>,
) -> Result<Response<GetGatewayConfigResponse>, Status> {
Ok(Response::new(GetGatewayConfigResponse::default()))
Ok(Response::new(GetGatewayConfigResponse {
settings: self.state.global_settings.lock().await.clone(),
settings_revision: 1,
}))
}

async fn get_sandbox_provider_environment(
Expand Down Expand Up @@ -887,6 +891,15 @@ async fn run_server() -> TestServer {
}
}

async fn enable_providers_v2(ts: &TestServer) {
ts.state.global_settings.lock().await.insert(
openshell_core::settings::PROVIDERS_V2_ENABLED_KEY.to_string(),
SettingValue {
value: Some(setting_value::Value::BoolValue(true)),
},
);
}

#[tokio::test]
async fn provider_cli_run_functions_support_full_crud_flow() {
let ts = run_server().await;
Expand Down Expand Up @@ -1145,6 +1158,8 @@ credentials:
env_vars: [CUSTOM_API_KEY]
auth_style: bearer
header_name: authorization
discovery:
credentials: [api_key]
endpoints:
- host: api.custom.example
port: 443
Expand Down Expand Up @@ -1195,6 +1210,159 @@ binaries: [/usr/bin/custom]
.expect("profile delete");
}

#[tokio::test]
async fn provider_create_from_existing_uses_profile_discovery_when_v2_enabled() {
let ts = run_server().await;
enable_providers_v2(&ts).await;
ts.state.profiles.lock().await.insert(
"custom-discovery".to_string(),
ProviderProfile {
id: "custom-discovery".to_string(),
display_name: "Custom Discovery".to_string(),
credentials: vec![ProviderProfileCredential {
name: "api_key".to_string(),
env_vars: vec!["CUSTOM_DISCOVERY_API_KEY".to_string()],
required: true,
..Default::default()
}],
discovery: Some(ProviderProfileDiscovery {
credentials: vec!["api_key".to_string()],
}),
..Default::default()
},
);
let _env = EnvVarGuard::set(&[("CUSTOM_DISCOVERY_API_KEY", "profile-secret")]);

run::provider_create(
&ts.endpoint,
"custom-discovered",
"custom-discovery",
true,
&[],
&[],
&ts.tls,
)
.await
.expect("profile-backed provider create --from-existing");

let provider = ts
.state
.providers
.lock()
.await
.get("custom-discovered")
.cloned()
.expect("custom provider should be stored");
assert_eq!(provider.r#type, "custom-discovery");
assert_eq!(
provider.credentials.get("CUSTOM_DISCOVERY_API_KEY"),
Some(&"profile-secret".to_string())
);
}

#[tokio::test]
async fn provider_create_from_existing_fails_when_profile_discovery_finds_nothing() {
let ts = run_server().await;
enable_providers_v2(&ts).await;
ts.state.profiles.lock().await.insert(
"empty-discovery".to_string(),
ProviderProfile {
id: "empty-discovery".to_string(),
display_name: "Empty Discovery".to_string(),
credentials: vec![ProviderProfileCredential {
name: "api_key".to_string(),
env_vars: vec!["CUSTOM_DISCOVERY_TOKEN_NOT_SET_1460".to_string()],
required: false,
..Default::default()
}],
discovery: Some(ProviderProfileDiscovery {
credentials: vec!["api_key".to_string()],
}),
..Default::default()
},
);

let err = run::provider_create(
&ts.endpoint,
"empty-discovered",
"empty-discovery",
true,
&[],
&[],
&ts.tls,
)
.await
.expect_err("empty profile-backed discovery should fail");

assert!(
err.to_string()
.contains("no existing local credentials/config found"),
"unexpected error: {err}"
);
assert!(
!ts.state
.providers
.lock()
.await
.contains_key("empty-discovered")
);
}

#[tokio::test]
async fn provider_update_from_existing_uses_profile_discovery_when_v2_enabled() {
let ts = run_server().await;
enable_providers_v2(&ts).await;
ts.state.profiles.lock().await.insert(
"custom-update-discovery".to_string(),
ProviderProfile {
id: "custom-update-discovery".to_string(),
display_name: "Custom Update Discovery".to_string(),
credentials: vec![ProviderProfileCredential {
name: "api_key".to_string(),
env_vars: vec!["CUSTOM_UPDATE_DISCOVERY_API_KEY".to_string()],
required: true,
..Default::default()
}],
discovery: Some(ProviderProfileDiscovery {
credentials: vec!["api_key".to_string()],
}),
..Default::default()
},
);
ts.state.providers.lock().await.insert(
"custom-update".to_string(),
Provider {
metadata: Some(openshell_core::proto::datamodel::v1::ObjectMeta {
id: "id-custom-update".to_string(),
name: "custom-update".to_string(),
..Default::default()
}),
r#type: "custom-update-discovery".to_string(),
credentials: HashMap::new(),
config: HashMap::new(),
credential_expires_at_ms: HashMap::new(),
},
);
let _env = EnvVarGuard::set(&[("CUSTOM_UPDATE_DISCOVERY_API_KEY", "updated-profile-secret")]);

run::provider_update(&ts.endpoint, "custom-update", true, &[], &[], &[], &ts.tls)
.await
.expect("profile-backed provider update --from-existing");

let provider = ts
.state
.providers
.lock()
.await
.get("custom-update")
.cloned()
.expect("custom provider should still be stored");
assert_eq!(
provider.credentials.get("CUSTOM_UPDATE_DISCOVERY_API_KEY"),
Some(&"updated-profile-secret".to_string())
);
}

#[tokio::test]
async fn provider_profile_import_from_directory_imports_supported_profile_files() {
let ts = run_server().await;
Expand Down
Loading
Loading