diff --git a/crates/openshell-server/src/grpc/provider.rs b/crates/openshell-server/src/grpc/provider.rs index cd85e31b1..5f1bf7ce5 100644 --- a/crates/openshell-server/src/grpc/provider.rs +++ b/crates/openshell-server/src/grpc/provider.rs @@ -299,12 +299,17 @@ pub(super) async fn delete_provider_record(store: &Store, name: &str) -> Result< .map_err(|e| Status::internal(format!("delete provider failed: {e}"))) } -async fn sandboxes_using_provider( - store: &Store, - provider_name: &str, -) -> Result, Status> { - let mut blocking = Vec::new(); - let mut offset = 0; +/// Iterate over every `Sandbox` in the store and collect items produced by +/// `f`. `f` receives each decoded sandbox; returning `Some(T)` includes the +/// value in the output, `None` skips it. +/// +/// This is the shared pagination kernel used by all sandbox-scan helpers. +async fn scan_sandboxes(store: &Store, mut f: F) -> Result, Status> +where + F: FnMut(Sandbox) -> Option, +{ + let mut out = Vec::new(); + let mut offset = 0u32; loop { let records = store .list(Sandbox::object_type(), 1000, offset) @@ -319,56 +324,50 @@ async fn sandboxes_using_provider( .map_err(|_| Status::internal("sandbox page size exceeded u32"))?, ) .ok_or_else(|| Status::internal("sandbox pagination offset overflow"))?; - for record in records { let sandbox = Sandbox::decode(record.payload.as_slice()) .map_err(|e| Status::internal(format!("decode sandbox failed: {e}")))?; - let Some(spec) = sandbox.spec.as_ref() else { - continue; - }; - if spec.providers.iter().any(|name| name == provider_name) { - blocking.push(sandbox.object_name().to_string()); + if let Some(item) = f(sandbox) { + out.push(item); } } } - blocking.sort(); - blocking.dedup(); - Ok(blocking) + Ok(out) } -async fn sandboxes_using_provider_records( +async fn sandboxes_using_provider( store: &Store, provider_name: &str, -) -> Result, Status> { - let mut sandboxes = Vec::new(); - let mut offset = 0; - loop { - let records = store - .list(Sandbox::object_type(), 1000, offset) - .await - .map_err(|e| Status::internal(format!("list sandboxes failed: {e}")))?; - if records.is_empty() { - break; +) -> Result, Status> { + let provider_name = provider_name.to_string(); + let mut names = scan_sandboxes(store, |sandbox| { + let spec = sandbox.spec.as_ref()?; + if spec.providers.iter().any(|n| n == &provider_name) { + Some(sandbox.object_name().to_string()) + } else { + None } - offset = offset - .checked_add( - u32::try_from(records.len()) - .map_err(|_| Status::internal("sandbox page size exceeded u32"))?, - ) - .ok_or_else(|| Status::internal("sandbox pagination offset overflow"))?; + }) + .await?; + names.sort(); + names.dedup(); + Ok(names) +} - for record in records { - let sandbox = Sandbox::decode(record.payload.as_slice()) - .map_err(|e| Status::internal(format!("decode sandbox failed: {e}")))?; - let Some(spec) = sandbox.spec.as_ref() else { - continue; - }; - if spec.providers.iter().any(|name| name == provider_name) { - sandboxes.push(sandbox); - } +async fn sandboxes_using_provider_records( + store: &Store, + provider_name: &str, +) -> Result, Status> { + let provider_name = provider_name.to_string(); + scan_sandboxes(store, |sandbox| { + let spec = sandbox.spec.as_ref()?; + if spec.providers.iter().any(|n| n == &provider_name) { + Some(sandbox) + } else { + None } - } - Ok(sandboxes) + }) + .await } /// Merge an incoming map into an existing map. @@ -1045,41 +1044,31 @@ fn has_errors(diagnostics: &[ProfileValidationDiagnostic]) -> bool { } async fn sandboxes_using_profile(store: &Store, profile_id: &str) -> Result, Status> { - let mut blocking = Vec::new(); - let mut offset = 0; - loop { - let records = store - .list(Sandbox::object_type(), 1000, offset) - .await - .map_err(|e| Status::internal(format!("list sandboxes failed: {e}")))?; - if records.is_empty() { - break; - } - offset = offset - .checked_add( - u32::try_from(records.len()) - .map_err(|_| Status::internal("sandbox page size exceeded u32"))?, - ) - .ok_or_else(|| Status::internal("sandbox pagination offset overflow"))?; + // Collect all sandboxes that reference at least one provider — pagination + // is handled by `scan_sandboxes`; the async provider lookup happens below. + let candidates = scan_sandboxes(store, |sandbox| { + let has_providers = sandbox + .spec + .as_ref() + .is_some_and(|s| !s.providers.is_empty()); + has_providers.then_some(sandbox) + }) + .await?; - for record in records { - let sandbox = Sandbox::decode(record.payload.as_slice()) - .map_err(|e| Status::internal(format!("decode sandbox failed: {e}")))?; - let Some(spec) = sandbox.spec.as_ref() else { + let mut blocking = Vec::new(); + for sandbox in candidates { + let spec = sandbox.spec.as_ref().expect("filtered by scan_sandboxes"); + for provider_name in &spec.providers { + let Some(provider) = store + .get_message_by_name::(provider_name) + .await + .map_err(|e| Status::internal(format!("fetch provider failed: {e}")))? + else { continue; }; - for provider_name in &spec.providers { - let Some(provider) = store - .get_message_by_name::(provider_name) - .await - .map_err(|e| Status::internal(format!("fetch provider failed: {e}")))? - else { - continue; - }; - if normalize_profile_id(&provider.r#type).as_deref() == Some(profile_id) { - blocking.push(sandbox.object_name().to_string()); - break; - } + if normalize_profile_id(&provider.r#type).as_deref() == Some(profile_id) { + blocking.push(sandbox.object_name().to_string()); + break; } } } diff --git a/crates/openshell-server/src/grpc/sandbox.rs b/crates/openshell-server/src/grpc/sandbox.rs index 4978687ed..1d5bb05e9 100644 --- a/crates/openshell-server/src/grpc/sandbox.rs +++ b/crates/openshell-server/src/grpc/sandbox.rs @@ -30,7 +30,7 @@ use std::pin::Pin; use std::sync::Arc; use std::sync::atomic::{AtomicBool, Ordering}; use tokio::net::{TcpListener, TcpStream}; -use tokio::sync::mpsc; +use tokio::sync::{mpsc, oneshot}; use tokio_stream::wrappers::ReceiverStream; use tonic::{Request, Response, Status}; use tracing::{debug, info, warn}; @@ -737,29 +737,10 @@ pub(super) async fn handle_exec_sandbox( let (tx, rx) = mpsc::channel::>(256); tokio::spawn(async move { // Wait for the supervisor's reverse CONNECT to deliver the relay stream. - let relay_stream = match tokio::time::timeout(std::time::Duration::from_secs(10), relay_rx) - .await - { - Ok(Ok(Ok(stream))) => stream, - Ok(Ok(Err(status))) => { - warn!(sandbox_id = %sandbox_id, channel_id = %channel_id, error = %status.message(), "ExecSandbox: relay target open failed"); - let _ = tx.send(Err(status)).await; - return; - } - Ok(Err(_)) => { - warn!(sandbox_id = %sandbox_id, channel_id = %channel_id, "ExecSandbox: relay channel dropped"); - let _ = tx - .send(Err(Status::unavailable("relay channel dropped"))) - .await; - return; - } - Err(_) => { - warn!(sandbox_id = %sandbox_id, channel_id = %channel_id, "ExecSandbox: relay open timed out"); - let _ = tx - .send(Err(Status::deadline_exceeded("relay open timed out"))) - .await; - return; - } + let Some(relay_stream) = + await_relay_stream(relay_rx, &tx, &sandbox_id, &channel_id, "ExecSandbox").await + else { + return; }; if let Err(err) = stream_exec_over_relay( @@ -782,6 +763,41 @@ pub(super) async fn handle_exec_sandbox( Ok(Response::new(ReceiverStream::new(rx))) } +/// Wait for the supervisor's reverse CONNECT to deliver a relay stream. +/// +/// Returns `Some(stream)` on success. On any failure the error is sent on `tx` +/// and `None` is returned; the caller should then `return` immediately. +async fn await_relay_stream( + relay_rx: oneshot::Receiver>, + tx: &mpsc::Sender>, + sandbox_id: &str, + channel_id: &str, + context: &str, +) -> Option { + match tokio::time::timeout(std::time::Duration::from_secs(10), relay_rx).await { + Ok(Ok(Ok(stream))) => Some(stream), + Ok(Ok(Err(status))) => { + warn!(sandbox_id = %sandbox_id, channel_id = %channel_id, error = %status.message(), "{context}: relay target open failed"); + let _ = tx.send(Err(status)).await; + None + } + Ok(Err(_)) => { + warn!(sandbox_id = %sandbox_id, channel_id = %channel_id, "{context}: relay channel dropped"); + let _ = tx + .send(Err(Status::unavailable("relay channel dropped"))) + .await; + None + } + Err(_) => { + warn!(sandbox_id = %sandbox_id, channel_id = %channel_id, "{context}: relay open timed out"); + let _ = tx + .send(Err(Status::deadline_exceeded("relay open timed out"))) + .await; + None + } + } +} + pub(super) async fn handle_forward_tcp( state: &Arc, request: Request>, @@ -831,29 +847,10 @@ pub(super) async fn handle_forward_tcp( let (tx, rx) = mpsc::channel::>(256); tokio::spawn(async move { let _connection_guard = connection_guard; - let relay_stream = match tokio::time::timeout(std::time::Duration::from_secs(10), relay_rx) - .await - { - Ok(Ok(Ok(stream))) => stream, - Ok(Ok(Err(status))) => { - warn!(sandbox_id = %sandbox_id, channel_id = %channel_id, error = %status.message(), "ForwardTcp: relay target open failed"); - let _ = tx.send(Err(status)).await; - return; - } - Ok(Err(_)) => { - warn!(sandbox_id = %sandbox_id, channel_id = %channel_id, "ForwardTcp: relay channel dropped"); - let _ = tx - .send(Err(Status::unavailable("relay channel dropped"))) - .await; - return; - } - Err(_) => { - warn!(sandbox_id = %sandbox_id, channel_id = %channel_id, "ForwardTcp: relay open timed out"); - let _ = tx - .send(Err(Status::deadline_exceeded("relay open timed out"))) - .await; - return; - } + let Some(relay_stream) = + await_relay_stream(relay_rx, &tx, &sandbox_id, &channel_id, "ForwardTcp").await + else { + return; }; bridge_forward_tcp_stream(inbound, relay_stream, tx, &sandbox_id, &channel_id).await; @@ -1179,29 +1176,16 @@ pub(super) async fn handle_exec_sandbox_interactive( let (tx, rx) = mpsc::channel::>(256); tokio::spawn(async move { - let relay_stream = match tokio::time::timeout(std::time::Duration::from_secs(10), relay_rx) - .await - { - Ok(Ok(Ok(stream))) => stream, - Ok(Ok(Err(status))) => { - warn!(sandbox_id = %sandbox_id, channel_id = %channel_id, error = %status.message(), "ExecSandboxInteractive: relay target open failed"); - let _ = tx.send(Err(status)).await; - return; - } - Ok(Err(_)) => { - warn!(sandbox_id = %sandbox_id, channel_id = %channel_id, "ExecSandboxInteractive: relay channel dropped"); - let _ = tx - .send(Err(Status::unavailable("relay channel dropped"))) - .await; - return; - } - Err(_) => { - warn!(sandbox_id = %sandbox_id, channel_id = %channel_id, "ExecSandboxInteractive: relay open timed out"); - let _ = tx - .send(Err(Status::deadline_exceeded("relay open timed out"))) - .await; - return; - } + let Some(relay_stream) = await_relay_stream( + relay_rx, + &tx, + &sandbox_id, + &channel_id, + "ExecSandboxInteractive", + ) + .await + else { + return; }; if let Err(err) = stream_interactive_exec_over_relay(