Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 64 additions & 75 deletions crates/openshell-server/src/grpc/provider.rs
Original file line number Diff line number Diff line change
Expand Up @@ -299,12 +299,17 @@ pub(super) async fn delete_provider_record(store: &Store, name: &str) -> Result<
.map_err(|e| Status::internal(format!("delete provider failed: {e}")))
}

async fn sandboxes_using_provider(
store: &Store,
provider_name: &str,
) -> Result<Vec<String>, Status> {
let mut blocking = Vec::new();
let mut offset = 0;
/// Iterate over every `Sandbox` in the store and collect items produced by
/// `f`. `f` receives each decoded sandbox; returning `Some(T)` includes the
/// value in the output, `None` skips it.
///
/// This is the shared pagination kernel used by all sandbox-scan helpers.
async fn scan_sandboxes<T, F>(store: &Store, mut f: F) -> Result<Vec<T>, Status>
where
F: FnMut(Sandbox) -> Option<T>,
{
let mut out = Vec::new();
let mut offset = 0u32;
loop {
let records = store
.list(Sandbox::object_type(), 1000, offset)
Expand All @@ -319,56 +324,50 @@ async fn sandboxes_using_provider(
.map_err(|_| Status::internal("sandbox page size exceeded u32"))?,
)
.ok_or_else(|| Status::internal("sandbox pagination offset overflow"))?;

for record in records {
let sandbox = Sandbox::decode(record.payload.as_slice())
.map_err(|e| Status::internal(format!("decode sandbox failed: {e}")))?;
let Some(spec) = sandbox.spec.as_ref() else {
continue;
};
if spec.providers.iter().any(|name| name == provider_name) {
blocking.push(sandbox.object_name().to_string());
if let Some(item) = f(sandbox) {
out.push(item);
}
}
}
blocking.sort();
blocking.dedup();
Ok(blocking)
Ok(out)
}

async fn sandboxes_using_provider_records(
async fn sandboxes_using_provider(
store: &Store,
provider_name: &str,
) -> Result<Vec<Sandbox>, Status> {
let mut sandboxes = Vec::new();
let mut offset = 0;
loop {
let records = store
.list(Sandbox::object_type(), 1000, offset)
.await
.map_err(|e| Status::internal(format!("list sandboxes failed: {e}")))?;
if records.is_empty() {
break;
) -> Result<Vec<String>, Status> {
let provider_name = provider_name.to_string();
let mut names = scan_sandboxes(store, |sandbox| {
let spec = sandbox.spec.as_ref()?;
if spec.providers.iter().any(|n| n == &provider_name) {
Some(sandbox.object_name().to_string())
} else {
None
}
offset = offset
.checked_add(
u32::try_from(records.len())
.map_err(|_| Status::internal("sandbox page size exceeded u32"))?,
)
.ok_or_else(|| Status::internal("sandbox pagination offset overflow"))?;
})
.await?;
names.sort();
names.dedup();
Ok(names)
}

for record in records {
let sandbox = Sandbox::decode(record.payload.as_slice())
.map_err(|e| Status::internal(format!("decode sandbox failed: {e}")))?;
let Some(spec) = sandbox.spec.as_ref() else {
continue;
};
if spec.providers.iter().any(|name| name == provider_name) {
sandboxes.push(sandbox);
}
async fn sandboxes_using_provider_records(
store: &Store,
provider_name: &str,
) -> Result<Vec<Sandbox>, Status> {
let provider_name = provider_name.to_string();
scan_sandboxes(store, |sandbox| {
let spec = sandbox.spec.as_ref()?;
if spec.providers.iter().any(|n| n == &provider_name) {
Some(sandbox)
} else {
None
}
}
Ok(sandboxes)
})
.await
}

/// Merge an incoming map into an existing map.
Expand Down Expand Up @@ -1045,41 +1044,31 @@ fn has_errors(diagnostics: &[ProfileValidationDiagnostic]) -> bool {
}

async fn sandboxes_using_profile(store: &Store, profile_id: &str) -> Result<Vec<String>, Status> {
let mut blocking = Vec::new();
let mut offset = 0;
loop {
let records = store
.list(Sandbox::object_type(), 1000, offset)
.await
.map_err(|e| Status::internal(format!("list sandboxes failed: {e}")))?;
if records.is_empty() {
break;
}
offset = offset
.checked_add(
u32::try_from(records.len())
.map_err(|_| Status::internal("sandbox page size exceeded u32"))?,
)
.ok_or_else(|| Status::internal("sandbox pagination offset overflow"))?;
// Collect all sandboxes that reference at least one provider — pagination
// is handled by `scan_sandboxes`; the async provider lookup happens below.
let candidates = scan_sandboxes(store, |sandbox| {
let has_providers = sandbox
.spec
.as_ref()
.is_some_and(|s| !s.providers.is_empty());
has_providers.then_some(sandbox)
})
.await?;

for record in records {
let sandbox = Sandbox::decode(record.payload.as_slice())
.map_err(|e| Status::internal(format!("decode sandbox failed: {e}")))?;
let Some(spec) = sandbox.spec.as_ref() else {
let mut blocking = Vec::new();
for sandbox in candidates {
let spec = sandbox.spec.as_ref().expect("filtered by scan_sandboxes");
for provider_name in &spec.providers {
let Some(provider) = store
.get_message_by_name::<Provider>(provider_name)
.await
.map_err(|e| Status::internal(format!("fetch provider failed: {e}")))?
else {
continue;
};
for provider_name in &spec.providers {
let Some(provider) = store
.get_message_by_name::<Provider>(provider_name)
.await
.map_err(|e| Status::internal(format!("fetch provider failed: {e}")))?
else {
continue;
};
if normalize_profile_id(&provider.r#type).as_deref() == Some(profile_id) {
blocking.push(sandbox.object_name().to_string());
break;
}
if normalize_profile_id(&provider.r#type).as_deref() == Some(profile_id) {
blocking.push(sandbox.object_name().to_string());
break;
}
}
}
Expand Down
124 changes: 54 additions & 70 deletions crates/openshell-server/src/grpc/sandbox.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ use std::pin::Pin;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use tokio::net::{TcpListener, TcpStream};
use tokio::sync::mpsc;
use tokio::sync::{mpsc, oneshot};
use tokio_stream::wrappers::ReceiverStream;
use tonic::{Request, Response, Status};
use tracing::{debug, info, warn};
Expand Down Expand Up @@ -737,29 +737,10 @@ pub(super) async fn handle_exec_sandbox(
let (tx, rx) = mpsc::channel::<Result<ExecSandboxEvent, Status>>(256);
tokio::spawn(async move {
// Wait for the supervisor's reverse CONNECT to deliver the relay stream.
let relay_stream = match tokio::time::timeout(std::time::Duration::from_secs(10), relay_rx)
.await
{
Ok(Ok(Ok(stream))) => stream,
Ok(Ok(Err(status))) => {
warn!(sandbox_id = %sandbox_id, channel_id = %channel_id, error = %status.message(), "ExecSandbox: relay target open failed");
let _ = tx.send(Err(status)).await;
return;
}
Ok(Err(_)) => {
warn!(sandbox_id = %sandbox_id, channel_id = %channel_id, "ExecSandbox: relay channel dropped");
let _ = tx
.send(Err(Status::unavailable("relay channel dropped")))
.await;
return;
}
Err(_) => {
warn!(sandbox_id = %sandbox_id, channel_id = %channel_id, "ExecSandbox: relay open timed out");
let _ = tx
.send(Err(Status::deadline_exceeded("relay open timed out")))
.await;
return;
}
let Some(relay_stream) =
await_relay_stream(relay_rx, &tx, &sandbox_id, &channel_id, "ExecSandbox").await
else {
return;
};

if let Err(err) = stream_exec_over_relay(
Expand All @@ -782,6 +763,41 @@ pub(super) async fn handle_exec_sandbox(
Ok(Response::new(ReceiverStream::new(rx)))
}

/// Wait for the supervisor's reverse CONNECT to deliver a relay stream.
///
/// Returns `Some(stream)` on success. On any failure the error is sent on `tx`
/// and `None` is returned; the caller should then `return` immediately.
async fn await_relay_stream<T: Send + 'static>(
relay_rx: oneshot::Receiver<Result<tokio::io::DuplexStream, Status>>,
tx: &mpsc::Sender<Result<T, Status>>,
sandbox_id: &str,
channel_id: &str,
context: &str,
) -> Option<tokio::io::DuplexStream> {
match tokio::time::timeout(std::time::Duration::from_secs(10), relay_rx).await {
Ok(Ok(Ok(stream))) => Some(stream),
Ok(Ok(Err(status))) => {
warn!(sandbox_id = %sandbox_id, channel_id = %channel_id, error = %status.message(), "{context}: relay target open failed");
let _ = tx.send(Err(status)).await;
None
}
Ok(Err(_)) => {
warn!(sandbox_id = %sandbox_id, channel_id = %channel_id, "{context}: relay channel dropped");
let _ = tx
.send(Err(Status::unavailable("relay channel dropped")))
.await;
None
}
Err(_) => {
warn!(sandbox_id = %sandbox_id, channel_id = %channel_id, "{context}: relay open timed out");
let _ = tx
.send(Err(Status::deadline_exceeded("relay open timed out")))
.await;
None
}
}
}

pub(super) async fn handle_forward_tcp(
state: &Arc<ServerState>,
request: Request<tonic::Streaming<TcpForwardFrame>>,
Expand Down Expand Up @@ -831,29 +847,10 @@ pub(super) async fn handle_forward_tcp(
let (tx, rx) = mpsc::channel::<Result<TcpForwardFrame, Status>>(256);
tokio::spawn(async move {
let _connection_guard = connection_guard;
let relay_stream = match tokio::time::timeout(std::time::Duration::from_secs(10), relay_rx)
.await
{
Ok(Ok(Ok(stream))) => stream,
Ok(Ok(Err(status))) => {
warn!(sandbox_id = %sandbox_id, channel_id = %channel_id, error = %status.message(), "ForwardTcp: relay target open failed");
let _ = tx.send(Err(status)).await;
return;
}
Ok(Err(_)) => {
warn!(sandbox_id = %sandbox_id, channel_id = %channel_id, "ForwardTcp: relay channel dropped");
let _ = tx
.send(Err(Status::unavailable("relay channel dropped")))
.await;
return;
}
Err(_) => {
warn!(sandbox_id = %sandbox_id, channel_id = %channel_id, "ForwardTcp: relay open timed out");
let _ = tx
.send(Err(Status::deadline_exceeded("relay open timed out")))
.await;
return;
}
let Some(relay_stream) =
await_relay_stream(relay_rx, &tx, &sandbox_id, &channel_id, "ForwardTcp").await
else {
return;
};

bridge_forward_tcp_stream(inbound, relay_stream, tx, &sandbox_id, &channel_id).await;
Expand Down Expand Up @@ -1179,29 +1176,16 @@ pub(super) async fn handle_exec_sandbox_interactive(

let (tx, rx) = mpsc::channel::<Result<ExecSandboxEvent, Status>>(256);
tokio::spawn(async move {
let relay_stream = match tokio::time::timeout(std::time::Duration::from_secs(10), relay_rx)
.await
{
Ok(Ok(Ok(stream))) => stream,
Ok(Ok(Err(status))) => {
warn!(sandbox_id = %sandbox_id, channel_id = %channel_id, error = %status.message(), "ExecSandboxInteractive: relay target open failed");
let _ = tx.send(Err(status)).await;
return;
}
Ok(Err(_)) => {
warn!(sandbox_id = %sandbox_id, channel_id = %channel_id, "ExecSandboxInteractive: relay channel dropped");
let _ = tx
.send(Err(Status::unavailable("relay channel dropped")))
.await;
return;
}
Err(_) => {
warn!(sandbox_id = %sandbox_id, channel_id = %channel_id, "ExecSandboxInteractive: relay open timed out");
let _ = tx
.send(Err(Status::deadline_exceeded("relay open timed out")))
.await;
return;
}
let Some(relay_stream) = await_relay_stream(
relay_rx,
&tx,
&sandbox_id,
&channel_id,
"ExecSandboxInteractive",
)
.await
else {
return;
};

if let Err(err) = stream_interactive_exec_over_relay(
Expand Down
Loading