From 4aae038ffc9bcca6fe363dbdbcbf739764331783 Mon Sep 17 00:00:00 2001 From: Drew Newberry Date: Fri, 1 May 2026 17:55:31 -0700 Subject: [PATCH] fix(docker): harden supervisor startup and gateway routing Signed-off-by: Drew Newberry --- Cargo.lock | 1 + architecture/build-containers.md | 2 +- architecture/gateway-security.md | 4 +- architecture/gateway.md | 18 +- crates/openshell-core/src/config.rs | 11 +- crates/openshell-driver-docker/Cargo.toml | 1 + crates/openshell-driver-docker/src/lib.rs | 255 ++++++++++++++++-- crates/openshell-driver-docker/src/tests.rs | 203 ++++++++++++-- .../src/sandbox/linux/netns.rs | 83 +++++- crates/openshell-server/src/cli.rs | 36 +-- crates/openshell-server/src/compute/mod.rs | 16 ++ crates/openshell-server/src/lib.rs | 211 +++++++++------ deploy/docker/Dockerfile.images | 2 +- .../helm/openshell/templates/statefulset.yaml | 2 + 14 files changed, 687 insertions(+), 158 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index c6409ab12..bfaa55d93 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3188,6 +3188,7 @@ dependencies = [ "tokio-stream", "tonic", "tracing", + "url", ] [[package]] diff --git a/architecture/build-containers.md b/architecture/build-containers.md index 7886c766c..ba13899cf 100644 --- a/architecture/build-containers.md +++ b/architecture/build-containers.md @@ -9,7 +9,7 @@ The gateway runs the control plane API server. It is deployed as a StatefulSet i - **Docker target**: `gateway` in `deploy/docker/Dockerfile.images` - **Registry**: `ghcr.io/nvidia/openshell/gateway:latest` - **Pulled when**: Cluster startup (the Helm chart triggers the pull) -- **Entrypoint**: `openshell-gateway --port 8080` (gRPC + HTTP, mTLS) +- **Entrypoint**: `openshell-gateway --bind-address 0.0.0.0 --port 8080` (gRPC + HTTP, mTLS) ## Cluster (`openshell/cluster`) diff --git a/architecture/gateway-security.md b/architecture/gateway-security.md index a32c3fb52..36a5d3d6d 100644 --- a/architecture/gateway-security.md +++ b/architecture/gateway-security.md @@ -304,9 +304,9 @@ Traffic flows through several layers from the host to the gateway process: | Container | `30051` | Hardcoded in `crates/openshell-bootstrap/src/docker.rs` | | k3s NodePort | `30051` | `deploy/helm/openshell/values.yaml` (`service.nodePort`) | | k3s Service | `8080` | `deploy/helm/openshell/values.yaml` (`service.port`) | -| Server bind | `8080` | `--port` flag / `OPENSHELL_SERVER_PORT` env var | +| Server bind | `0.0.0.0:8080` in deployed containers | `--bind-address 0.0.0.0 --port 8080` / `OPENSHELL_BIND_ADDRESS` + `OPENSHELL_SERVER_PORT` | -Docker maps `host_port → 30051/tcp`. Inside k3s, the NodePort service maps `30051 → 8080 (pod port)`. The server binds `0.0.0.0:8080`. +Docker maps `host_port → 30051/tcp`. Inside k3s, the NodePort service maps `30051 → 8080 (pod port)`. The deployed gateway container binds `0.0.0.0:8080` explicitly. ## Security Model Summary diff --git a/architecture/gateway.md b/architecture/gateway.md index 8e2724bc6..ab865d6ef 100644 --- a/architecture/gateway.md +++ b/architecture/gateway.md @@ -107,14 +107,16 @@ The gateway boots in `cli::run_cli` (`crates/openshell-server/src/cli.rs`) and p - `docker` constructs `openshell-driver-docker` in-process and manages local containers labeled with the configured sandbox namespace. - `vm` spawns the standalone `openshell-driver-vm` binary as a local compute-driver process, resolves it from `--driver-dir`, conventional libexec install paths, or a sibling of the gateway binary, connects to it over a Unix domain socket, and keeps the libkrun/rootfs runtime out of the gateway binary. 3. Build `ServerState` (shared via `Arc` across all handlers), including a fresh `SupervisorSessionRegistry`. - 4. **Spawn background tasks**: + 4. Resume persisted sandboxes that were stopped during the previous gateway shutdown. + 5. **Spawn background tasks**: - `ComputeRuntime::spawn_watchers` -- consumes the compute-driver watch stream, republishes platform events, and runs a periodic `ListSandboxes` snapshot reconcile. - `ssh_tunnel::spawn_session_reaper` -- sweeps expired or revoked SSH session tokens from the store hourly. - `supervisor_session::spawn_relay_reaper` -- sweeps orphaned pending relay channels every 30 seconds. - 5. Create `MultiplexService`. - 6. Bind `TcpListener` on `config.bind_address`. - 7. Optionally create `TlsAcceptor` from cert/key files. - 8. Enter the accept loop: for each connection, spawn a tokio task that optionally performs a TLS handshake, then calls `MultiplexService::serve()`. + 6. Create `MultiplexService`. + 7. Bind the primary gateway listener and any compute-driver requested listeners. Docker requests the Docker bridge gateway address with the normal gateway port, so sandbox containers can call back over the bridge without joining the host network. + 8. Bind optional health and metrics listeners. + 9. Optionally create `TlsAcceptor` from cert/key files. + 10. Spawn a task per gateway listener. Each accepted connection optionally performs a TLS handshake, then calls `MultiplexService::serve()`. ## Configuration @@ -122,8 +124,8 @@ All configuration is via CLI flags with environment variable fallbacks. The `--d | Flag | Env Var | Default | Description | |------|---------|---------|-------------| +| `--bind-address` | `OPENSHELL_BIND_ADDRESS` | `127.0.0.1` | IP address for gateway, health, and metrics listeners. Container deployments pass `0.0.0.0` explicitly. | | `--port` | `OPENSHELL_SERVER_PORT` | `8080` | TCP listen port | -| `--bind-address` | `OPENSHELL_BIND_ADDRESS` | `0.0.0.0` | Address for the main gateway listener | | `--log-level` | `OPENSHELL_LOG_LEVEL` | `info` | Tracing log level filter | | `--tls-cert` | `OPENSHELL_TLS_CERT` | None | Path to PEM certificate file | | `--tls-key` | `OPENSHELL_TLS_KEY` | None | Path to PEM private key file | @@ -136,6 +138,7 @@ All configuration is via CLI flags with environment variable fallbacks. The `--d | `--sandbox-image` | `OPENSHELL_SANDBOX_IMAGE` | None | Default container image for sandbox pods | | `--grpc-endpoint` | `OPENSHELL_GRPC_ENDPOINT` | None | gRPC endpoint reachable from within the cluster (for supervisor callbacks) | | `--drivers` | `OPENSHELL_DRIVERS` | `kubernetes` | Compute backend to use. Current options are `kubernetes`, `docker`, and `vm`. | +| `--docker-network-name` | `OPENSHELL_DOCKER_NETWORK_NAME` | `openshell-docker` | Docker bridge network that local Docker sandboxes join | | `--vm-driver-state-dir` | `OPENSHELL_VM_DRIVER_STATE_DIR` | `target/openshell-vm-driver` | Host directory for VM sandbox rootfs, console logs, and runtime state | | `--driver-dir` | `OPENSHELL_DRIVER_DIR` | unset | Override directory for `openshell-driver-vm`. When unset, the gateway searches `~/.local/libexec/openshell`, `/usr/libexec/openshell`, `/usr/local/libexec/openshell`, `/usr/local/libexec`, then a sibling binary. | | `--vm-krun-log-level` | `OPENSHELL_VM_KRUN_LOG_LEVEL` | `1` | libkrun log level for VM helper processes | @@ -608,6 +611,9 @@ The gateway reaches the sandbox exclusively through the supervisor-initiated `Co The Docker driver (`crates/openshell-driver-docker/src/lib.rs`) is an in-process compute backend for local standalone gateways. It creates one Docker container per sandbox, labels each container with `openshell.ai/managed-by=openshell`, `openshell.ai/sandbox-id`, `openshell.ai/sandbox-name`, and `openshell.ai/sandbox-namespace`, and bind-mounts a Linux `openshell-sandbox` supervisor binary into the container. - **Create**: Pulls or validates the sandbox image according to `sandbox_image_pull_policy`, creates a labeled container, mounts the supervisor binary and optional TLS material, and starts the container with the supervisor as entrypoint. +- **Bridge networking**: Ensures a local Docker bridge network exists (`openshell-docker` by default) and starts every sandbox container on that network instead of using `network_mode=host`. +- **Gateway callback routing**: On native Linux Docker, injects `host.openshell.internal` with the bridge gateway IP and reports that bridge gateway IP plus the normal gateway port to `run_server()` as an extra listener. If the primary listener already binds the wildcard address for that port, the extra address is covered and is not bound a second time. On Docker Desktop, the bridge gateway IP belongs to Docker Desktop's VM rather than the macOS/Windows host, so the driver maps `host.openshell.internal` to Docker's `host-gateway` alias and does not request an extra listener. `OPENSHELL_ENDPOINT` inside Docker sandboxes uses the configured scheme and points at `host.openshell.internal:` in both cases. +- **Environment ownership**: Merges template and spec environment first, then overwrites driver-owned supervisor variables, including `PATH`, `OPENSHELL_ENDPOINT`, `OPENSHELL_SANDBOX_ID`, `OPENSHELL_SSH_SOCKET_PATH`, and `OPENSHELL_SANDBOX_COMMAND`. This keeps privileged supervisor setup from resolving helper binaries through a user-controlled search path. - **List/Get/Watch**: Reads labeled containers in the configured sandbox namespace and derives driver-native sandbox status from Docker state plus supervisor relay readiness. - **Stop**: Stops the matching labeled container without deleting it. - **Delete**: Force-removes the matching labeled container. diff --git a/crates/openshell-core/src/config.rs b/crates/openshell-core/src/config.rs index cf4190975..1ec06677b 100644 --- a/crates/openshell-core/src/config.rs +++ b/crates/openshell-core/src/config.rs @@ -31,6 +31,9 @@ pub const DEFAULT_SSH_HANDSHAKE_SKEW_SECS: u64 = 300; /// Default Podman bridge network name. pub const DEFAULT_NETWORK_NAME: &str = "openshell"; +/// Default Docker bridge network name for local sandboxes. +pub const DEFAULT_DOCKER_NETWORK_NAME: &str = "openshell-docker"; + /// Default OCI image for the openshell-sandbox supervisor binary. pub const DEFAULT_SUPERVISOR_IMAGE: &str = "openshell/supervisor:latest"; @@ -515,7 +518,7 @@ impl Config { } fn default_bind_address() -> SocketAddr { - "0.0.0.0:8080".parse().expect("valid default address") + "127.0.0.1:8080".parse().expect("valid default address") } fn default_log_level() -> String { @@ -589,6 +592,12 @@ mod tests { assert!(err.contains("unsupported compute driver 'firecracker'")); } + #[test] + fn config_defaults_to_loopback_bind_address() { + let expected: SocketAddr = "127.0.0.1:8080".parse().expect("valid address"); + assert_eq!(Config::new(None).bind_address, expected); + } + #[test] fn config_new_disables_health_bind_by_default() { let cfg = Config::new(None); diff --git a/crates/openshell-driver-docker/Cargo.toml b/crates/openshell-driver-docker/Cargo.toml index ee917b78d..79d4fb37d 100644 --- a/crates/openshell-driver-docker/Cargo.toml +++ b/crates/openshell-driver-docker/Cargo.toml @@ -22,6 +22,7 @@ bytes = { workspace = true } bollard = { version = "0.20" } tar = "0.4" tempfile = "3" +url = { workspace = true } [lints] workspace = true diff --git a/crates/openshell-driver-docker/src/lib.rs b/crates/openshell-driver-docker/src/lib.rs index 8b8df5b89..b1f1c421f 100644 --- a/crates/openshell-driver-docker/src/lib.rs +++ b/crates/openshell-driver-docker/src/lib.rs @@ -8,8 +8,9 @@ use bollard::Docker; use bollard::errors::Error as BollardError; use bollard::models::{ - ContainerCreateBody, ContainerSummary, ContainerSummaryStateEnum, DeviceRequest, HostConfig, - Mount, MountTypeEnum, RestartPolicy, RestartPolicyNameEnum, + ContainerCreateBody, ContainerSummary, ContainerSummaryStateEnum, DeviceRequest, + EndpointSettings, HostConfig, Mount, MountTypeEnum, NetworkCreateRequest, NetworkingConfig, + RestartPolicy, RestartPolicyNameEnum, SystemInfo, }; use bollard::query_parameters::{ CreateContainerOptionsBuilder, CreateImageOptions, DownloadFromContainerOptionsBuilder, @@ -17,7 +18,9 @@ use bollard::query_parameters::{ }; use bytes::Bytes; use futures::{Stream, StreamExt}; -use openshell_core::config::{CDI_GPU_DEVICE_ALL, DEFAULT_STOP_TIMEOUT_SECS}; +use openshell_core::config::{ + CDI_GPU_DEVICE_ALL, DEFAULT_DOCKER_NETWORK_NAME, DEFAULT_STOP_TIMEOUT_SECS, +}; use openshell_core::proto::compute::v1::{ CreateSandboxRequest, CreateSandboxResponse, DeleteSandboxRequest, DeleteSandboxResponse, DriverCondition, DriverSandbox, DriverSandboxStatus, DriverSandboxTemplate, @@ -30,6 +33,7 @@ use openshell_core::proto::compute::v1::{ use openshell_core::{Config, Error, Result as CoreResult}; use std::collections::HashMap; use std::io::Read; +use std::net::{IpAddr, SocketAddr}; use std::path::{Path, PathBuf}; use std::pin::Pin; use std::sync::Arc; @@ -38,6 +42,7 @@ use tokio::sync::{broadcast, mpsc}; use tokio_stream::wrappers::ReceiverStream; use tonic::{Request, Response, Status}; use tracing::{info, warn}; +use url::Url; const WATCH_BUFFER: usize = 128; const WATCH_POLL_INTERVAL: Duration = Duration::from_secs(2); @@ -54,7 +59,10 @@ const TLS_CA_MOUNT_PATH: &str = "/etc/openshell/tls/client/ca.crt"; const TLS_CERT_MOUNT_PATH: &str = "/etc/openshell/tls/client/tls.crt"; const TLS_KEY_MOUNT_PATH: &str = "/etc/openshell/tls/client/tls.key"; const SANDBOX_COMMAND: &str = "sleep infinity"; -const HOST_OPENSHELL_INTERNAL_HOSTS_ENTRY: &str = "host.openshell.internal:127.0.0.1"; +const SUPERVISOR_PATH: &str = "/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin"; +const HOST_OPENSHELL_INTERNAL: &str = "host.openshell.internal"; +const HOST_DOCKER_INTERNAL: &str = "host.docker.internal"; +const DOCKER_NETWORK_DRIVER: &str = "bridge"; /// Default image holding the Linux `openshell-sandbox` binary. The gateway /// pulls this image and extracts the binary to a host-side cache when no @@ -136,6 +144,9 @@ pub struct DockerComputeConfig { /// Host-side private key for Docker sandbox mTLS. pub guest_tls_key: Option, + + /// Docker bridge network that sandbox containers join. + pub network_name: String, } #[derive(Debug, Clone, PartialEq, Eq)] @@ -151,6 +162,8 @@ struct DockerDriverRuntimeConfig { image_pull_policy: String, sandbox_namespace: String, grpc_endpoint: String, + network_name: String, + gateway_route: DockerGatewayRoute, ssh_socket_path: String, stop_timeout_secs: u32, log_level: String, @@ -160,6 +173,15 @@ struct DockerDriverRuntimeConfig { supports_gpu: bool, } +#[derive(Debug, Clone, PartialEq, Eq)] +enum DockerGatewayRoute { + Bridge { + bind_address: SocketAddr, + host_alias_ip: IpAddr, + }, + HostGateway, +} + #[derive(Clone)] pub struct DockerComputeDriver { docker: Arc, @@ -194,12 +216,27 @@ impl DockerComputeDriver { let version = docker.version().await.map_err(|err| { Error::execution(format!("failed to query Docker daemon version: {err}")) })?; - let supports_gpu = docker - .info() - .await - .ok() - .and_then(|info| info.cdi_spec_dirs) + let info = docker.info().await.map_err(|err| { + Error::execution(format!("failed to query Docker daemon info: {err}")) + })?; + let supports_gpu = info + .cdi_spec_dirs + .as_ref() .is_some_and(|dirs| !dirs.is_empty()); + let gateway_port = config.bind_address.port(); + if gateway_port == 0 { + return Err(Error::config( + "docker compute driver requires a fixed non-zero gateway bind port", + )); + } + let network_name = docker_network_name(docker_config)?; + let bridge_gateway_ip = ensure_bridge_network(&docker, &network_name).await?; + let gateway_route = docker_gateway_route(&info, bridge_gateway_ip, gateway_port); + let grpc_endpoint = docker_container_openshell_endpoint( + &config.grpc_endpoint, + HOST_OPENSHELL_INTERNAL, + gateway_port, + ); let daemon_arch = normalize_docker_arch(version.arch.as_deref().unwrap_or_default()); let supervisor_bin = resolve_supervisor_bin(&docker, docker_config, &daemon_arch).await?; let guest_tls = docker_guest_tls_paths(config, docker_config)?; @@ -210,7 +247,9 @@ impl DockerComputeDriver { default_image: config.sandbox_image.clone(), image_pull_policy: config.sandbox_image_pull_policy.clone(), sandbox_namespace: config.sandbox_namespace.clone(), - grpc_endpoint: config.grpc_endpoint.clone(), + grpc_endpoint, + network_name, + gateway_route, ssh_socket_path: config.sandbox_ssh_socket_path.clone(), stop_timeout_secs: DEFAULT_STOP_TIMEOUT_SECS, log_level: config.log_level.clone(), @@ -231,6 +270,14 @@ impl DockerComputeDriver { Ok(driver) } + #[must_use] + pub fn gateway_bind_addresses(&self) -> Vec { + match self.config.gateway_route { + DockerGatewayRoute::Bridge { bind_address, .. } => vec![bind_address], + DockerGatewayRoute::HostGateway => Vec::new(), + } + } + fn capabilities(&self) -> GetCapabilitiesResponse { GetCapabilitiesResponse { driver_name: "docker".to_string(), @@ -830,10 +877,7 @@ fn bind_mount(source: &Path, target: &str, read_only: bool) -> Mount { fn build_environment(sandbox: &DriverSandbox, config: &DockerDriverRuntimeConfig) -> Vec { let mut environment = HashMap::from([ ("HOME".to_string(), "/root".to_string()), - ( - "PATH".to_string(), - "/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin".to_string(), - ), + ("PATH".to_string(), SUPERVISOR_PATH.to_string()), ("TERM".to_string(), "xterm".to_string()), ( "OPENSHELL_LOG_LEVEL".to_string(), @@ -862,6 +906,9 @@ fn build_environment(sandbox: &DriverSandbox, config: &DockerDriverRuntimeConfig "OPENSHELL_SANDBOX_COMMAND".to_string(), SANDBOX_COMMAND.to_string(), ); + // The root supervisor executes namespace helpers during bootstrap; keep + // their search path driver-owned even when the template/spec set PATH. + environment.insert("PATH".to_string(), SUPERVISOR_PATH.to_string()); if config.guest_tls.is_some() { environment.insert( "OPENSHELL_TLS_CA".to_string(), @@ -959,18 +1006,16 @@ fn build_container_create_body( // container layer is redundant relative to those controls // and conflicts with them in this case. security_opt: Some(vec!["apparmor=unconfined".to_string()]), - // Run in the host network namespace so a gateway bound to - // 127.0.0.1 is reachable from the supervisor as 127.0.0.1. - // The supervisor still creates a nested network namespace for - // the sandboxed workload and forces workload traffic through - // its policy proxy. - network_mode: Some("host".to_string()), - // Keep a stable host alias available inside the container without - // requiring users to edit the host's /etc/hosts. In host network - // mode this resolves back to the host loopback gateway. - extra_hosts: Some(vec![HOST_OPENSHELL_INTERNAL_HOSTS_ENTRY.to_string()]), + network_mode: Some(config.network_name.clone()), + extra_hosts: Some(docker_extra_hosts(&config.gateway_route)), ..Default::default() }), + networking_config: Some(NetworkingConfig { + endpoints_config: Some(HashMap::from([( + config.network_name.clone(), + EndpointSettings::default(), + )])), + }), ..Default::default() }) } @@ -999,6 +1044,158 @@ fn sandbox_log_level(sandbox: &DriverSandbox, default_level: &str) -> String { .to_string() } +fn docker_container_openshell_endpoint(endpoint: &str, host: &str, port: u16) -> String { + let Ok(mut url) = Url::parse(endpoint) else { + return endpoint.to_string(); + }; + + if url.set_host(Some(host)).is_ok() && url.set_port(Some(port)).is_ok() { + return url.to_string(); + } + + endpoint.to_string() +} + +fn docker_network_name(config: &DockerComputeConfig) -> CoreResult { + let name = config.network_name.trim(); + if name.is_empty() { + return Ok(DEFAULT_DOCKER_NETWORK_NAME.to_string()); + } + Ok(name.to_string()) +} + +fn docker_gateway_route( + info: &SystemInfo, + bridge_gateway_ip: IpAddr, + port: u16, +) -> DockerGatewayRoute { + if is_docker_desktop(info) { + DockerGatewayRoute::HostGateway + } else { + DockerGatewayRoute::Bridge { + bind_address: SocketAddr::new(bridge_gateway_ip, port), + host_alias_ip: bridge_gateway_ip, + } + } +} + +fn is_docker_desktop(info: &SystemInfo) -> bool { + let operating_system = info + .operating_system + .as_deref() + .unwrap_or_default() + .to_ascii_lowercase(); + if operating_system.contains("docker desktop") { + return true; + } + + info.labels.as_ref().is_some_and(|labels| { + labels + .iter() + .any(|label| label.starts_with("com.docker.desktop.")) + }) +} + +fn docker_extra_hosts(route: &DockerGatewayRoute) -> Vec { + match route { + DockerGatewayRoute::Bridge { host_alias_ip, .. } => vec![ + format!("{HOST_DOCKER_INTERNAL}:{host_alias_ip}"), + format!("{HOST_OPENSHELL_INTERNAL}:{host_alias_ip}"), + ], + DockerGatewayRoute::HostGateway => { + vec![format!("{HOST_OPENSHELL_INTERNAL}:host-gateway")] + } + } +} + +async fn ensure_bridge_network(docker: &Docker, network_name: &str) -> CoreResult { + match docker.inspect_network(network_name, None).await { + Ok(network) => return validate_bridge_network(network_name, &network), + Err(err) if !is_not_found_error(&err) => { + return Err(Error::execution(format!( + "failed to inspect Docker network '{network_name}': {err}" + ))); + } + Err(_) => {} + } + + docker + .create_network(NetworkCreateRequest { + name: network_name.to_string(), + driver: Some(DOCKER_NETWORK_DRIVER.to_string()), + attachable: Some(true), + labels: Some(HashMap::from([( + MANAGED_BY_LABEL_KEY.to_string(), + MANAGED_BY_LABEL_VALUE.to_string(), + )])), + ..Default::default() + }) + .await + .map(|_| ()) + .or_else(|err| { + if is_conflict_error(&err) { + Ok(()) + } else { + Err(Error::execution(format!( + "failed to create Docker network '{network_name}': {err}" + ))) + } + })?; + + let network = docker + .inspect_network(network_name, None) + .await + .map_err(|err| { + Error::execution(format!( + "failed to inspect Docker network '{network_name}' after create: {err}" + )) + })?; + validate_bridge_network(network_name, &network) +} + +fn validate_bridge_network( + network_name: &str, + network: &bollard::models::NetworkInspect, +) -> CoreResult { + if network.driver.as_deref() != Some(DOCKER_NETWORK_DRIVER) { + return Err(Error::config(format!( + "Docker network '{network_name}' must use the '{DOCKER_NETWORK_DRIVER}' driver, found '{}'", + network.driver.as_deref().unwrap_or("unknown") + ))); + } + + docker_bridge_gateway_ip(network_name, network) +} + +fn docker_bridge_gateway_ip( + network_name: &str, + network: &bollard::models::NetworkInspect, +) -> CoreResult { + let Some(configs) = network.ipam.as_ref().and_then(|ipam| ipam.config.as_ref()) else { + return Err(Error::config(format!( + "Docker bridge network '{network_name}' does not expose IPAM gateway configuration" + ))); + }; + + for config in configs { + let Some(gateway) = config.gateway.as_deref() else { + continue; + }; + let ip = gateway.parse::().map_err(|err| { + Error::config(format!( + "Docker bridge network '{network_name}' has invalid gateway '{gateway}': {err}" + )) + })?; + if matches!(ip, IpAddr::V4(_)) { + return Ok(ip); + } + } + + Err(Error::config(format!( + "Docker bridge network '{network_name}' does not have an IPv4 IPAM gateway" + ))) +} + fn docker_resource_limits( template: &DriverSandboxTemplate, ) -> Result { @@ -1776,6 +1973,16 @@ fn is_not_found_error(err: &BollardError) -> bool { ) } +fn is_conflict_error(err: &BollardError) -> bool { + matches!( + err, + BollardError::DockerResponseServerError { + status_code: 409, + .. + } + ) +} + fn is_not_modified_error(err: &BollardError) -> bool { matches!( err, diff --git a/crates/openshell-driver-docker/src/tests.rs b/crates/openshell-driver-docker/src/tests.rs index b016b31eb..ae93f5b66 100644 --- a/crates/openshell-driver-docker/src/tests.rs +++ b/crates/openshell-driver-docker/src/tests.rs @@ -2,10 +2,12 @@ // SPDX-License-Identifier: Apache-2.0 use super::*; +use openshell_core::config::DEFAULT_SERVER_PORT; use openshell_core::proto::compute::v1::{ DriverResourceRequirements, DriverSandboxSpec, DriverSandboxTemplate, }; use std::fs; +use std::net::{IpAddr, Ipv4Addr, SocketAddr}; use tempfile::TempDir; const TLS_MOUNT_DIR: &str = "/etc/openshell/tls/client"; @@ -42,6 +44,14 @@ fn runtime_config() -> DockerDriverRuntimeConfig { image_pull_policy: String::new(), sandbox_namespace: "default".to_string(), grpc_endpoint: "https://localhost:8443".to_string(), + network_name: DEFAULT_DOCKER_NETWORK_NAME.to_string(), + gateway_route: DockerGatewayRoute::Bridge { + bind_address: SocketAddr::new( + IpAddr::V4(Ipv4Addr::new(172, 18, 0, 1)), + DEFAULT_SERVER_PORT, + ), + host_alias_ip: IpAddr::V4(Ipv4Addr::new(172, 18, 0, 1)), + }, ssh_socket_path: "/run/openshell/ssh.sock".to_string(), stop_timeout_secs: DEFAULT_STOP_TIMEOUT_SECS, log_level: "info".to_string(), @@ -57,16 +67,129 @@ fn runtime_config() -> DockerDriverRuntimeConfig { } #[test] -fn build_environment_preserves_loopback_endpoint_for_host_network() { - let mut config = runtime_config(); - config.grpc_endpoint = "http://127.0.0.1:8080".to_string(); +fn container_visible_endpoint_rewrites_loopback_hosts() { + assert_eq!( + docker_container_openshell_endpoint( + "https://localhost:8443", + HOST_OPENSHELL_INTERNAL, + DEFAULT_SERVER_PORT, + ), + "https://host.openshell.internal:8080/" + ); + assert_eq!( + docker_container_openshell_endpoint( + "http://127.0.0.1:8080", + HOST_OPENSHELL_INTERNAL, + DEFAULT_SERVER_PORT, + ), + "http://host.openshell.internal:8080/" + ); + assert_eq!( + docker_container_openshell_endpoint( + "https://gateway.internal:8443", + HOST_OPENSHELL_INTERNAL, + DEFAULT_SERVER_PORT, + ), + "https://host.openshell.internal:8080/" + ); +} + +#[test] +fn docker_bridge_gateway_ip_requires_ipv4_gateway() { + let network = bollard::models::NetworkInspect { + driver: Some(DOCKER_NETWORK_DRIVER.to_string()), + ipam: Some(bollard::models::Ipam { + config: Some(vec![ + bollard::models::IpamConfig { + gateway: Some("fd00::1".to_string()), + ..Default::default() + }, + bollard::models::IpamConfig { + gateway: Some("172.18.0.1".to_string()), + ..Default::default() + }, + ]), + ..Default::default() + }), + ..Default::default() + }; + + assert_eq!( + docker_bridge_gateway_ip(DEFAULT_DOCKER_NETWORK_NAME, &network).unwrap(), + IpAddr::V4(Ipv4Addr::new(172, 18, 0, 1)) + ); + + let ipv6_only_network = bollard::models::NetworkInspect { + driver: Some(DOCKER_NETWORK_DRIVER.to_string()), + ipam: Some(bollard::models::Ipam { + config: Some(vec![bollard::models::IpamConfig { + gateway: Some("fd00::1".to_string()), + ..Default::default() + }]), + ..Default::default() + }), + ..Default::default() + }; + + assert!( + docker_bridge_gateway_ip(DEFAULT_DOCKER_NETWORK_NAME, &ipv6_only_network) + .unwrap_err() + .to_string() + .contains("IPv4 IPAM gateway") + ); +} + +#[test] +fn docker_gateway_route_uses_host_gateway_for_docker_desktop() { + let info = SystemInfo { + operating_system: Some("Docker Desktop".to_string()), + labels: Some(vec![ + "com.docker.desktop.address=unix:///tmp/docker.sock".to_string(), + ]), + ..Default::default() + }; - let env = build_environment(&test_sandbox(), &config); - assert!(env.contains(&"OPENSHELL_ENDPOINT=http://127.0.0.1:8080".to_string())); + assert_eq!( + docker_gateway_route( + &info, + IpAddr::V4(Ipv4Addr::new(172, 18, 0, 1)), + DEFAULT_SERVER_PORT, + ), + DockerGatewayRoute::HostGateway + ); + assert_eq!( + docker_extra_hosts(&DockerGatewayRoute::HostGateway), + vec!["host.openshell.internal:host-gateway".to_string()] + ); +} - config.grpc_endpoint = "https://localhost:8443".to_string(); - let env = build_environment(&test_sandbox(), &config); - assert!(env.contains(&"OPENSHELL_ENDPOINT=https://localhost:8443".to_string())); +#[test] +fn docker_gateway_route_uses_bridge_gateway_for_linux_docker() { + let info = SystemInfo { + operating_system: Some("Ubuntu 24.04 LTS".to_string()), + ..Default::default() + }; + + let route = docker_gateway_route( + &info, + IpAddr::V4(Ipv4Addr::new(172, 18, 0, 1)), + DEFAULT_SERVER_PORT, + ); + + assert_eq!( + route, + DockerGatewayRoute::Bridge { + bind_address: "172.18.0.1:8080".parse().unwrap(), + host_alias_ip: IpAddr::V4(Ipv4Addr::new(172, 18, 0, 1)), + } + ); + assert_eq!( + docker_extra_hosts(&route), + vec![ + "host.docker.internal:172.18.0.1".to_string(), + "host.openshell.internal:172.18.0.1".to_string() + ] + ); } #[test] @@ -123,6 +246,29 @@ fn build_environment_sets_docker_tls_paths() { ); } +#[test] +fn build_environment_keeps_path_driver_controlled() { + let mut sandbox = test_sandbox(); + let spec = sandbox.spec.as_mut().unwrap(); + spec.environment + .insert("PATH".to_string(), "/malicious/spec/bin".to_string()); + spec.template + .as_mut() + .unwrap() + .environment + .insert("PATH".to_string(), "/malicious/template/bin".to_string()); + + let env = build_environment(&sandbox, &runtime_config()); + let path_entries = env + .iter() + .filter(|entry| entry.starts_with("PATH=")) + .collect::>(); + + let expected_path = format!("PATH={SUPERVISOR_PATH}"); + assert_eq!(path_entries.len(), 1); + assert_eq!(path_entries[0], &expected_path); +} + #[test] fn build_mounts_uses_docker_tls_directory() { let mounts = build_mounts(&runtime_config()); @@ -168,13 +314,33 @@ fn build_container_create_body_clears_inherited_cmd() { .and_then(|labels| labels.get(SANDBOX_NAMESPACE_LABEL_KEY)), Some(&"default".to_string()) ); + let host_config = create_body.host_config.as_ref().unwrap(); assert!( + host_config.device_requests.as_ref().is_none(), + "non-GPU containers should not request Docker devices" + ); + assert_eq!( + host_config.security_opt.as_ref(), + Some(&vec!["apparmor=unconfined".to_string()]) + ); + assert_eq!( + host_config.network_mode.as_deref(), + Some(DEFAULT_DOCKER_NETWORK_NAME) + ); + assert_eq!( + host_config.extra_hosts.as_ref(), + Some(&vec![ + "host.docker.internal:172.18.0.1".to_string(), + "host.openshell.internal:172.18.0.1".to_string() + ]) + ); + assert_eq!( create_body - .host_config + .networking_config .as_ref() - .and_then(|host_config| host_config.device_requests.as_ref()) - .is_none(), - "non-GPU containers should not request Docker devices" + .and_then(|config| config.endpoints_config.as_ref()) + .and_then(|endpoints| endpoints.get(DEFAULT_DOCKER_NETWORK_NAME)), + Some(&EndpointSettings::default()) ); } @@ -229,19 +395,22 @@ fn require_sandbox_identifier_rejects_when_id_and_name_are_empty() { } #[test] -fn build_container_create_body_uses_host_network() { +fn build_container_create_body_uses_bridge_network() { let create_body = build_container_create_body(&test_sandbox(), &runtime_config()).unwrap(); let host_config = create_body.host_config.expect("host_config is populated"); assert_eq!( host_config.network_mode, - Some("host".to_string()), - "sandbox must use host networking so 127.0.0.1 reaches the host gateway" + Some(DEFAULT_DOCKER_NETWORK_NAME.to_string()), + "sandbox should join the driver-managed bridge network" ); assert_eq!( host_config.extra_hosts, - Some(vec!["host.openshell.internal:127.0.0.1".to_string()]), - "sandbox should expose a stable host alias without host /etc/hosts edits" + Some(vec![ + "host.docker.internal:172.18.0.1".to_string(), + "host.openshell.internal:172.18.0.1".to_string() + ]), + "sandbox should expose stable host aliases for gateway callbacks" ); } diff --git a/crates/openshell-sandbox/src/sandbox/linux/netns.rs b/crates/openshell-sandbox/src/sandbox/linux/netns.rs index e926335e0..fad98b016 100644 --- a/crates/openshell-sandbox/src/sandbox/linux/netns.rs +++ b/crates/openshell-sandbox/src/sandbox/linux/netns.rs @@ -10,6 +10,7 @@ use miette::{IntoDiagnostic, Result}; use std::net::IpAddr; use std::os::unix::io::RawFd; +use std::path::Path; use std::process::Command; use tracing::{debug, warn}; use uuid::Uuid; @@ -18,6 +19,13 @@ use uuid::Uuid; const SUBNET_PREFIX: &str = "10.200.0"; const HOST_IP_SUFFIX: u8 = 1; const SANDBOX_IP_SUFFIX: u8 = 2; +const IP_SEARCH_PATHS: &[&str] = &["/usr/sbin/ip", "/sbin/ip", "/usr/bin/ip", "/bin/ip"]; +const NSENTER_SEARCH_PATHS: &[&str] = &[ + "/usr/bin/nsenter", + "/bin/nsenter", + "/usr/sbin/nsenter", + "/sbin/nsenter", +]; /// Handle to a network namespace with veth pair. /// @@ -661,14 +669,19 @@ impl Drop for NetworkNamespace { /// Run an `ip` command on the host. fn run_ip(args: &[&str]) -> Result<()> { - debug!(command = %format!("ip {}", args.join(" ")), "Running ip command"); + let ip_path = find_trusted_binary("ip", IP_SEARCH_PATHS)?; - let output = Command::new("ip").args(args).output().into_diagnostic()?; + debug!(command = %format!("{ip_path} {}", args.join(" ")), "Running ip command"); + + let output = Command::new(ip_path) + .args(args) + .output() + .into_diagnostic()?; if !output.status.success() { let stderr = String::from_utf8_lossy(&output.stderr); return Err(miette::miette!( - "ip {} failed: {}", + "{ip_path} {} failed: {}", args.join(" "), stderr.trim() )); @@ -688,15 +701,20 @@ fn run_ip(args: &[&str]) -> Result<()> { /// The supervisor's operations (addr add, link set, route add) are all /// netlink-based and do not need sysfs access. fn run_ip_netns(netns: &str, args: &[&str]) -> Result<()> { + let ip_path = find_trusted_binary("ip", IP_SEARCH_PATHS)?; + let nsenter_path = find_trusted_binary("nsenter", NSENTER_SEARCH_PATHS)?; let ns_path = format!("/var/run/netns/{netns}"); let net_flag = format!("--net={ns_path}"); - let mut full_args = vec![net_flag.as_str(), "--", "ip"]; + let mut full_args = vec![net_flag.as_str(), "--", ip_path]; full_args.extend(args); - debug!(command = %format!("nsenter {}", full_args.join(" ")), "Running ip in namespace via nsenter"); + debug!( + command = %format!("{nsenter_path} {}", full_args.join(" ")), + "Running ip in namespace via nsenter" + ); - let output = Command::new("nsenter") + let output = Command::new(nsenter_path) .args(&full_args) .output() .into_diagnostic()?; @@ -704,7 +722,7 @@ fn run_ip_netns(netns: &str, args: &[&str]) -> Result<()> { if !output.status.success() { let stderr = String::from_utf8_lossy(&output.stderr); return Err(miette::miette!( - "nsenter --net={} ip {} failed: {}", + "{nsenter_path} --net={} {ip_path} {} failed: {}", ns_path, args.join(" "), stderr.trim() @@ -719,6 +737,7 @@ fn run_ip_netns(netns: &str, args: &[&str]) -> Result<()> { /// Uses `nsenter` instead of `ip netns exec` to avoid the sysfs remount /// that fails in rootless container runtimes. See `run_ip_netns` for details. fn run_iptables_netns(netns: &str, iptables_cmd: &str, args: &[&str]) -> Result<()> { + let nsenter_path = find_trusted_binary("nsenter", NSENTER_SEARCH_PATHS)?; let ns_path = format!("/var/run/netns/{netns}"); let net_flag = format!("--net={ns_path}"); @@ -726,11 +745,11 @@ fn run_iptables_netns(netns: &str, iptables_cmd: &str, args: &[&str]) -> Result< full_args.extend(args); debug!( - command = %format!("nsenter {}", full_args.join(" ")), + command = %format!("{nsenter_path} {}", full_args.join(" ")), "Running iptables in namespace via nsenter" ); - let output = Command::new("nsenter") + let output = Command::new(nsenter_path) .args(&full_args) .output() .into_diagnostic()?; @@ -738,7 +757,7 @@ fn run_iptables_netns(netns: &str, iptables_cmd: &str, args: &[&str]) -> Result< if !output.status.success() { let stderr = String::from_utf8_lossy(&output.stderr); return Err(miette::miette!( - "nsenter --net={} {} failed: {}", + "{nsenter_path} --net={} {} failed: {}", ns_path, iptables_cmd, stderr.trim() @@ -754,6 +773,22 @@ fn run_iptables_netns(netns: &str, iptables_cmd: &str, args: &[&str]) -> Result< const IPTABLES_SEARCH_PATHS: &[&str] = &["/usr/sbin/iptables", "/sbin/iptables", "/usr/bin/iptables"]; +fn find_trusted_binary<'a>(name: &str, paths: &'a [&str]) -> Result<&'a str> { + paths + .iter() + .copied() + .find(|path| { + let path = Path::new(path); + path.is_absolute() && path.is_file() + }) + .ok_or_else(|| { + miette::miette!( + "trusted {name} helper not found; checked {}", + paths.join(", ") + ) + }) +} + /// Returns true if xt extension modules (e.g. `xt_comment`) cannot be used /// via the given iptables binary. /// @@ -823,12 +858,12 @@ fn xt_extensions_unavailable(iptables_path: &str) -> bool { fn find_iptables() -> Option { let standard_path = IPTABLES_SEARCH_PATHS .iter() - .find(|path| std::path::Path::new(path).exists()) + .find(|path| Path::new(path).exists()) .copied()?; if xt_extensions_unavailable(standard_path) { let legacy_path = standard_path.replace("iptables", "iptables-legacy"); - if std::path::Path::new(&legacy_path).exists() { + if Path::new(&legacy_path).exists() { debug!( legacy = legacy_path, "xt extensions unavailable; using iptables-legacy" @@ -843,7 +878,7 @@ fn find_iptables() -> Option { /// Find the ip6tables binary path, deriving it from the iptables location. fn find_ip6tables(iptables_path: &str) -> Option { let ip6_path = iptables_path.replace("iptables", "ip6tables"); - if std::path::Path::new(&ip6_path).exists() { + if Path::new(&ip6_path).exists() { Some(ip6_path) } else { None @@ -853,10 +888,32 @@ fn find_ip6tables(iptables_path: &str) -> Option { #[cfg(test)] mod tests { use super::*; + use std::fs; // These tests require root and network namespace support // Run with: sudo cargo test -- --ignored + #[test] + fn find_trusted_binary_uses_absolute_existing_file() { + let tempdir = tempfile::tempdir().unwrap(); + let helper = tempdir.path().join("ip"); + fs::write(&helper, b"test helper").unwrap(); + let helper = helper.to_str().unwrap(); + + assert_eq!( + find_trusted_binary("ip", &["relative-ip", "/missing/ip", helper]).unwrap(), + helper + ); + } + + #[test] + fn find_trusted_binary_rejects_missing_helpers() { + let err = + find_trusted_binary("nsenter", &["relative-nsenter", "/missing/nsenter"]).unwrap_err(); + + assert!(err.to_string().contains("trusted nsenter helper not found")); + } + #[test] #[ignore = "requires root privileges"] fn test_create_and_drop_namespace() { diff --git a/crates/openshell-server/src/cli.rs b/crates/openshell-server/src/cli.rs index ae90c8b34..e1830b70a 100644 --- a/crates/openshell-server/src/cli.rs +++ b/crates/openshell-server/src/cli.rs @@ -7,9 +7,10 @@ use clap::{Command, CommandFactory, FromArgMatches, Parser}; use miette::{IntoDiagnostic, Result}; use openshell_core::ComputeDriverKind; use openshell_core::config::{ - DEFAULT_SERVER_PORT, DEFAULT_SSH_HANDSHAKE_SKEW_SECS, DEFAULT_SSH_PORT, + DEFAULT_DOCKER_NETWORK_NAME, DEFAULT_SERVER_PORT, DEFAULT_SSH_HANDSHAKE_SKEW_SECS, + DEFAULT_SSH_PORT, }; -use std::net::{IpAddr, Ipv4Addr, SocketAddr}; +use std::net::{IpAddr, SocketAddr}; use std::path::PathBuf; use tracing::info; use tracing_subscriber::EnvFilter; @@ -22,18 +23,14 @@ use crate::{run_server, tracing_bus::TracingLogBus}; #[command(version = openshell_core::VERSION)] #[command(about = "OpenShell gRPC/HTTP server", long_about = None)] struct Args { + /// IP address to bind the server, health, and metrics listeners to. + #[arg(long, default_value = "127.0.0.1", env = "OPENSHELL_BIND_ADDRESS")] + bind_address: IpAddr, + /// Port to bind the server to. #[arg(long, default_value_t = DEFAULT_SERVER_PORT, env = "OPENSHELL_SERVER_PORT")] port: u16, - /// Address to bind the server to. - #[arg( - long, - default_value_t = IpAddr::V4(Ipv4Addr::UNSPECIFIED), - env = "OPENSHELL_BIND_ADDRESS" - )] - bind_address: IpAddr, - /// Port for unauthenticated health endpoints (healthz, readyz). /// Set to 0 to disable the dedicated health listener. #[arg(long, default_value_t = 0, env = "OPENSHELL_HEALTH_PORT")] @@ -215,6 +212,14 @@ struct Args { #[arg(long, env = "OPENSHELL_DOCKER_TLS_KEY")] docker_tls_key: Option, + /// Docker bridge network used for sandbox containers. + #[arg( + long, + env = "OPENSHELL_DOCKER_NETWORK_NAME", + default_value = DEFAULT_DOCKER_NETWORK_NAME + )] + docker_network_name: String, + /// Disable TLS entirely — listen on plaintext HTTP. /// Use this when the gateway sits behind a reverse proxy or tunnel /// (e.g. Cloudflare Tunnel) that terminates TLS at the edge. @@ -295,7 +300,7 @@ async fn run_from_args(args: Args) -> Result<()> { EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new(&args.log_level)), ); - let bind = SocketAddr::from((args.bind_address, args.port)); + let bind = SocketAddr::new(args.bind_address, args.port); let tls = if args.disable_tls { None @@ -332,7 +337,7 @@ async fn run_from_args(args: Args) -> Result<()> { args.port )); } - let health_bind = SocketAddr::from(([0, 0, 0, 0], args.health_port)); + let health_bind = SocketAddr::new(args.bind_address, args.health_port); config = config.with_health_bind_address(health_bind); } @@ -349,7 +354,7 @@ async fn run_from_args(args: Args) -> Result<()> { args.health_port )); } - let metrics_bind = SocketAddr::from(([0, 0, 0, 0], args.metrics_port)); + let metrics_bind = SocketAddr::new(args.bind_address, args.metrics_port); config = config.with_metrics_bind_address(metrics_bind); } @@ -416,6 +421,7 @@ async fn run_from_args(args: Args) -> Result<()> { guest_tls_ca: args.docker_tls_ca, guest_tls_cert: args.docker_tls_cert, guest_tls_key: args.docker_tls_key, + network_name: args.docker_network_name, }; if args.disable_tls { @@ -457,10 +463,10 @@ mod tests { } #[test] - fn command_defaults_bind_address_to_all_interfaces() { + fn command_defaults_bind_address_to_loopback() { let args = Args::try_parse_from(["openshell-gateway", "--db-url", "sqlite::memory:"]).unwrap(); - assert_eq!(args.bind_address, IpAddr::V4(Ipv4Addr::UNSPECIFIED)); + assert_eq!(args.bind_address, IpAddr::V4(Ipv4Addr::LOCALHOST)); } #[test] diff --git a/crates/openshell-server/src/compute/mod.rs b/crates/openshell-server/src/compute/mod.rs index df6f88c77..2d6351637 100644 --- a/crates/openshell-server/src/compute/mod.rs +++ b/crates/openshell-server/src/compute/mod.rs @@ -36,6 +36,7 @@ use openshell_driver_podman::{ }; use prost::Message; use std::fmt; +use std::net::SocketAddr; use std::pin::Pin; use std::sync::Arc; use std::time::Duration; @@ -228,6 +229,7 @@ pub struct ComputeRuntime { tracing_log_bus: TracingLogBus, supervisor_sessions: Arc, sync_lock: Arc>, + gateway_bind_addresses: Vec, } impl fmt::Debug for ComputeRuntime { @@ -249,6 +251,7 @@ impl ComputeRuntime { tracing_log_bus: TracingLogBus, supervisor_sessions: Arc, _allows_loopback_endpoints: bool, + gateway_bind_addresses: Vec, ) -> Result { let default_image = driver .get_capabilities(Request::new(GetCapabilitiesRequest {})) @@ -268,6 +271,7 @@ impl ComputeRuntime { tracing_log_bus, supervisor_sessions, sync_lock: Arc::new(Mutex::new(())), + gateway_bind_addresses, }) } @@ -285,6 +289,7 @@ impl ComputeRuntime { .await .map_err(|err| ComputeError::Message(err.to_string()))?, ); + let gateway_bind_addresses = driver.gateway_bind_addresses(); let shutdown_cleanup: Arc = driver.clone(); let startup_resume: Arc = driver.clone(); let driver: SharedComputeDriver = driver; @@ -299,6 +304,7 @@ impl ComputeRuntime { tracing_log_bus, supervisor_sessions, true, + gateway_bind_addresses, ) .await } @@ -326,6 +332,7 @@ impl ComputeRuntime { tracing_log_bus, supervisor_sessions, false, + Vec::new(), ) .await } @@ -351,6 +358,7 @@ impl ComputeRuntime { tracing_log_bus, supervisor_sessions, true, + Vec::new(), ) .await } @@ -378,6 +386,7 @@ impl ComputeRuntime { tracing_log_bus, supervisor_sessions, true, + Vec::new(), ) .await } @@ -387,6 +396,11 @@ impl ComputeRuntime { &self.default_image } + #[must_use] + pub fn gateway_bind_addresses(&self) -> &[SocketAddr] { + &self.gateway_bind_addresses + } + pub async fn validate_sandbox_create(&self, sandbox: &Sandbox) -> Result<(), Status> { let driver_sandbox = driver_sandbox_from_public(sandbox); self.driver @@ -1604,6 +1618,7 @@ pub async fn new_test_runtime(store: Arc) -> ComputeRuntime { tracing_log_bus: TracingLogBus::new(), supervisor_sessions: Arc::new(SupervisorSessionRegistry::new()), sync_lock: Arc::new(Mutex::new(())), + gateway_bind_addresses: Vec::new(), } } @@ -1770,6 +1785,7 @@ mod tests { tracing_log_bus: TracingLogBus::new(), supervisor_sessions: Arc::new(SupervisorSessionRegistry::new()), sync_lock: Arc::new(Mutex::new(())), + gateway_bind_addresses: Vec::new(), } } diff --git a/crates/openshell-server/src/lib.rs b/crates/openshell-server/src/lib.rs index f37ba472f..155643760 100644 --- a/crates/openshell-server/src/lib.rs +++ b/crates/openshell-server/src/lib.rs @@ -43,8 +43,8 @@ use std::io::ErrorKind; use std::net::SocketAddr; use std::sync::{Arc, Mutex}; use std::time::Duration; -use tokio::net::TcpListener; -use tokio::sync::broadcast; +use tokio::net::{TcpListener, TcpStream}; +use tokio::sync::watch; use tracing::{debug, error, info, warn}; use compute::{ComputeRuntime, DockerComputeConfig, VmComputeConfig}; @@ -216,23 +216,18 @@ pub async fn run_server( // Create the multiplexed service let service = MultiplexService::new(state.clone()); - // Bind the primary TCP listener plus any extras requested by drivers. - // The same multiplex service is served on each address so the CLI on - // loopback and sandboxes on a driver-supplied interface can both reach - // the gateway with identical semantics. - let mut listeners: Vec<(SocketAddr, TcpListener)> = Vec::new(); - let primary_listener = TcpListener::bind(config.bind_address) - .await - .map_err(|e| Error::transport(format!("failed to bind to {}: {e}", config.bind_address)))?; - info!(address = %config.bind_address, "Server listening"); - listeners.push((config.bind_address, primary_listener)); - - for extra in &config.extra_bind_addresses { - let extra_listener = TcpListener::bind(*extra) + let mut extra_listener_addresses = config.extra_bind_addresses.clone(); + extra_listener_addresses.extend_from_slice(state.compute.gateway_bind_addresses()); + let gateway_listener_addresses = + gateway_listener_addresses(config.bind_address, &extra_listener_addresses); + let mut gateway_listeners = Vec::with_capacity(gateway_listener_addresses.len()); + for address in gateway_listener_addresses { + let listener = TcpListener::bind(address) .await - .map_err(|e| Error::transport(format!("failed to bind extra address {extra}: {e}")))?; - info!(address = %extra, "Server listening on extra address"); - listeners.push((*extra, extra_listener)); + .map_err(|e| Error::transport(format!("failed to bind to {address}: {e}")))?; + let local_addr = listener.local_addr().unwrap_or(address); + info!(address = %local_addr, "Server listening"); + gateway_listeners.push((listener, local_addr)); } // Bind the unauthenticated health endpoint on a separate port when configured. @@ -291,24 +286,27 @@ pub async fn run_server( None }; - // Coordinate graceful shutdown across every listener: a single broadcast - // channel notifies all accept loops, and a `JoinSet` lets us wait for - // them to drain before returning. - let (shutdown_tx, _) = broadcast::channel::<()>(1); - let mut accept_tasks = tokio::task::JoinSet::new(); - for (addr, listener) in listeners { - let service = service.clone(); - let tls_acceptor = tls_acceptor.clone(); - let mut shutdown_rx = shutdown_tx.subscribe(); - accept_tasks.spawn(async move { - run_accept_loop(addr, listener, service, tls_acceptor, &mut shutdown_rx).await; - }); + let (shutdown_tx, shutdown_rx) = watch::channel(false); + let mut listener_tasks = Vec::with_capacity(gateway_listeners.len()); + for (listener, listen_addr) in gateway_listeners { + listener_tasks.push(tokio::spawn(serve_gateway_listener( + listener, + listen_addr, + service.clone(), + tls_acceptor.clone(), + shutdown_rx.clone(), + ))); } shutdown_signal().await; info!("Shutdown signal received; stopping gateway"); - let _ = shutdown_tx.send(()); - while accept_tasks.join_next().await.is_some() {} + let _ = shutdown_tx.send(true); + + for task in listener_tasks { + if let Err(err) = task.await { + warn!(error = %err, "Gateway listener task failed during shutdown"); + } + } state .compute @@ -319,64 +317,96 @@ pub async fn run_server( Ok(()) } -/// Drive a single listener until either the listener errors fatally or the -/// gateway receives a shutdown signal. -/// -/// All listeners share the same `MultiplexService` and (optional) TLS -/// acceptor, so callers can run multiple instances of this loop in parallel -/// to expose the gateway on more than one bind address without forking the -/// service definition. -async fn run_accept_loop( - bind_addr: SocketAddr, +fn gateway_listener_addresses( + bind_address: SocketAddr, + extra_addresses: &[SocketAddr], +) -> Vec { + let mut addresses = vec![bind_address]; + for address in extra_addresses { + if !addresses + .iter() + .any(|existing| listener_covers(*existing, *address)) + { + addresses.push(*address); + } + } + addresses +} + +fn listener_covers(existing: SocketAddr, requested: SocketAddr) -> bool { + if existing == requested { + return true; + } + if existing.port() != requested.port() { + return false; + } + + match (existing.ip(), requested.ip()) { + (std::net::IpAddr::V4(existing), std::net::IpAddr::V4(_)) => existing.is_unspecified(), + (std::net::IpAddr::V6(existing), std::net::IpAddr::V6(_)) => existing.is_unspecified(), + _ => false, + } +} + +async fn serve_gateway_listener( listener: TcpListener, + listen_addr: SocketAddr, service: MultiplexService, tls_acceptor: Option, - shutdown_rx: &mut broadcast::Receiver<()>, + mut shutdown: watch::Receiver, ) { loop { - let (stream, addr) = tokio::select! { - _ = shutdown_rx.recv() => { - debug!(bind = %bind_addr, "Listener received shutdown"); - return; - } - accepted = listener.accept() => { - match accepted { - Ok(conn) => conn, - Err(e) => { - error!(error = %e, bind = %bind_addr, "Failed to accept connection"); - continue; - } + let accepted = tokio::select! { + changed = shutdown.changed() => { + if changed.is_err() || *shutdown.borrow() { + break; } + continue; } + accepted = listener.accept() => accepted, }; - let service = service.clone(); + let (stream, addr) = match accepted { + Ok(conn) => conn, + Err(e) => { + error!(error = %e, listen = %listen_addr, "Failed to accept connection"); + continue; + } + }; - if let Some(ref acceptor) = tls_acceptor { - let tls_acceptor = acceptor.clone(); - tokio::spawn(async move { - match tls_acceptor.inner().accept(stream).await { - Ok(tls_stream) => { - if let Err(e) = service.serve(tls_stream).await { - error!(error = %e, client = %addr, "Connection error"); - } - } - Err(e) => { - if is_benign_tls_handshake_failure(&e) { - debug!(error = %e, client = %addr, "TLS handshake closed early"); - } else { - error!(error = %e, client = %addr, "TLS handshake failed"); - } + spawn_gateway_connection(stream, addr, service.clone(), tls_acceptor.clone()); + } +} + +fn spawn_gateway_connection( + stream: TcpStream, + addr: SocketAddr, + service: MultiplexService, + tls_acceptor: Option, +) { + if let Some(acceptor) = tls_acceptor { + tokio::spawn(async move { + match acceptor.inner().accept(stream).await { + Ok(tls_stream) => { + if let Err(e) = service.serve(tls_stream).await { + error!(error = %e, client = %addr, "Connection error"); } } - }); - } else { - tokio::spawn(async move { - if let Err(e) = service.serve(stream).await { - error!(error = %e, client = %addr, "Connection error"); + Err(e) => { + if is_benign_tls_handshake_failure(&e) { + debug!(error = %e, client = %addr, "TLS handshake closed early"); + } else { + error!(error = %e, client = %addr, "TLS handshake failed"); + } } - }); - } + } + }); + } else { + tokio::spawn(async move { + if let Err(e) = service.serve(stream).await { + error!(error = %e, client = %addr, "Connection error"); + } + }); } } @@ -560,9 +590,12 @@ fn configured_compute_driver(config: &Config) -> Result { #[cfg(test)] mod tests { - use super::{configured_compute_driver, is_benign_tls_handshake_failure}; + use super::{ + configured_compute_driver, gateway_listener_addresses, is_benign_tls_handshake_failure, + }; use openshell_core::{ComputeDriverKind, Config}; use std::io::{Error, ErrorKind}; + use std::net::SocketAddr; #[test] fn classifies_probe_style_tls_disconnects_as_benign() { @@ -652,4 +685,26 @@ mod tests { ComputeDriverKind::Docker ); } + + #[test] + fn gateway_listener_addresses_skip_driver_address_covered_by_wildcard() { + let primary: SocketAddr = "0.0.0.0:8080".parse().unwrap(); + let docker: SocketAddr = "172.18.0.1:8080".parse().unwrap(); + + assert_eq!( + gateway_listener_addresses(primary, &[docker, docker]), + vec![primary] + ); + } + + #[test] + fn gateway_listener_addresses_include_driver_address_on_distinct_ip() { + let primary: SocketAddr = "127.0.0.1:8080".parse().unwrap(); + let docker: SocketAddr = "172.18.0.1:8080".parse().unwrap(); + + assert_eq!( + gateway_listener_addresses(primary, &[docker, docker]), + vec![primary, docker] + ); + } } diff --git a/deploy/docker/Dockerfile.images b/deploy/docker/Dockerfile.images index ebe5e267e..16fe08ecb 100644 --- a/deploy/docker/Dockerfile.images +++ b/deploy/docker/Dockerfile.images @@ -66,7 +66,7 @@ USER openshell EXPOSE 8080 ENTRYPOINT ["openshell-gateway"] -CMD ["--port", "8080"] +CMD ["--bind-address", "0.0.0.0", "--port", "8080"] # --------------------------------------------------------------------------- # Final supervisor image diff --git a/deploy/helm/openshell/templates/statefulset.yaml b/deploy/helm/openshell/templates/statefulset.yaml index 86f6dc3ed..b87eebad6 100644 --- a/deploy/helm/openshell/templates/statefulset.yaml +++ b/deploy/helm/openshell/templates/statefulset.yaml @@ -47,6 +47,8 @@ spec: image: "{{ .Values.image.repository }}:{{ .Values.image.tag | default .Chart.AppVersion }}" imagePullPolicy: {{ .Values.image.pullPolicy }} args: + - --bind-address + - "0.0.0.0" - --port - {{ .Values.service.port | quote }} - --health-port