diff --git a/.agents/skills/openshell-cli/SKILL.md b/.agents/skills/openshell-cli/SKILL.md index 132c99686..160e9f360 100644 --- a/.agents/skills/openshell-cli/SKILL.md +++ b/.agents/skills/openshell-cli/SKILL.md @@ -421,10 +421,14 @@ Watch for `deny` actions that indicate the user's work is being blocked by polic When denied actions are observed: -1. Pull current policy: `openshell policy get work-session --full > policy.yaml` -2. Modify the policy to allow the blocked actions (use `generate-sandbox-policy` skill for content) -3. Push the update: `openshell policy set work-session --policy policy.yaml --wait` -4. Verify: `openshell policy list work-session` +1. Prefer incremental updates for additive network changes: + `openshell policy update work-session --add-endpoint api.github.com:443:read-only:rest:enforce --binary /usr/bin/gh --wait` + `openshell policy update work-session --add-allow api.github.com:443:POST:/repos/*/issues --wait` +2. Use full YAML replacement when the change is broad or touches non-network fields: + `openshell policy get work-session --full > policy.yaml` + Modify the policy to allow the blocked actions (use `generate-sandbox-policy` skill for content) + `openshell policy set work-session --policy policy.yaml --wait` +3. Verify: `openshell policy list work-session` The user does not need to disconnect -- policy updates are hot-reloaded within ~30 seconds (or immediately when using `--wait`, which polls for confirmation). @@ -543,6 +547,7 @@ $ openshell sandbox upload --help | Create with custom policy | `openshell sandbox create --policy ./p.yaml` | | Connect to sandbox | `openshell sandbox connect ` | | Stream live logs | `openshell logs --tail` | +| Incremental policy update | `openshell policy update --add-endpoint host:443:read-only:rest:enforce --binary /usr/bin/curl --wait` | | Pull current policy | `openshell policy get --full > p.yaml` | | Push updated policy | `openshell policy set --policy p.yaml --wait` | | Policy revision history | `openshell policy list ` | diff --git a/.agents/skills/openshell-cli/cli-reference.md b/.agents/skills/openshell-cli/cli-reference.md index e344f20df..c9c5450a0 100644 --- a/.agents/skills/openshell-cli/cli-reference.md +++ b/.agents/skills/openshell-cli/cli-reference.md @@ -268,9 +268,32 @@ View sandbox logs. Supports one-shot and streaming. ## Policy Commands +### `openshell policy update ` + +Incrementally merge live network policy changes into the current sandbox policy. Multiple flags in one invocation are applied as one atomic batch and create at most one new revision. + +| Flag | Default | Description | +|------|---------|-------------| +| `--add-endpoint ` | repeatable | `host:port[:access[:protocol[:enforcement]]]`. Adds or merges an endpoint. `access`: `read-only`, `read-write`, `full`. `protocol`: `rest`, `sql`. `enforcement`: `enforce`, `audit`. | +| `--remove-endpoint ` | repeatable | `host:port`. Removes the endpoint or just the requested port from a multi-port endpoint. | +| `--add-allow ` | repeatable | `host:port:METHOD:path_glob`. Adds REST allow rules to an existing `protocol: rest` endpoint. | +| `--add-deny ` | repeatable | `host:port:METHOD:path_glob`. Adds REST deny rules to an existing `protocol: rest` endpoint that already has an allow base. | +| `--remove-rule ` | repeatable | Deletes a named network rule. | +| `--binary ` | repeatable | Adds binaries to each `--add-endpoint` rule. Valid only with `--add-endpoint`. | +| `--rule-name ` | none | Overrides the generated rule name. Valid only when exactly one `--add-endpoint` is provided. | +| `--dry-run` | false | Preview the merged policy locally without sending an update to the gateway. | +| `--wait` | false | Wait for the sandbox to confirm the new policy revision is loaded. | +| `--timeout ` | 60 | Timeout for `--wait`. | + +Notes: + +- `--add-allow` and `--add-deny` currently operate only on `protocol: rest` endpoints. +- `--wait` cannot be combined with `--dry-run`. +- Use `policy set` when replacing the full policy or changing static sections. + ### `openshell policy set --policy ` -Update the policy on a live sandbox. Only the dynamic `network_policies` field can be changed at runtime. +Replace the full policy on a live sandbox. Only the dynamic `network_policies` field can be changed at runtime. | Flag | Default | Description | |------|---------|-------------| diff --git a/architecture/security-policy.md b/architecture/security-policy.md index b8bda8f91..8721e8a7b 100644 --- a/architecture/security-policy.md +++ b/architecture/security-policy.md @@ -162,6 +162,22 @@ This guarantees that the same logical policy always produces the same hash regar **Idempotent updates**: `UpdateSandboxPolicy` compares the deterministic hash of the submitted policy against the latest stored revision's hash. If they match, the handler returns the existing version and hash without creating a new revision. The CLI detects this (the returned version equals the pre-call version) and prints `Policy unchanged` instead of `Policy version N submitted`. This makes repeated `policy set` calls safe and idempotent. +### Incremental Merge Updates + +`UpdateConfigRequest.merge_operations` supports batched incremental changes to the dynamic `network_policies` section. The CLI exposes this as `openshell policy update`. + +Supported first-pass operations: + +- `--add-endpoint host:port[:access[:protocol[:enforcement]]]` +- `--remove-endpoint host:port` +- `--remove-rule ` +- `--add-allow host:port:METHOD:path_glob` +- `--add-deny host:port:METHOD:path_glob` + +`--add-allow` and `--add-deny` target existing `protocol: rest` endpoints only. `--binary` may be repeated with `--add-endpoint`, and `--rule-name` is allowed only when exactly one `--add-endpoint` is present. + +Each `openshell policy update` invocation is atomic at the revision level: the CLI sends one `merge_operations` batch, the server merges the whole batch into the latest policy, validates the result, and persists at most one new revision. Concurrency is handled with optimistic retries on the `(sandbox_id, version)` uniqueness boundary. If another writer wins first, the server refetches the latest policy, reapplies the full batch, revalidates it, and retries. This preserves batch atomicity without serializing all sandbox policy writes behind a sandbox-global mutex. + ### Policy Revision Statuses | Status | Meaning | @@ -206,9 +222,20 @@ Failure scenarios that trigger LKG behavior include: ### CLI Commands -The `openshell policy` subcommand group manages live policy updates: +The `openshell policy` subcommand group manages live policy updates through full replacement (`policy set`) and incremental merges (`policy update`): ```bash +# Merge endpoint/rule changes into the current sandbox policy +openshell policy update \ + --add-endpoint api.github.com:443:read-only:rest:enforce \ + --binary /usr/bin/gh \ + --wait + +# Add a REST allow rule to an existing endpoint +openshell policy update \ + --add-allow api.github.com:443:POST:/repos/*/issues \ + --wait + # Push a new policy to a running sandbox openshell policy set --policy updated-policy.yaml @@ -255,6 +282,7 @@ Both `set` and `delete` require interactive confirmation (or `--yes` to bypass). When a global policy is active, sandbox-scoped policy mutations are blocked: - `policy set ` returns `FailedPrecondition: "policy is managed globally"` +- `policy update ` returns `FailedPrecondition: "policy is managed globally"` - `rule approve`, `rule approve-all` return `FailedPrecondition: "cannot approve rules while a global policy is active"` - Revoking a previously approved draft chunk is blocked (it would modify the sandbox policy) - Rejecting pending chunks is allowed (does not modify the sandbox policy) @@ -270,7 +298,7 @@ See [Gateway Settings Channel](gateway-settings.md#global-policy-lifecycle) for When `--full` is specified, the server includes the deserialized `SandboxPolicy` protobuf in the `SandboxPolicyRevision.policy` field (see `crates/openshell-server/src/grpc.rs` -- `policy_record_to_revision()` with `include_policy: true`). The CLI converts this proto back to YAML via `policy_to_yaml()`, which uses a `BTreeMap` for `network_policies` to produce deterministic key ordering. See `crates/openshell-cli/src/run.rs` -- `policy_to_yaml()`, `policy_get()`. -See `crates/openshell-cli/src/main.rs` -- `PolicyCommands` enum, `crates/openshell-cli/src/run.rs` -- `policy_set()`, `policy_get()`, `policy_list()`. +See `crates/openshell-cli/src/main.rs` -- `PolicyCommands` enum, `crates/openshell-cli/src/run.rs` -- `policy_update()`, `policy_set()`, `policy_get()`, `policy_list()`. --- diff --git a/crates/openshell-cli/src/lib.rs b/crates/openshell-cli/src/lib.rs index 09e05449b..1746547ef 100644 --- a/crates/openshell-cli/src/lib.rs +++ b/crates/openshell-cli/src/lib.rs @@ -12,6 +12,7 @@ pub mod auth; pub mod bootstrap; pub mod completers; pub mod edge_tunnel; +pub(crate) mod policy_update; pub mod run; pub mod ssh; pub mod tls; diff --git a/crates/openshell-cli/src/main.rs b/crates/openshell-cli/src/main.rs index 292922411..8c83914e3 100644 --- a/crates/openshell-cli/src/main.rs +++ b/crates/openshell-cli/src/main.rs @@ -254,6 +254,8 @@ const POLICY_EXAMPLES: &str = "\x1b[1mALIAS\x1b[0m \x1b[1mEXAMPLES\x1b[0m $ openshell policy get my-sandbox $ openshell policy set my-sandbox --policy policy.yaml + $ openshell policy update my-sandbox --add-endpoint api.github.com:443:read-only:rest:enforce + $ openshell policy update my-sandbox --add-allow api.github.com:443:GET:/repos/** $ openshell policy set --global --policy policy.yaml $ openshell policy delete --global $ openshell policy list my-sandbox @@ -1438,6 +1440,54 @@ enum PolicyCommands { timeout: u64, }, + /// Incrementally update policy on a live sandbox. + #[command(help_template = LEAF_HELP_TEMPLATE, next_help_heading = "FLAGS")] + Update { + /// Sandbox name (defaults to last-used sandbox). + #[arg(add = ArgValueCompleter::new(completers::complete_sandbox_names))] + name: Option, + + /// Add or merge an endpoint: host:port[:access[:protocol[:enforcement]]]. + #[arg(long = "add-endpoint")] + add_endpoints: Vec, + + /// Remove an endpoint: host:port. + #[arg(long = "remove-endpoint")] + remove_endpoints: Vec, + + /// Add a REST allow rule: host:port:METHOD:path_glob. + #[arg(long = "add-allow")] + add_allow: Vec, + + /// Add a REST deny rule: host:port:METHOD:path_glob. + #[arg(long = "add-deny")] + add_deny: Vec, + + /// Remove a network rule by name. + #[arg(long = "remove-rule")] + remove_rules: Vec, + + /// Add binaries to each --add-endpoint rule. + #[arg(long = "binary", value_hint = ValueHint::FilePath)] + binaries: Vec, + + /// Override the generated rule name when exactly one --add-endpoint is provided. + #[arg(long = "rule-name")] + rule_name: Option, + + /// Preview the merged policy without sending it to the gateway. + #[arg(long)] + dry_run: bool, + + /// Wait for the sandbox to load the policy revision. + #[arg(long)] + wait: bool, + + /// Timeout for --wait in seconds. + #[arg(long, default_value_t = 60)] + timeout: u64, + }, + /// Show current active policy for a sandbox or the global policy. #[command(help_template = LEAF_HELP_TEMPLATE, next_help_heading = "FLAGS")] Get { @@ -1988,6 +2038,37 @@ async fn main() -> Result<()> { .await?; } } + PolicyCommands::Update { + name, + add_endpoints, + remove_endpoints, + add_allow, + add_deny, + remove_rules, + binaries, + rule_name, + dry_run, + wait, + timeout, + } => { + let name = resolve_sandbox_name(name, &ctx.name)?; + run::sandbox_policy_update( + &ctx.endpoint, + &name, + &add_endpoints, + &remove_endpoints, + &add_deny, + &add_allow, + &remove_rules, + &binaries, + rule_name.as_deref(), + dry_run, + wait, + timeout, + &tls, + ) + .await?; + } PolicyCommands::Get { name, rev, diff --git a/crates/openshell-cli/src/policy_update.rs b/crates/openshell-cli/src/policy_update.rs new file mode 100644 index 000000000..9f053f73b --- /dev/null +++ b/crates/openshell-cli/src/policy_update.rs @@ -0,0 +1,473 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +use std::collections::BTreeMap; + +use miette::{Result, miette}; +use openshell_core::proto::policy_merge_operation; +use openshell_core::proto::{ + AddAllowRules, AddDenyRules, AddNetworkRule, L7Allow, L7DenyRule, L7Rule, NetworkBinary, + NetworkEndpoint, NetworkPolicyRule, PolicyMergeOperation, RemoveNetworkEndpoint, + RemoveNetworkRule, +}; +use openshell_policy::{PolicyMergeOp, generated_rule_name}; + +#[derive(Debug, Clone)] +pub(crate) struct PolicyUpdatePlan { + pub merge_operations: Vec, + pub preview_operations: Vec, +} + +pub(crate) fn build_policy_update_plan( + add_endpoints: &[String], + remove_endpoints: &[String], + add_deny: &[String], + add_allow: &[String], + remove_rules: &[String], + binaries: &[String], + rule_name: Option<&str>, +) -> Result { + if binaries.iter().any(|binary| binary.trim().is_empty()) { + return Err(miette!("--binary values must not be empty")); + } + if !binaries.is_empty() && add_endpoints.is_empty() { + return Err(miette!("--binary can only be used with --add-endpoint")); + } + if rule_name.is_some() && add_endpoints.is_empty() { + return Err(miette!("--rule-name can only be used with --add-endpoint")); + } + if rule_name.is_some() && add_endpoints.len() > 1 { + return Err(miette!( + "--rule-name is only supported when exactly one --add-endpoint is provided" + )); + } + + let mut merge_operations = Vec::new(); + let mut preview_operations = Vec::new(); + + let deduped_binaries = dedup_strings(binaries); + for spec in add_endpoints { + let endpoint = parse_add_endpoint_spec(spec)?; + let target_rule_name = rule_name + .map(str::trim) + .filter(|name| !name.is_empty()) + .map(ToString::to_string) + .unwrap_or_else(|| generated_rule_name(&endpoint.host, endpoint.port)); + let rule = NetworkPolicyRule { + name: target_rule_name.clone(), + endpoints: vec![endpoint.clone()], + binaries: deduped_binaries + .iter() + .map(|path| NetworkBinary { + path: path.clone(), + ..Default::default() + }) + .collect(), + }; + merge_operations.push(PolicyMergeOperation { + operation: Some(policy_merge_operation::Operation::AddRule(AddNetworkRule { + rule_name: target_rule_name.clone(), + rule: Some(rule.clone()), + })), + }); + preview_operations.push(PolicyMergeOp::AddRule { + rule_name: target_rule_name, + rule, + }); + } + + for spec in remove_endpoints { + let (host, port) = parse_remove_endpoint_spec(spec)?; + merge_operations.push(PolicyMergeOperation { + operation: Some(policy_merge_operation::Operation::RemoveEndpoint( + RemoveNetworkEndpoint { + rule_name: String::new(), + host: host.clone(), + port, + }, + )), + }); + preview_operations.push(PolicyMergeOp::RemoveEndpoint { + rule_name: None, + host, + port, + }); + } + + for name in remove_rules { + let rule_name = name.trim(); + if rule_name.is_empty() { + return Err(miette!("--remove-rule values must not be empty")); + } + merge_operations.push(PolicyMergeOperation { + operation: Some(policy_merge_operation::Operation::RemoveRule( + RemoveNetworkRule { + rule_name: rule_name.to_string(), + }, + )), + }); + preview_operations.push(PolicyMergeOp::RemoveRule { + rule_name: rule_name.to_string(), + }); + } + + for ((host, port), rules) in group_allow_rules(add_allow)? { + merge_operations.push(PolicyMergeOperation { + operation: Some(policy_merge_operation::Operation::AddAllowRules( + AddAllowRules { + host: host.clone(), + port, + rules: rules.clone(), + }, + )), + }); + preview_operations.push(PolicyMergeOp::AddAllowRules { host, port, rules }); + } + + for ((host, port), deny_rules) in group_deny_rules(add_deny)? { + merge_operations.push(PolicyMergeOperation { + operation: Some(policy_merge_operation::Operation::AddDenyRules( + AddDenyRules { + host: host.clone(), + port, + deny_rules: deny_rules.clone(), + }, + )), + }); + preview_operations.push(PolicyMergeOp::AddDenyRules { + host, + port, + deny_rules, + }); + } + + if merge_operations.is_empty() { + return Err(miette!( + "policy update requires at least one operation flag" + )); + } + + Ok(PolicyUpdatePlan { + merge_operations, + preview_operations, + }) +} + +fn group_allow_rules(specs: &[String]) -> Result>> { + let mut grouped = BTreeMap::new(); + for spec in specs { + let parsed = parse_l7_rule_spec("--add-allow", spec)?; + grouped + .entry((parsed.host, parsed.port)) + .or_insert_with(Vec::new) + .push(L7Rule { + allow: Some(L7Allow { + method: parsed.method, + path: parsed.path, + command: String::new(), + query: Default::default(), + }), + }); + } + Ok(grouped) +} + +fn group_deny_rules(specs: &[String]) -> Result>> { + let mut grouped = BTreeMap::new(); + for spec in specs { + let parsed = parse_l7_rule_spec("--add-deny", spec)?; + grouped + .entry((parsed.host, parsed.port)) + .or_insert_with(Vec::new) + .push(L7DenyRule { + method: parsed.method, + path: parsed.path, + command: String::new(), + query: Default::default(), + }); + } + Ok(grouped) +} + +#[derive(Debug, Clone)] +struct ParsedL7RuleSpec { + host: String, + port: u32, + method: String, + path: String, +} + +fn parse_l7_rule_spec(flag: &str, spec: &str) -> Result { + let parts = spec.split(':').collect::>(); + if parts.len() != 4 { + return Err(miette!( + "{flag} expects host:port:METHOD:path_glob, got '{spec}'" + )); + } + + let host = parse_host(flag, spec, parts[0])?; + let port = parse_port(flag, spec, parts[1])?; + let method = parts[2].trim(); + if method.is_empty() { + return Err(miette!("{flag} has an empty METHOD segment in '{spec}'")); + } + if method.contains(char::is_whitespace) { + return Err(miette!( + "{flag} METHOD must not contain whitespace in '{spec}'" + )); + } + + let path = parts[3].trim(); + if path.is_empty() { + return Err(miette!("{flag} has an empty path segment in '{spec}'")); + } + if !path.starts_with('/') && path != "**" && !path.starts_with("**/") { + return Err(miette!( + "{flag} path must start with '/' or be '**', got '{path}' in '{spec}'" + )); + } + + Ok(ParsedL7RuleSpec { + host, + port, + method: method.to_ascii_uppercase(), + path: path.to_string(), + }) +} + +fn parse_remove_endpoint_spec(spec: &str) -> Result<(String, u32)> { + let parts = spec.split(':').collect::>(); + if parts.len() != 2 { + return Err(miette!("--remove-endpoint expects host:port, got '{spec}'")); + } + + Ok(( + parse_host("--remove-endpoint", spec, parts[0])?, + parse_port("--remove-endpoint", spec, parts[1])?, + )) +} + +fn parse_add_endpoint_spec(spec: &str) -> Result { + let parts = spec.split(':').collect::>(); + if !(2..=5).contains(&parts.len()) { + return Err(miette!( + "--add-endpoint expects host:port[:access[:protocol[:enforcement]]], got '{spec}'" + )); + } + + let host = parse_host("--add-endpoint", spec, parts[0])?; + let port = parse_port("--add-endpoint", spec, parts[1])?; + + let access = parts.get(2).copied().unwrap_or("").trim(); + let protocol = parts.get(3).copied().unwrap_or("").trim(); + let enforcement = parts.get(4).copied().unwrap_or("").trim(); + + if parts.len() == 3 && access.is_empty() { + return Err(miette!( + "--add-endpoint has an empty access segment in '{spec}'; omit it entirely if you do not need access or protocol fields" + )); + } + if !enforcement.is_empty() && protocol.is_empty() { + return Err(miette!( + "--add-endpoint cannot set enforcement without protocol in '{spec}'" + )); + } + if !access.is_empty() && !matches!(access, "read-only" | "read-write" | "full") { + return Err(miette!( + "--add-endpoint access segment must be one of read-only, read-write, or full; got '{access}' in '{spec}'" + )); + } + if !protocol.is_empty() && !matches!(protocol, "rest" | "sql") { + return Err(miette!( + "--add-endpoint protocol segment must be 'rest' or 'sql'; got '{protocol}' in '{spec}'" + )); + } + if !enforcement.is_empty() && !matches!(enforcement, "enforce" | "audit") { + return Err(miette!( + "--add-endpoint enforcement segment must be 'enforce' or 'audit'; got '{enforcement}' in '{spec}'" + )); + } + + Ok(NetworkEndpoint { + host, + port, + ports: vec![port], + protocol: protocol.to_string(), + enforcement: enforcement.to_string(), + access: access.to_string(), + ..Default::default() + }) +} + +fn parse_host(flag: &str, spec: &str, host: &str) -> Result { + let host = host.trim(); + if host.is_empty() { + return Err(miette!("{flag} has an empty host segment in '{spec}'")); + } + if host.contains(char::is_whitespace) { + return Err(miette!( + "{flag} host must not contain whitespace in '{spec}'" + )); + } + if host.contains('/') { + return Err(miette!("{flag} host must not contain '/' in '{spec}'")); + } + Ok(host.to_string()) +} + +fn parse_port(flag: &str, spec: &str, port: &str) -> Result { + let port = port.trim(); + if port.is_empty() { + return Err(miette!("{flag} has an empty port segment in '{spec}'")); + } + let parsed = port.parse::().map_err(|_| { + miette!("{flag} port segment must be a base-10 integer, got '{port}' in '{spec}'") + })?; + if parsed == 0 || parsed > 65535 { + return Err(miette!( + "{flag} port must be in the range 1-65535, got '{parsed}' in '{spec}'" + )); + } + Ok(parsed) +} + +fn dedup_strings(values: &[String]) -> Vec { + let mut deduped = Vec::new(); + for value in values { + let trimmed = value.trim(); + if !trimmed.is_empty() && !deduped.iter().any(|existing| existing == trimmed) { + deduped.push(trimmed.to_string()); + } + } + deduped +} + +#[cfg(test)] +mod tests { + use super::build_policy_update_plan; + + #[test] + fn parse_add_endpoint_basic_l4() { + let plan = + build_policy_update_plan(&["ghcr.io:443".to_string()], &[], &[], &[], &[], &[], None) + .expect("plan should build"); + assert_eq!(plan.merge_operations.len(), 1); + assert_eq!(plan.preview_operations.len(), 1); + } + + #[test] + fn parse_add_endpoint_rejects_bad_access() { + let error = build_policy_update_plan( + &["api.github.com:443:write-ish".to_string()], + &[], + &[], + &[], + &[], + &[], + None, + ) + .expect_err("plan should fail"); + assert!(error.to_string().contains("access segment")); + } + + #[test] + fn parse_add_endpoint_allows_empty_access_when_protocol_present() { + build_policy_update_plan( + &["api.github.com:443::rest:enforce".to_string()], + &[], + &[], + &[], + &[], + &[], + None, + ) + .expect("plan should build"); + } + + #[test] + fn parse_add_deny_rejects_empty_method() { + let error = build_policy_update_plan( + &[], + &[], + &["api.github.com:443::/repos/**".to_string()], + &[], + &[], + &[], + None, + ) + .expect_err("plan should fail"); + assert!(error.to_string().contains("METHOD")); + } + + #[test] + fn parse_add_allow_rejects_non_absolute_path() { + let error = build_policy_update_plan( + &[], + &[], + &[], + &["api.github.com:443:GET:repos/**".to_string()], + &[], + &[], + None, + ) + .expect_err("plan should fail"); + assert!(error.to_string().contains("path must start with '/'")); + } + + #[test] + fn parse_add_endpoint_rejects_enforcement_without_protocol() { + let error = build_policy_update_plan( + &["api.github.com:443:read-only::enforce".to_string()], + &[], + &[], + &[], + &[], + &[], + None, + ) + .expect_err("plan should fail"); + assert!( + error + .to_string() + .contains("cannot set enforcement without protocol") + ); + } + + #[test] + fn parse_remove_endpoint_rejects_out_of_range_port() { + let error = build_policy_update_plan( + &[], + &["api.github.com:70000".to_string()], + &[], + &[], + &[], + &[], + None, + ) + .expect_err("plan should fail"); + assert!(error.to_string().contains("range 1-65535")); + } + + #[test] + fn binary_requires_add_endpoint() { + let error = + build_policy_update_plan(&[], &[], &[], &[], &[], &["/usr/bin/gh".to_string()], None) + .expect_err("plan should fail"); + assert!(error.to_string().contains("--binary")); + } + + #[test] + fn rule_name_rejects_multiple_add_endpoints() { + let error = build_policy_update_plan( + &["api.github.com:443".to_string(), "ghcr.io:443".to_string()], + &[], + &[], + &[], + &[], + &[], + Some("shared"), + ) + .expect_err("plan should fail"); + assert!(error.to_string().contains("exactly one --add-endpoint")); + } +} diff --git a/crates/openshell-cli/src/run.rs b/crates/openshell-cli/src/run.rs index c41b53518..60e28ac7e 100644 --- a/crates/openshell-cli/src/run.rs +++ b/crates/openshell-cli/src/run.rs @@ -3,6 +3,7 @@ //! CLI command implementations. +use crate::policy_update::build_policy_update_plan; use crate::tls::{ TlsOptions, build_rustls_config, grpc_client, grpc_inference_client, require_tls_materials, }; @@ -27,7 +28,7 @@ use openshell_core::proto::{ ExecSandboxRequest, GetClusterInferenceRequest, GetDraftHistoryRequest, GetDraftPolicyRequest, GetGatewayConfigRequest, GetProviderRequest, GetSandboxConfigRequest, GetSandboxLogsRequest, GetSandboxPolicyStatusRequest, GetSandboxRequest, HealthRequest, ListProvidersRequest, - ListSandboxPoliciesRequest, ListSandboxesRequest, PolicyStatus, Provider, + ListSandboxPoliciesRequest, ListSandboxesRequest, PolicySource, PolicyStatus, Provider, RejectDraftChunkRequest, Sandbox, SandboxPhase, SandboxPolicy, SandboxSpec, SandboxTemplate, SetClusterInferenceRequest, SettingScope, SettingValue, UpdateConfigRequest, UpdateProviderRequest, WatchSandboxRequest, exec_sandbox_event, setting_value, @@ -4144,6 +4145,16 @@ fn format_setting_value(value: Option<&SettingValue>) -> String { } } +fn short_hash(hash: &str) -> &str { + if hash.len() >= 12 { &hash[..12] } else { hash } +} + +fn print_policy_merge_warnings(warnings: &[openshell_policy::PolicyMergeWarning]) { + for warning in warnings { + eprintln!("{} {}", "!".yellow().bold(), warning); + } +} + pub async fn sandbox_policy_set_global( server: &str, policy_path: &str, @@ -4172,6 +4183,7 @@ pub async fn sandbox_policy_set_global( setting_value: None, delete_setting: false, global: true, + merge_operations: vec![], }) .await .into_diagnostic()? @@ -4221,12 +4233,11 @@ pub async fn sandbox_settings_get( return Ok(()); } - let policy_source = - if response.policy_source == openshell_core::proto::PolicySource::Global as i32 { - "global" - } else { - "sandbox" - }; + let policy_source = if response.policy_source == PolicySource::Global as i32 { + "global" + } else { + "sandbox" + }; println!("Sandbox: {}", name); println!("Config Rev: {}", response.config_revision); @@ -4297,12 +4308,11 @@ fn settings_to_json_sandbox( name: &str, response: &openshell_core::proto::GetSandboxConfigResponse, ) -> serde_json::Value { - let policy_source = - if response.policy_source == openshell_core::proto::PolicySource::Global as i32 { - "global" - } else { - "sandbox" - }; + let policy_source = if response.policy_source == PolicySource::Global as i32 { + "global" + } else { + "sandbox" + }; let mut settings = serde_json::Map::new(); let mut keys: Vec<_> = response.settings.keys().cloned().collect(); @@ -4371,6 +4381,7 @@ pub async fn gateway_setting_set( setting_value: Some(setting_value), delete_setting: false, global: true, + merge_operations: vec![], }) .await .into_diagnostic()? @@ -4404,6 +4415,7 @@ pub async fn sandbox_setting_set( setting_value: Some(setting_value), delete_setting: false, global: false, + merge_operations: vec![], }) .await .into_diagnostic()? @@ -4437,6 +4449,7 @@ pub async fn gateway_setting_delete( setting_value: None, delete_setting: true, global: true, + merge_operations: vec![], }) .await .into_diagnostic()? @@ -4470,6 +4483,7 @@ pub async fn sandbox_setting_delete( setting_value: None, delete_setting: true, global: false, + merge_operations: vec![], }) .await .into_diagnostic()? @@ -4527,6 +4541,7 @@ pub async fn sandbox_policy_set( setting_value: None, delete_setting: false, global: false, + merge_operations: vec![], }) .await .into_diagnostic()?; @@ -4614,6 +4629,176 @@ pub async fn sandbox_policy_set( } } +#[allow(clippy::too_many_arguments)] +pub async fn sandbox_policy_update( + server: &str, + name: &str, + add_endpoints: &[String], + remove_endpoints: &[String], + add_deny: &[String], + add_allow: &[String], + remove_rules: &[String], + binaries: &[String], + rule_name: Option<&str>, + dry_run: bool, + wait: bool, + timeout_secs: u64, + tls: &TlsOptions, +) -> Result<()> { + if dry_run && wait { + return Err(miette!("--wait cannot be combined with --dry-run")); + } + + let plan = build_policy_update_plan( + add_endpoints, + remove_endpoints, + add_deny, + add_allow, + remove_rules, + binaries, + rule_name, + )?; + + let mut client = grpc_client(server, tls).await?; + let sandbox = client + .get_sandbox(GetSandboxRequest { + name: name.to_string(), + }) + .await + .into_diagnostic()? + .into_inner() + .sandbox + .ok_or_else(|| miette!("sandbox not found"))?; + + let current = client + .get_sandbox_config(GetSandboxConfigRequest { + sandbox_id: sandbox.id.clone(), + }) + .await + .into_diagnostic()? + .into_inner(); + + if current.policy_source == PolicySource::Global as i32 { + return Err(miette!( + "policy is managed globally; delete the global policy before using `openshell policy update`" + )); + } + + let merged = openshell_policy::merge_policy( + current.policy.clone().unwrap_or_default(), + &plan.preview_operations, + ) + .map_err(|error| miette!("{error}"))?; + + if dry_run { + eprintln!( + "{} Dry run preview for {} incremental policy operation(s)", + "✓".green().bold(), + plan.preview_operations.len() + ); + print_policy_merge_warnings(&merged.warnings); + print_sandbox_policy(&merged.policy); + return Ok(()); + } + + let current_version = current.version; + let current_hash = current.policy_hash.clone(); + let response = client + .update_config(UpdateConfigRequest { + name: name.to_string(), + policy: None, + setting_key: String::new(), + setting_value: None, + delete_setting: false, + global: false, + merge_operations: plan.merge_operations, + }) + .await + .into_diagnostic()? + .into_inner(); + + print_policy_merge_warnings(&merged.warnings); + + if response.version == current_version && response.policy_hash == current_hash { + eprintln!( + "{} Policy unchanged (version {}, hash: {})", + "·".dimmed(), + response.version, + short_hash(&response.policy_hash) + ); + return Ok(()); + } + + eprintln!( + "{} Policy version {} submitted (hash: {})", + "✓".green().bold(), + response.version, + short_hash(&response.policy_hash) + ); + + if !wait { + return Ok(()); + } + + let deadline = Instant::now() + Duration::from_secs(timeout_secs); + loop { + if Instant::now() > deadline { + eprintln!( + "{} Timeout waiting for policy version {} to load", + "✗".red().bold(), + response.version + ); + std::process::exit(124); + } + + tokio::time::sleep(Duration::from_secs(1)).await; + + let status_resp = client + .get_sandbox_policy_status(GetSandboxPolicyStatusRequest { + name: name.to_string(), + version: response.version, + global: false, + }) + .await + .into_diagnostic()?; + + let inner = status_resp.into_inner(); + if let Some(rev) = &inner.revision { + let status = PolicyStatus::try_from(rev.status).unwrap_or(PolicyStatus::Unspecified); + match status { + PolicyStatus::Loaded => { + eprintln!( + "{} Policy version {} loaded (active version: {})", + "✓".green().bold(), + rev.version, + inner.active_version + ); + return Ok(()); + } + PolicyStatus::Failed => { + eprintln!( + "{} Policy version {} failed to load: {}", + "✗".red().bold(), + rev.version, + rev.load_error + ); + std::process::exit(1); + } + PolicyStatus::Superseded => { + eprintln!( + "{} Policy version {} was superseded (active version: {})", + "⚠".yellow().bold(), + rev.version, + inner.active_version + ); + return Ok(()); + } + _ => {} + } + } + } +} + pub async fn sandbox_policy_get( server: &str, name: &str, diff --git a/crates/openshell-policy/src/lib.rs b/crates/openshell-policy/src/lib.rs index e3c26061a..2c8c0cc76 100644 --- a/crates/openshell-policy/src/lib.rs +++ b/crates/openshell-policy/src/lib.rs @@ -9,6 +9,8 @@ //! policy schema. Both parsing (YAML→proto) and serialization (proto→YAML) use //! these types, ensuring round-trip fidelity. +mod merge; + use std::collections::{BTreeMap, HashMap}; use std::fmt; use std::path::Path; @@ -20,6 +22,11 @@ use openshell_core::proto::{ }; use serde::{Deserialize, Serialize}; +pub use merge::{ + PolicyMergeError, PolicyMergeOp, PolicyMergeResult, PolicyMergeWarning, generated_rule_name, + merge_policy, +}; + // --------------------------------------------------------------------------- // YAML serde types (canonical — used for both parsing and serialization) // --------------------------------------------------------------------------- diff --git a/crates/openshell-policy/src/merge.rs b/crates/openshell-policy/src/merge.rs new file mode 100644 index 000000000..5f5d2d40d --- /dev/null +++ b/crates/openshell-policy/src/merge.rs @@ -0,0 +1,1016 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +use std::collections::HashSet; + +use openshell_core::proto::{ + L7Allow, L7DenyRule, L7Rule, NetworkBinary, NetworkEndpoint, NetworkPolicyRule, SandboxPolicy, +}; + +#[derive(Debug, Clone, PartialEq)] +pub enum PolicyMergeOp { + AddRule { + rule_name: String, + rule: NetworkPolicyRule, + }, + RemoveEndpoint { + rule_name: Option, + host: String, + port: u32, + }, + RemoveRule { + rule_name: String, + }, + AddDenyRules { + host: String, + port: u32, + deny_rules: Vec, + }, + AddAllowRules { + host: String, + port: u32, + rules: Vec, + }, + RemoveBinary { + rule_name: String, + binary_path: String, + }, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum PolicyMergeWarning { + ExistingProtocolRetained { + host: String, + port: u32, + existing: String, + incoming: String, + }, + ExistingEnforcementRetained { + host: String, + port: u32, + existing: String, + incoming: String, + }, + ExistingTlsRetained { + host: String, + port: u32, + existing: String, + incoming: String, + }, + ExistingAccessRetained { + host: String, + port: u32, + existing: String, + incoming: String, + }, + ExpandedAccessPreset { + host: String, + port: u32, + access: String, + }, + IgnoredIncomingAccessBecauseRulesExist { + host: String, + port: u32, + incoming: String, + }, +} + +impl std::fmt::Display for PolicyMergeWarning { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::ExistingProtocolRetained { + host, + port, + existing, + incoming, + } => write!( + f, + "endpoint {host}:{port} keeps existing protocol '{existing}' and ignores incoming '{incoming}'" + ), + Self::ExistingEnforcementRetained { + host, + port, + existing, + incoming, + } => write!( + f, + "endpoint {host}:{port} keeps existing enforcement '{existing}' and ignores incoming '{incoming}'" + ), + Self::ExistingTlsRetained { + host, + port, + existing, + incoming, + } => write!( + f, + "endpoint {host}:{port} keeps existing tls mode '{existing}' and ignores incoming '{incoming}'" + ), + Self::ExistingAccessRetained { + host, + port, + existing, + incoming, + } => write!( + f, + "endpoint {host}:{port} keeps existing access preset '{existing}' and ignores incoming '{incoming}'" + ), + Self::ExpandedAccessPreset { host, port, access } => write!( + f, + "expanded access preset '{access}' to explicit rules for endpoint {host}:{port}" + ), + Self::IgnoredIncomingAccessBecauseRulesExist { + host, + port, + incoming, + } => write!( + f, + "endpoint {host}:{port} already uses explicit rules; incoming access preset '{incoming}' was ignored" + ), + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum PolicyMergeError { + MissingRuleNameForAddRule, + InvalidEndpointReference { + host: String, + port: u32, + }, + EndpointNotFound { + host: String, + port: u32, + }, + EndpointHasNoL7Inspection { + host: String, + port: u32, + }, + UnsupportedEndpointProtocol { + host: String, + port: u32, + protocol: String, + }, + EndpointHasNoAllowBase { + host: String, + port: u32, + }, + UnsupportedAccessPreset { + host: String, + port: u32, + access: String, + }, +} + +impl std::fmt::Display for PolicyMergeError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::MissingRuleNameForAddRule => write!(f, "add-rule operation requires a rule name"), + Self::InvalidEndpointReference { host, port } => { + write!(f, "invalid endpoint reference '{host}:{port}'") + } + Self::EndpointNotFound { host, port } => { + write!( + f, + "endpoint {host}:{port} was not found in the current policy" + ) + } + Self::EndpointHasNoL7Inspection { host, port } => write!( + f, + "endpoint {host}:{port} has no L7 inspection configured (protocol is empty)" + ), + Self::UnsupportedEndpointProtocol { + host, + port, + protocol, + } => write!( + f, + "endpoint {host}:{port} uses unsupported protocol '{protocol}'; this operation currently supports only protocol 'rest'" + ), + Self::EndpointHasNoAllowBase { host, port } => write!( + f, + "endpoint {host}:{port} has no base allow set; configure access or explicit allow rules before adding deny rules" + ), + Self::UnsupportedAccessPreset { host, port, access } => write!( + f, + "endpoint {host}:{port} uses unsupported access preset '{access}'" + ), + } + } +} + +impl std::error::Error for PolicyMergeError {} + +#[derive(Debug, Clone, PartialEq)] +pub struct PolicyMergeResult { + pub policy: SandboxPolicy, + pub warnings: Vec, + pub changed: bool, +} + +pub fn merge_policy( + policy: SandboxPolicy, + operations: &[PolicyMergeOp], +) -> Result { + let mut merged = policy.clone(); + let mut warnings = Vec::new(); + + for operation in operations { + apply_operation(&mut merged, operation, &mut warnings)?; + } + + let changed = merged != policy; + Ok(PolicyMergeResult { + policy: merged, + warnings, + changed, + }) +} + +pub fn generated_rule_name(host: &str, port: u32) -> String { + let sanitized = host + .replace(['.', '-'], "_") + .chars() + .filter(|c| c.is_alphanumeric() || *c == '_') + .collect::(); + format!("allow_{sanitized}_{port}") +} + +fn apply_operation( + policy: &mut SandboxPolicy, + operation: &PolicyMergeOp, + warnings: &mut Vec, +) -> Result<(), PolicyMergeError> { + match operation { + PolicyMergeOp::AddRule { rule_name, rule } => { + add_rule(policy, rule_name, rule, warnings)?; + } + PolicyMergeOp::RemoveEndpoint { + rule_name, + host, + port, + } => { + remove_endpoint(policy, rule_name.as_deref(), host, *port); + } + PolicyMergeOp::RemoveRule { rule_name } => { + policy.network_policies.remove(rule_name); + } + PolicyMergeOp::AddDenyRules { + host, + port, + deny_rules, + } => { + let endpoint = find_endpoint_mut(policy, host, *port).ok_or_else(|| { + PolicyMergeError::EndpointNotFound { + host: host.clone(), + port: *port, + } + })?; + ensure_rest_endpoint(endpoint, host, *port)?; + if endpoint.access.is_empty() && endpoint.rules.is_empty() { + return Err(PolicyMergeError::EndpointHasNoAllowBase { + host: host.clone(), + port: *port, + }); + } + append_unique_deny_rules(&mut endpoint.deny_rules, deny_rules); + } + PolicyMergeOp::AddAllowRules { host, port, rules } => { + let endpoint = find_endpoint_mut(policy, host, *port).ok_or_else(|| { + PolicyMergeError::EndpointNotFound { + host: host.clone(), + port: *port, + } + })?; + ensure_rest_endpoint(endpoint, host, *port)?; + expand_existing_access(endpoint, host, *port, warnings)?; + append_unique_l7_rules(&mut endpoint.rules, rules); + } + PolicyMergeOp::RemoveBinary { + rule_name, + binary_path, + } => { + let should_remove = if let Some(rule) = policy.network_policies.get_mut(rule_name) { + let original_len = rule.binaries.len(); + rule.binaries.retain(|binary| binary.path != *binary_path); + original_len != rule.binaries.len() && rule.binaries.is_empty() + } else { + false + }; + if should_remove { + policy.network_policies.remove(rule_name); + } + } + } + Ok(()) +} + +fn add_rule( + policy: &mut SandboxPolicy, + rule_name: &str, + incoming_rule: &NetworkPolicyRule, + warnings: &mut Vec, +) -> Result<(), PolicyMergeError> { + if rule_name.trim().is_empty() { + return Err(PolicyMergeError::MissingRuleNameForAddRule); + } + + let mut incoming_rule = incoming_rule.clone(); + normalize_rule(&mut incoming_rule); + if incoming_rule.name.is_empty() { + incoming_rule.name = rule_name.to_string(); + } + + let target_key = if policy.network_policies.contains_key(rule_name) { + Some(rule_name.to_string()) + } else { + let mut keys: Vec<_> = policy.network_policies.keys().cloned().collect(); + keys.sort(); + keys.into_iter().find(|key| { + policy + .network_policies + .get(key) + .is_some_and(|existing_rule| rules_share_endpoint(existing_rule, &incoming_rule)) + }) + }; + + if let Some(key) = target_key { + let existing_rule = policy + .network_policies + .get_mut(&key) + .expect("existing rule must be present"); + merge_rules(existing_rule, &incoming_rule, warnings)?; + } else { + policy + .network_policies + .insert(rule_name.to_string(), incoming_rule); + } + + Ok(()) +} + +fn merge_rules( + existing_rule: &mut NetworkPolicyRule, + incoming_rule: &NetworkPolicyRule, + warnings: &mut Vec, +) -> Result<(), PolicyMergeError> { + append_unique_binaries(&mut existing_rule.binaries, &incoming_rule.binaries); + + for incoming_endpoint in &incoming_rule.endpoints { + let mut incoming_endpoint = incoming_endpoint.clone(); + normalize_endpoint(&mut incoming_endpoint); + if let Some(existing_endpoint) = + find_matching_endpoint_mut(&mut existing_rule.endpoints, &incoming_endpoint) + { + merge_endpoint(existing_endpoint, &incoming_endpoint, warnings)?; + } else { + existing_rule.endpoints.push(incoming_endpoint); + } + } + + Ok(()) +} + +fn merge_endpoint( + existing: &mut NetworkEndpoint, + incoming: &NetworkEndpoint, + warnings: &mut Vec, +) -> Result<(), PolicyMergeError> { + let host = if existing.host.is_empty() { + incoming.host.clone() + } else { + existing.host.clone() + }; + let port = canonical_ports(existing) + .into_iter() + .next() + .or_else(|| canonical_ports(incoming).into_iter().next()) + .unwrap_or(0); + + if existing.host.is_empty() { + existing.host = incoming.host.clone(); + } + + merge_endpoint_ports(existing, incoming); + let existing_protocol = existing.protocol.clone(); + merge_string_field( + &mut existing.protocol, + &incoming.protocol, + PolicyMergeWarning::ExistingProtocolRetained { + host: host.clone(), + port, + existing: existing_protocol, + incoming: incoming.protocol.clone(), + }, + warnings, + ); + let existing_enforcement = existing.enforcement.clone(); + merge_string_field( + &mut existing.enforcement, + &incoming.enforcement, + PolicyMergeWarning::ExistingEnforcementRetained { + host: host.clone(), + port, + existing: existing_enforcement, + incoming: incoming.enforcement.clone(), + }, + warnings, + ); + let existing_tls = existing.tls.clone(); + merge_string_field( + &mut existing.tls, + &incoming.tls, + PolicyMergeWarning::ExistingTlsRetained { + host: host.clone(), + port, + existing: existing_tls, + incoming: incoming.tls.clone(), + }, + warnings, + ); + + if !incoming.rules.is_empty() { + expand_existing_access(existing, &host, port, warnings)?; + append_unique_l7_rules(&mut existing.rules, &incoming.rules); + if !incoming.access.is_empty() { + warnings.push(PolicyMergeWarning::IgnoredIncomingAccessBecauseRulesExist { + host, + port, + incoming: incoming.access.clone(), + }); + } + } else if !incoming.access.is_empty() { + if !existing.rules.is_empty() { + warnings.push(PolicyMergeWarning::IgnoredIncomingAccessBecauseRulesExist { + host, + port, + incoming: incoming.access.clone(), + }); + } else if existing.access.is_empty() { + existing.access = incoming.access.clone(); + } else if existing.access != incoming.access { + warnings.push(PolicyMergeWarning::ExistingAccessRetained { + host, + port, + existing: existing.access.clone(), + incoming: incoming.access.clone(), + }); + } + } + + append_unique_deny_rules(&mut existing.deny_rules, &incoming.deny_rules); + append_unique_strings(&mut existing.allowed_ips, &incoming.allowed_ips); + normalize_endpoint(existing); + Ok(()) +} + +fn merge_string_field( + existing: &mut String, + incoming: &str, + warning: PolicyMergeWarning, + warnings: &mut Vec, +) { + if incoming.is_empty() { + return; + } + if existing.is_empty() { + *existing = incoming.to_string(); + } else if *existing != incoming { + warnings.push(warning); + } +} + +fn merge_endpoint_ports(existing: &mut NetworkEndpoint, incoming: &NetworkEndpoint) { + let mut ports = canonical_ports(existing); + for port in canonical_ports(incoming) { + if !ports.contains(&port) { + ports.push(port); + } + } + ports.sort_unstable(); + ports.dedup(); + existing.ports = ports.clone(); + existing.port = ports.first().copied().unwrap_or(0); +} + +fn rules_share_endpoint( + existing_rule: &NetworkPolicyRule, + incoming_rule: &NetworkPolicyRule, +) -> bool { + incoming_rule.endpoints.iter().any(|incoming_endpoint| { + existing_rule + .endpoints + .iter() + .any(|existing_endpoint| endpoints_overlap(existing_endpoint, incoming_endpoint)) + }) +} + +fn endpoints_overlap(left: &NetworkEndpoint, right: &NetworkEndpoint) -> bool { + if !left.host.eq_ignore_ascii_case(&right.host) { + return false; + } + + let left_ports = canonical_ports(left); + let right_ports = canonical_ports(right); + left_ports.iter().any(|port| right_ports.contains(port)) +} + +fn canonical_ports(endpoint: &NetworkEndpoint) -> Vec { + if !endpoint.ports.is_empty() { + endpoint.ports.clone() + } else if endpoint.port > 0 { + vec![endpoint.port] + } else { + vec![] + } +} + +fn find_matching_endpoint_mut<'a>( + endpoints: &'a mut [NetworkEndpoint], + target: &NetworkEndpoint, +) -> Option<&'a mut NetworkEndpoint> { + endpoints + .iter_mut() + .find(|endpoint| endpoints_overlap(endpoint, target)) +} + +fn find_endpoint_mut<'a>( + policy: &'a mut SandboxPolicy, + host: &str, + port: u32, +) -> Option<&'a mut NetworkEndpoint> { + let mut keys: Vec<_> = policy.network_policies.keys().cloned().collect(); + keys.sort(); + let target_key = keys.into_iter().find(|key| { + policy.network_policies.get(key).is_some_and(|rule| { + rule.endpoints + .iter() + .any(|endpoint| endpoint_matches_host_port(endpoint, host, port)) + }) + })?; + + policy + .network_policies + .get_mut(&target_key) + .and_then(|rule| { + rule.endpoints + .iter_mut() + .find(|endpoint| endpoint_matches_host_port(endpoint, host, port)) + }) +} + +fn endpoint_matches_host_port(endpoint: &NetworkEndpoint, host: &str, port: u32) -> bool { + endpoint.host.eq_ignore_ascii_case(host) && canonical_ports(endpoint).contains(&port) +} + +fn ensure_rest_endpoint( + endpoint: &NetworkEndpoint, + host: &str, + port: u32, +) -> Result<(), PolicyMergeError> { + if endpoint.protocol.is_empty() { + return Err(PolicyMergeError::EndpointHasNoL7Inspection { + host: host.to_string(), + port, + }); + } + if endpoint.protocol != "rest" { + return Err(PolicyMergeError::UnsupportedEndpointProtocol { + host: host.to_string(), + port, + protocol: endpoint.protocol.clone(), + }); + } + Ok(()) +} + +fn expand_existing_access( + endpoint: &mut NetworkEndpoint, + host: &str, + port: u32, + warnings: &mut Vec, +) -> Result<(), PolicyMergeError> { + if endpoint.access.is_empty() { + return Ok(()); + } + + let access = endpoint.access.clone(); + let expanded = + expand_access_preset(&access).ok_or_else(|| PolicyMergeError::UnsupportedAccessPreset { + host: host.to_string(), + port, + access: access.clone(), + })?; + endpoint.access.clear(); + append_unique_l7_rules(&mut endpoint.rules, &expanded); + warnings.push(PolicyMergeWarning::ExpandedAccessPreset { + host: host.to_string(), + port, + access, + }); + Ok(()) +} + +fn expand_access_preset(access: &str) -> Option> { + let methods = match access { + "read-only" => vec!["GET", "HEAD", "OPTIONS"], + "read-write" => vec!["GET", "HEAD", "OPTIONS", "POST", "PUT", "PATCH"], + "full" => vec!["*"], + _ => return None, + }; + + Some( + methods + .into_iter() + .map(|method| L7Rule { + allow: Some(L7Allow { + method: method.to_string(), + path: "**".to_string(), + command: String::new(), + query: Default::default(), + }), + }) + .collect(), + ) +} + +fn append_unique_binaries(existing: &mut Vec, incoming: &[NetworkBinary]) { + let mut seen: HashSet = existing.iter().map(|binary| binary.path.clone()).collect(); + for binary in incoming { + if seen.insert(binary.path.clone()) { + existing.push(binary.clone()); + } + } +} + +fn append_unique_strings(existing: &mut Vec, incoming: &[String]) { + let mut seen: HashSet = existing.iter().cloned().collect(); + for value in incoming { + if seen.insert(value.clone()) { + existing.push(value.clone()); + } + } +} + +fn append_unique_l7_rules(existing: &mut Vec, incoming: &[L7Rule]) { + for rule in incoming { + if !existing.contains(rule) { + existing.push(rule.clone()); + } + } +} + +fn append_unique_deny_rules(existing: &mut Vec, incoming: &[L7DenyRule]) { + for rule in incoming { + if !existing.contains(rule) { + existing.push(rule.clone()); + } + } +} + +fn normalize_rule(rule: &mut NetworkPolicyRule) { + for endpoint in &mut rule.endpoints { + normalize_endpoint(endpoint); + } + dedup_binaries(&mut rule.binaries); +} + +fn normalize_endpoint(endpoint: &mut NetworkEndpoint) { + let mut ports = canonical_ports(endpoint); + ports.sort_unstable(); + ports.dedup(); + endpoint.ports = ports.clone(); + endpoint.port = ports.first().copied().unwrap_or(0); + dedup_strings(&mut endpoint.allowed_ips); + dedup_l7_rules(&mut endpoint.rules); + dedup_deny_rules(&mut endpoint.deny_rules); +} + +fn dedup_strings(values: &mut Vec) { + let mut seen = HashSet::new(); + values.retain(|value| seen.insert(value.clone())); +} + +fn dedup_binaries(values: &mut Vec) { + let mut seen = HashSet::new(); + values.retain(|binary| seen.insert(binary.path.clone())); +} + +fn dedup_l7_rules(values: &mut Vec) { + let mut deduped = Vec::with_capacity(values.len()); + for value in std::mem::take(values) { + if !deduped.contains(&value) { + deduped.push(value); + } + } + *values = deduped; +} + +fn dedup_deny_rules(values: &mut Vec) { + let mut deduped = Vec::with_capacity(values.len()); + for value in std::mem::take(values) { + if !deduped.contains(&value) { + deduped.push(value); + } + } + *values = deduped; +} + +fn remove_endpoint(policy: &mut SandboxPolicy, rule_name: Option<&str>, host: &str, port: u32) { + let target_keys: Vec = if let Some(rule_name) = rule_name { + if policy.network_policies.contains_key(rule_name) { + vec![rule_name.to_string()] + } else { + vec![] + } + } else { + let mut keys: Vec<_> = policy.network_policies.keys().cloned().collect(); + keys.sort(); + keys + }; + + let mut empty_rules = Vec::new(); + for key in target_keys { + if let Some(rule) = policy.network_policies.get_mut(&key) { + rule.endpoints.retain_mut(|endpoint| { + if !endpoint_matches_host_port(endpoint, host, port) { + return true; + } + + let mut remaining_ports = canonical_ports(endpoint); + remaining_ports.retain(|existing_port| *existing_port != port); + remaining_ports.sort_unstable(); + remaining_ports.dedup(); + + if remaining_ports.is_empty() { + return false; + } + + endpoint.ports = remaining_ports.clone(); + endpoint.port = remaining_ports[0]; + true + }); + + if rule.endpoints.is_empty() { + empty_rules.push(key); + } + } + } + + for key in empty_rules { + policy.network_policies.remove(&key); + } +} + +#[cfg(test)] +mod tests { + use std::collections::HashMap; + + use super::{ + PolicyMergeError, PolicyMergeOp, PolicyMergeWarning, generated_rule_name, merge_policy, + }; + use crate::restrictive_default_policy; + use openshell_core::proto::{ + L7Allow, L7DenyRule, L7Rule, NetworkBinary, NetworkEndpoint, NetworkPolicyRule, + }; + + fn endpoint(host: &str, port: u32) -> NetworkEndpoint { + NetworkEndpoint { + host: host.to_string(), + port, + ports: vec![port], + ..Default::default() + } + } + + fn rule_with_endpoint(name: &str, host: &str, port: u32) -> NetworkPolicyRule { + NetworkPolicyRule { + name: name.to_string(), + endpoints: vec![endpoint(host, port)], + ..Default::default() + } + } + + fn rest_rule(method: &str, path: &str) -> L7Rule { + L7Rule { + allow: Some(L7Allow { + method: method.to_string(), + path: path.to_string(), + command: String::new(), + query: HashMap::new(), + }), + } + } + + #[test] + fn generated_rule_name_sanitizes_host() { + assert_eq!( + generated_rule_name("api.github.com", 443), + "allow_api_github_com_443" + ); + } + + #[test] + fn add_rule_merges_l7_fields_into_existing_endpoint() { + let mut policy = restrictive_default_policy(); + policy.network_policies.insert( + "existing".to_string(), + NetworkPolicyRule { + name: "existing".to_string(), + endpoints: vec![endpoint("api.github.com", 443)], + binaries: vec![NetworkBinary { + path: "/usr/bin/curl".to_string(), + ..Default::default() + }], + }, + ); + + let incoming = NetworkPolicyRule { + name: "incoming".to_string(), + endpoints: vec![NetworkEndpoint { + host: "api.github.com".to_string(), + port: 443, + ports: vec![443], + protocol: "rest".to_string(), + enforcement: "enforce".to_string(), + rules: vec![rest_rule("GET", "/repos/**")], + ..Default::default() + }], + binaries: vec![NetworkBinary { + path: "/usr/bin/gh".to_string(), + ..Default::default() + }], + }; + + let result = merge_policy( + policy, + &[PolicyMergeOp::AddRule { + rule_name: "allow_api_github_com_443".to_string(), + rule: incoming, + }], + ) + .expect("merge should succeed"); + + let rule = &result.policy.network_policies["existing"]; + let endpoint = &rule.endpoints[0]; + assert_eq!(endpoint.protocol, "rest"); + assert_eq!(endpoint.enforcement, "enforce"); + assert_eq!(endpoint.rules.len(), 1); + assert_eq!(rule.binaries.len(), 2); + } + + #[test] + fn add_allow_expands_access_preset() { + let mut policy = restrictive_default_policy(); + policy.network_policies.insert( + "github".to_string(), + NetworkPolicyRule { + name: "github".to_string(), + endpoints: vec![NetworkEndpoint { + host: "api.github.com".to_string(), + port: 443, + ports: vec![443], + protocol: "rest".to_string(), + access: "read-only".to_string(), + ..Default::default() + }], + ..Default::default() + }, + ); + + let result = merge_policy( + policy, + &[PolicyMergeOp::AddAllowRules { + host: "api.github.com".to_string(), + port: 443, + rules: vec![rest_rule("POST", "/repos/*/issues")], + }], + ) + .expect("merge should succeed"); + + let endpoint = &result.policy.network_policies["github"].endpoints[0]; + assert!(endpoint.access.is_empty()); + assert_eq!(endpoint.rules.len(), 4); + assert!(result.warnings.iter().any(|warning| matches!( + warning, + PolicyMergeWarning::ExpandedAccessPreset { access, .. } if access == "read-only" + ))); + } + + #[test] + fn add_deny_requires_rest_protocol() { + let mut policy = restrictive_default_policy(); + policy.network_policies.insert( + "db".to_string(), + NetworkPolicyRule { + name: "db".to_string(), + endpoints: vec![NetworkEndpoint { + host: "db.example.com".to_string(), + port: 5432, + ports: vec![5432], + protocol: "sql".to_string(), + access: "full".to_string(), + ..Default::default() + }], + ..Default::default() + }, + ); + + let error = merge_policy( + policy, + &[PolicyMergeOp::AddDenyRules { + host: "db.example.com".to_string(), + port: 5432, + deny_rules: vec![L7DenyRule { + method: "POST".to_string(), + path: "/admin".to_string(), + ..Default::default() + }], + }], + ) + .expect_err("merge should fail"); + + assert!(matches!( + error, + PolicyMergeError::UnsupportedEndpointProtocol { protocol, .. } if protocol == "sql" + )); + } + + #[test] + fn remove_endpoint_drops_only_requested_port() { + let mut policy = restrictive_default_policy(); + policy.network_policies.insert( + "multi".to_string(), + NetworkPolicyRule { + name: "multi".to_string(), + endpoints: vec![NetworkEndpoint { + host: "api.example.com".to_string(), + port: 80, + ports: vec![80, 443], + ..Default::default() + }], + ..Default::default() + }, + ); + + let result = merge_policy( + policy, + &[PolicyMergeOp::RemoveEndpoint { + rule_name: None, + host: "api.example.com".to_string(), + port: 443, + }], + ) + .expect("merge should succeed"); + + let endpoint = &result.policy.network_policies["multi"].endpoints[0]; + assert_eq!(endpoint.ports, vec![80]); + assert_eq!(endpoint.port, 80); + } + + #[test] + fn remove_binary_removes_rule_when_last_binary_is_deleted() { + let mut policy = restrictive_default_policy(); + policy.network_policies.insert( + "github".to_string(), + NetworkPolicyRule { + name: "github".to_string(), + endpoints: vec![endpoint("api.github.com", 443)], + binaries: vec![NetworkBinary { + path: "/usr/bin/gh".to_string(), + ..Default::default() + }], + }, + ); + + let result = merge_policy( + policy, + &[PolicyMergeOp::RemoveBinary { + rule_name: "github".to_string(), + binary_path: "/usr/bin/gh".to_string(), + }], + ) + .expect("merge should succeed"); + + assert!(!result.policy.network_policies.contains_key("github")); + } + + #[test] + fn add_rule_without_existing_match_inserts_requested_key() { + let policy = restrictive_default_policy(); + let result = merge_policy( + policy, + &[PolicyMergeOp::AddRule { + rule_name: "allow_api_example_com_443".to_string(), + rule: rule_with_endpoint("custom", "api.example.com", 443), + }], + ) + .expect("merge should succeed"); + + assert!( + result + .policy + .network_policies + .contains_key("allow_api_example_com_443") + ); + } +} diff --git a/crates/openshell-sandbox/src/grpc_client.rs b/crates/openshell-sandbox/src/grpc_client.rs index 5503637ee..c97d1d792 100644 --- a/crates/openshell-sandbox/src/grpc_client.rs +++ b/crates/openshell-sandbox/src/grpc_client.rs @@ -133,6 +133,7 @@ async fn sync_policy_with_client( setting_value: None, delete_setting: false, global: false, + merge_operations: vec![], }) .await .into_diagnostic() diff --git a/crates/openshell-server/src/grpc/policy.rs b/crates/openshell-server/src/grpc/policy.rs index 58d0c03cf..e6c37a8e5 100644 --- a/crates/openshell-server/src/grpc/policy.rs +++ b/crates/openshell-server/src/grpc/policy.rs @@ -12,8 +12,10 @@ use crate::ServerState; use crate::persistence::{DraftChunkRecord, PolicyRecord, Store}; +use openshell_core::proto::policy_merge_operation; use openshell_core::proto::setting_value; use openshell_core::proto::{ + AddAllowRules as ProtoAddAllowRules, AddDenyRules as ProtoAddDenyRules, ApproveAllDraftChunksRequest, ApproveAllDraftChunksResponse, ApproveDraftChunkRequest, ApproveDraftChunkResponse, ClearDraftChunksRequest, ClearDraftChunksResponse, DraftHistoryEntry, EditDraftChunkRequest, EditDraftChunkResponse, EffectiveSetting, @@ -22,15 +24,16 @@ use openshell_core::proto::{ GetSandboxConfigResponse, GetSandboxLogsRequest, GetSandboxLogsResponse, GetSandboxPolicyStatusRequest, GetSandboxPolicyStatusResponse, GetSandboxProviderEnvironmentRequest, GetSandboxProviderEnvironmentResponse, - ListSandboxPoliciesRequest, ListSandboxPoliciesResponse, PolicyChunk, PolicySource, - PolicyStatus, PushSandboxLogsRequest, PushSandboxLogsResponse, RejectDraftChunkRequest, - RejectDraftChunkResponse, ReportPolicyStatusRequest, ReportPolicyStatusResponse, - SandboxLogLine, SandboxPolicyRevision, SettingScope, SettingValue, SubmitPolicyAnalysisRequest, - SubmitPolicyAnalysisResponse, UndoDraftChunkRequest, UndoDraftChunkResponse, - UpdateConfigRequest, UpdateConfigResponse, + ListSandboxPoliciesRequest, ListSandboxPoliciesResponse, PolicyChunk, PolicyMergeOperation, + PolicySource, PolicyStatus, PushSandboxLogsRequest, PushSandboxLogsResponse, + RejectDraftChunkRequest, RejectDraftChunkResponse, ReportPolicyStatusRequest, + ReportPolicyStatusResponse, SandboxLogLine, SandboxPolicyRevision, SettingScope, SettingValue, + SubmitPolicyAnalysisRequest, SubmitPolicyAnalysisResponse, UndoDraftChunkRequest, + UndoDraftChunkResponse, UpdateConfigRequest, UpdateConfigResponse, }; use openshell_core::proto::{Sandbox, SandboxPolicy as ProtoSandboxPolicy}; use openshell_core::settings::{self, SettingValueKind}; +use openshell_policy::{PolicyMergeOp, merge_policy}; use prost::Message; use sha2::{Digest, Sha256}; use std::collections::{BTreeMap, HashMap}; @@ -242,21 +245,32 @@ pub(super) async fn handle_update_config( let key = req.setting_key.trim(); let has_policy = req.policy.is_some(); let has_setting = !key.is_empty(); + let has_merge_ops = !req.merge_operations.is_empty(); + let mut mutation_count = 0_u8; + mutation_count += u8::from(has_policy); + mutation_count += u8::from(has_setting); + mutation_count += u8::from(has_merge_ops); - if has_policy && has_setting { + if mutation_count > 1 { return Err(Status::invalid_argument( - "policy and setting_key cannot be set in the same request", + "policy, setting_key, and merge_operations are mutually exclusive", )); } - if !has_policy && !has_setting { + if mutation_count == 0 { return Err(Status::invalid_argument( - "either policy or setting_key must be provided", + "one of policy, setting_key, or merge_operations must be provided", )); } if req.global { let _settings_guard = state.settings_mutex.lock().await; + if has_merge_ops { + return Err(Status::invalid_argument( + "merge_operations are not supported for global policy updates", + )); + } + if has_policy { if req.delete_setting { return Err(Status::invalid_argument( @@ -493,6 +507,45 @@ pub(super) async fn handle_update_config( })); } + if has_merge_ops { + let global_settings = load_global_settings(state.store.as_ref()).await?; + if global_settings.settings.contains_key(POLICY_SETTING_KEY) { + return Err(Status::failed_precondition( + "policy is managed globally; delete global policy before sandbox policy update", + )); + } + + let spec = sandbox + .spec + .as_ref() + .ok_or_else(|| Status::internal("sandbox has no spec"))?; + let merge_ops = parse_merge_operations(&req.merge_operations)?; + validate_merge_operations_for_server(&merge_ops)?; + let (version, hash) = apply_merge_operations_with_retry( + state.store.as_ref(), + &sandbox_id, + spec.policy.as_ref(), + &merge_ops, + ) + .await?; + + state.sandbox_watch_bus.notify(&sandbox_id); + info!( + sandbox_id = %sandbox_id, + version, + policy_hash = %hash, + operation_count = merge_ops.len(), + "UpdateConfig: merged incremental policy operations" + ); + + return Ok(Response::new(UpdateConfigResponse { + version: u32::try_from(version).unwrap_or(0), + policy_hash: hash, + settings_revision: 0, + deleted: false, + })); + } + // Sandbox-scoped policy update. let mut new_policy = req .policy @@ -1677,90 +1730,203 @@ async fn require_no_global_policy(state: &ServerState) -> Result<(), Status> { Ok(()) } -pub(super) async fn merge_chunk_into_policy( - store: &Store, - sandbox_id: &str, - chunk: &DraftChunkRecord, -) -> Result<(i64, String), Status> { - use openshell_core::proto::NetworkPolicyRule; +fn parse_merge_operations( + proto_ops: &[PolicyMergeOperation], +) -> Result, Status> { + proto_ops + .iter() + .enumerate() + .map(|(index, operation)| { + let Some(operation) = operation.operation.as_ref() else { + return Err(Status::invalid_argument(format!( + "merge_operations[{index}] is missing an operation" + ))); + }; - let rule = NetworkPolicyRule::decode(chunk.proposed_rule.as_slice()) - .map_err(|e| Status::internal(format!("decode proposed_rule failed: {e}")))?; + match operation { + policy_merge_operation::Operation::AddRule(add_rule) => { + let rule_name = add_rule.rule_name.trim(); + if rule_name.is_empty() { + return Err(Status::invalid_argument(format!( + "merge_operations[{index}].add_rule.rule_name is required" + ))); + } + if add_rule.rule.as_ref().is_none_or(|rule| rule.endpoints.is_empty()) { + return Err(Status::invalid_argument(format!( + "merge_operations[{index}].add_rule.rule must contain at least one endpoint" + ))); + } + Ok(PolicyMergeOp::AddRule { + rule_name: rule_name.to_string(), + rule: add_rule.rule.clone().unwrap_or_default(), + }) + } + policy_merge_operation::Operation::RemoveEndpoint(remove_endpoint) => { + if remove_endpoint.host.trim().is_empty() || remove_endpoint.port == 0 { + return Err(Status::invalid_argument(format!( + "merge_operations[{index}].remove_endpoint requires host and non-zero port" + ))); + } + let rule_name = if remove_endpoint.rule_name.trim().is_empty() { + None + } else { + Some(remove_endpoint.rule_name.trim().to_string()) + }; + Ok(PolicyMergeOp::RemoveEndpoint { + rule_name, + host: remove_endpoint.host.trim().to_string(), + port: remove_endpoint.port, + }) + } + policy_merge_operation::Operation::RemoveRule(remove_rule) => { + let rule_name = remove_rule.rule_name.trim(); + if rule_name.is_empty() { + return Err(Status::invalid_argument(format!( + "merge_operations[{index}].remove_rule.rule_name is required" + ))); + } + Ok(PolicyMergeOp::RemoveRule { + rule_name: rule_name.to_string(), + }) + } + policy_merge_operation::Operation::AddDenyRules(add_deny_rules) => { + parse_proto_add_deny_rules(index, add_deny_rules) + } + policy_merge_operation::Operation::AddAllowRules(add_allow_rules) => { + parse_proto_add_allow_rules(index, add_allow_rules) + } + policy_merge_operation::Operation::RemoveBinary(remove_binary) => { + let rule_name = remove_binary.rule_name.trim(); + let binary_path = remove_binary.binary_path.trim(); + if rule_name.is_empty() || binary_path.is_empty() { + return Err(Status::invalid_argument(format!( + "merge_operations[{index}].remove_binary requires rule_name and binary_path" + ))); + } + Ok(PolicyMergeOp::RemoveBinary { + rule_name: rule_name.to_string(), + binary_path: binary_path.to_string(), + }) + } + } + }) + .collect() +} + +fn parse_proto_add_deny_rules( + index: usize, + add_deny_rules: &ProtoAddDenyRules, +) -> Result { + if add_deny_rules.host.trim().is_empty() + || add_deny_rules.port == 0 + || add_deny_rules.deny_rules.is_empty() + { + return Err(Status::invalid_argument(format!( + "merge_operations[{index}].add_deny_rules requires host, non-zero port, and at least one deny rule" + ))); + } + + Ok(PolicyMergeOp::AddDenyRules { + host: add_deny_rules.host.trim().to_string(), + port: add_deny_rules.port, + deny_rules: add_deny_rules.deny_rules.clone(), + }) +} + +fn parse_proto_add_allow_rules( + index: usize, + add_allow_rules: &ProtoAddAllowRules, +) -> Result { + if add_allow_rules.host.trim().is_empty() + || add_allow_rules.port == 0 + || add_allow_rules.rules.is_empty() + { + return Err(Status::invalid_argument(format!( + "merge_operations[{index}].add_allow_rules requires host, non-zero port, and at least one allow rule" + ))); + } + if add_allow_rules + .rules + .iter() + .any(|rule| rule.allow.as_ref().is_none()) + { + return Err(Status::invalid_argument(format!( + "merge_operations[{index}].add_allow_rules rules must include allow payloads" + ))); + } + + Ok(PolicyMergeOp::AddAllowRules { + host: add_allow_rules.host.trim().to_string(), + port: add_allow_rules.port, + rules: add_allow_rules.rules.clone(), + }) +} - // Defense-in-depth: reject proposed rules targeting always-blocked - // destinations. Even if the sandbox mapper didn't filter these (e.g., - // an older sandbox version), the proxy will deny them at runtime. - validate_rule_not_always_blocked(&rule)?; +fn validate_merge_operations_for_server(operations: &[PolicyMergeOp]) -> Result<(), Status> { + for operation in operations { + if let PolicyMergeOp::AddRule { rule, .. } = operation { + validate_rule_not_always_blocked(rule)?; + } + } + Ok(()) +} + +fn map_policy_merge_error(error: openshell_policy::PolicyMergeError) -> Status { + match error { + openshell_policy::PolicyMergeError::MissingRuleNameForAddRule + | openshell_policy::PolicyMergeError::InvalidEndpointReference { .. } + | openshell_policy::PolicyMergeError::UnsupportedAccessPreset { .. } => { + Status::invalid_argument(error.to_string()) + } + openshell_policy::PolicyMergeError::EndpointNotFound { .. } + | openshell_policy::PolicyMergeError::EndpointHasNoL7Inspection { .. } + | openshell_policy::PolicyMergeError::UnsupportedEndpointProtocol { .. } + | openshell_policy::PolicyMergeError::EndpointHasNoAllowBase { .. } => { + Status::failed_precondition(error.to_string()) + } + } +} +async fn apply_merge_operations_with_retry( + store: &Store, + sandbox_id: &str, + baseline_policy: Option<&ProtoSandboxPolicy>, + operations: &[PolicyMergeOp], +) -> Result<(i64, String), Status> { for attempt in 1..=MERGE_RETRY_LIMIT { let latest = store .get_latest_policy(sandbox_id) .await .map_err(|e| Status::internal(format!("fetch latest policy failed: {e}")))?; - let mut policy = if let Some(ref record) = latest { + let current_policy = if let Some(ref record) = latest { ProtoSandboxPolicy::decode(record.policy_payload.as_slice()) .map_err(|e| Status::internal(format!("decode current policy failed: {e}")))? } else { - ProtoSandboxPolicy::default() + baseline_policy.cloned().unwrap_or_default() }; - let base_version = latest.as_ref().map_or(0, |r| r.version); + let merged = merge_policy(current_policy, operations).map_err(map_policy_merge_error)?; + let new_policy = merged.policy; + let hash = deterministic_policy_hash(&new_policy); - let chunk_host_lc = chunk.host.to_lowercase(); - let chunk_port = chunk.port as u32; + if let Some(baseline_policy) = baseline_policy { + validate_static_fields_unchanged(baseline_policy, &new_policy)?; + } + validate_policy_safety(&new_policy)?; - let merge_key = if policy.network_policies.contains_key(&chunk.rule_name) { - Some(chunk.rule_name.clone()) - } else { - policy - .network_policies - .iter() - .find_map(|(key, existing_rule)| { - let has_match = existing_rule.endpoints.iter().any(|ep| { - let host_match = ep.host.to_lowercase() == chunk_host_lc; - let port_match = if ep.ports.is_empty() { - ep.port == chunk_port - } else { - ep.ports.contains(&chunk_port) - }; - host_match && port_match - }); - has_match.then(|| key.clone()) - }) - }; + if let Some(ref current) = latest + && current.policy_hash == hash + { + return Ok((current.version, hash)); + } - if let Some(key) = merge_key { - let existing = policy.network_policies.get_mut(&key).unwrap(); - for b in &rule.binaries { - if !existing.binaries.iter().any(|eb| eb.path == b.path) { - existing.binaries.push(b.clone()); - } - } - for ep in &rule.endpoints { - if let Some(existing_ep) = existing.endpoints.iter_mut().find(|e| { - e.host.to_lowercase() == ep.host.to_lowercase() - && (e.port == ep.port - || (!e.ports.is_empty() && e.ports.contains(&ep.port))) - }) { - for ip in &ep.allowed_ips { - if !existing_ep.allowed_ips.contains(ip) { - existing_ep.allowed_ips.push(ip.clone()); - } - } - } else { - existing.endpoints.push(ep.clone()); - } - } - } else { - policy - .network_policies - .insert(chunk.rule_name.clone(), rule.clone()); + if latest.is_none() && !merged.changed { + return Ok((0, hash)); } - let payload = policy.encode_to_vec(); - let hash = deterministic_policy_hash(&policy); - let next_version = base_version + 1; + let payload = new_policy.encode_to_vec(); + let next_version = latest.as_ref().map_or(1, |record| record.version + 1); let policy_id = uuid::Uuid::new_v4().to_string(); match store @@ -1775,10 +1941,10 @@ pub(super) async fn merge_chunk_into_policy( if attempt > 1 { info!( sandbox_id = %sandbox_id, - rule_name = %chunk.rule_name, attempt, version = next_version, - "merge_chunk_into_policy: succeeded after version conflict retry" + operation_count = operations.len(), + "apply_merge_operations_with_retry: succeeded after version conflict retry" ); } @@ -1789,10 +1955,10 @@ pub(super) async fn merge_chunk_into_policy( if msg.contains("UNIQUE") || msg.contains("unique") || msg.contains("duplicate") { warn!( sandbox_id = %sandbox_id, - rule_name = %chunk.rule_name, attempt, conflicting_version = next_version, - "merge_chunk_into_policy: version conflict, retrying" + operation_count = operations.len(), + "apply_merge_operations_with_retry: version conflict, retrying" ); tokio::task::yield_now().await; continue; @@ -1805,90 +1971,44 @@ pub(super) async fn merge_chunk_into_policy( } Err(Status::aborted(format!( - "merge_chunk_into_policy: gave up after {} version conflict retries for rule '{}'", - MERGE_RETRY_LIMIT, chunk.rule_name + "apply_merge_operations_with_retry: gave up after {MERGE_RETRY_LIMIT} version conflict retries" ))) } +pub(super) async fn merge_chunk_into_policy( + store: &Store, + sandbox_id: &str, + chunk: &DraftChunkRecord, +) -> Result<(i64, String), Status> { + let rule = openshell_core::proto::NetworkPolicyRule::decode(chunk.proposed_rule.as_slice()) + .map_err(|e| Status::internal(format!("decode proposed_rule failed: {e}")))?; + apply_merge_operations_with_retry( + store, + sandbox_id, + None, + &[PolicyMergeOp::AddRule { + rule_name: chunk.rule_name.clone(), + rule, + }], + ) + .await +} + async fn remove_chunk_from_policy( state: &ServerState, sandbox_id: &str, chunk: &DraftChunkRecord, ) -> Result<(i64, String), Status> { - for attempt in 1..=MERGE_RETRY_LIMIT { - let latest = state - .store - .get_latest_policy(sandbox_id) - .await - .map_err(|e| Status::internal(format!("fetch latest policy failed: {e}")))? - .ok_or_else(|| Status::internal("no active policy to undo from"))?; - - let mut policy = ProtoSandboxPolicy::decode(latest.policy_payload.as_slice()) - .map_err(|e| Status::internal(format!("decode current policy failed: {e}")))?; - - let should_remove = - if let Some(existing) = policy.network_policies.get_mut(&chunk.rule_name) { - existing.binaries.retain(|b| b.path != chunk.binary); - existing.binaries.is_empty() - } else { - false - }; - if should_remove { - policy.network_policies.remove(&chunk.rule_name); - } - - let payload = policy.encode_to_vec(); - let hash = deterministic_policy_hash(&policy); - let next_version = latest.version + 1; - let policy_id = uuid::Uuid::new_v4().to_string(); - - match state - .store - .put_policy_revision(&policy_id, sandbox_id, next_version, &payload, &hash) - .await - { - Ok(()) => { - let _ = state - .store - .supersede_older_policies(sandbox_id, next_version) - .await; - - if attempt > 1 { - info!( - sandbox_id = %sandbox_id, - rule_name = %chunk.rule_name, - attempt, - version = next_version, - "remove_chunk_from_policy: succeeded after version conflict retry" - ); - } - - return Ok((next_version, hash)); - } - Err(e) => { - let msg = e.to_string(); - if msg.contains("UNIQUE") || msg.contains("unique") || msg.contains("duplicate") { - warn!( - sandbox_id = %sandbox_id, - rule_name = %chunk.rule_name, - attempt, - conflicting_version = next_version, - "remove_chunk_from_policy: version conflict, retrying" - ); - tokio::task::yield_now().await; - continue; - } - return Err(Status::internal(format!( - "persist policy revision failed: {e}" - ))); - } - } - } - - Err(Status::aborted(format!( - "remove_chunk_from_policy: gave up after {} version conflict retries for rule '{}'", - MERGE_RETRY_LIMIT, chunk.rule_name - ))) + apply_merge_operations_with_retry( + state.store.as_ref(), + sandbox_id, + None, + &[PolicyMergeOp::RemoveBinary { + rule_name: chunk.rule_name.clone(), + binary_path: chunk.binary.clone(), + }], + ) + .await } // --------------------------------------------------------------------------- @@ -2151,6 +2271,7 @@ mod tests { use super::*; use crate::persistence::Store; use std::collections::HashMap; + use std::sync::Arc; use tonic::Code; // ---- Sandbox without policy ---- @@ -2184,9 +2305,7 @@ mod tests { #[tokio::test] async fn sandbox_policy_backfill_on_update_when_no_baseline() { - use openshell_core::proto::{ - FilesystemPolicy, LandlockPolicy, ProcessPolicy, SandboxPhase, SandboxSpec, - }; + use openshell_core::proto::{FilesystemPolicy, LandlockPolicy, SandboxPhase, SandboxSpec}; let store = Store::connect("sqlite::memory:").await.unwrap(); @@ -2488,6 +2607,89 @@ mod tests { assert!(policy.network_policies.contains_key("allow_10_0_0_5_8080")); } + #[tokio::test] + async fn concurrent_merge_batches_preserve_both_updates() { + use openshell_core::proto::{ + L7Allow, L7DenyRule, L7Rule, NetworkEndpoint, NetworkPolicyRule, SandboxPolicy, + }; + + let store = Store::connect("sqlite::memory:").await.unwrap(); + let sandbox_id = "sb-concurrent-merge"; + + let initial_policy = SandboxPolicy { + network_policies: [( + "github".to_string(), + NetworkPolicyRule { + name: "github".to_string(), + endpoints: vec![NetworkEndpoint { + host: "api.github.com".to_string(), + port: 443, + ports: vec![443], + protocol: "rest".to_string(), + access: "read-only".to_string(), + ..Default::default() + }], + ..Default::default() + }, + )] + .into_iter() + .collect(), + ..Default::default() + }; + store + .put_policy_revision( + "p-seed", + sandbox_id, + 1, + &initial_policy.encode_to_vec(), + "seed-hash", + ) + .await + .unwrap(); + + let add_allow = [PolicyMergeOp::AddAllowRules { + host: "api.github.com".to_string(), + port: 443, + rules: vec![L7Rule { + allow: Some(L7Allow { + method: "POST".to_string(), + path: "/repos/*/issues".to_string(), + command: String::new(), + query: HashMap::new(), + }), + }], + }]; + let add_deny = [PolicyMergeOp::AddDenyRules { + host: "api.github.com".to_string(), + port: 443, + deny_rules: vec![L7DenyRule { + method: "POST".to_string(), + path: "/admin".to_string(), + query: HashMap::new(), + ..Default::default() + }], + }]; + + let (left, right) = tokio::join!( + apply_merge_operations_with_retry(&store, sandbox_id, None, &add_allow), + apply_merge_operations_with_retry(&store, sandbox_id, None, &add_deny), + ); + + let mut versions = vec![left.unwrap().0, right.unwrap().0]; + versions.sort_unstable(); + assert_eq!(versions, vec![2, 3]); + + let latest = store.get_latest_policy(sandbox_id).await.unwrap().unwrap(); + assert_eq!(latest.version, 3); + + let policy = SandboxPolicy::decode(latest.policy_payload.as_slice()).unwrap(); + let endpoint = &policy.network_policies["github"].endpoints[0]; + assert!(endpoint.access.is_empty()); + assert_eq!(endpoint.rules.len(), 4); + assert_eq!(endpoint.deny_rules.len(), 1); + assert_eq!(endpoint.deny_rules[0].path, "/admin"); + } + // ---- validate_rule_not_always_blocked ---- #[test] @@ -2608,7 +2810,7 @@ mod tests { let global = StoredSettings::default(); let sandbox = StoredSettings::default(); let merged = merge_effective_settings(&global, &sandbox).unwrap(); - for registered in openshell_core::settings::REGISTERED_SETTINGS { + for registered in settings::REGISTERED_SETTINGS { let setting = merged .get(registered.key) .unwrap_or_else(|| panic!("missing registered key {}", registered.key)); @@ -2625,7 +2827,7 @@ mod tests { fn materialize_global_settings_includes_unset_registered_keys() { let global = StoredSettings::default(); let materialized = materialize_global_settings(&global).unwrap(); - for registered in openshell_core::settings::REGISTERED_SETTINGS { + for registered in settings::REGISTERED_SETTINGS { let setting = materialized .get(registered.key) .unwrap_or_else(|| panic!("missing registered key {}", registered.key)); @@ -2785,7 +2987,7 @@ mod tests { let global = StoredSettings::default(); let sandbox = StoredSettings::default(); let merged = merge_effective_settings(&global, &sandbox).unwrap(); - for registered in openshell_core::settings::REGISTERED_SETTINGS { + for registered in settings::REGISTERED_SETTINGS { let setting = merged.get(registered.key).unwrap(); assert_eq!(setting.scope, SettingScope::Unspecified as i32); assert!(setting.value.is_none()); @@ -3066,12 +3268,12 @@ mod tests { #[tokio::test] async fn concurrent_global_setting_mutations_are_serialized() { - let store = std::sync::Arc::new( + let store = Arc::new( Store::connect("sqlite::memory:?cache=shared") .await .unwrap(), ); - let mutex = std::sync::Arc::new(tokio::sync::Mutex::new(())); + let mutex = Arc::new(tokio::sync::Mutex::new(())); let n = 50; let mut handles = Vec::with_capacity(n); @@ -3101,7 +3303,7 @@ mod tests { #[tokio::test] async fn concurrent_global_setting_mutations_without_lock_can_lose_writes() { - let store = std::sync::Arc::new( + let store = Arc::new( Store::connect("sqlite::memory:?cache=shared") .await .unwrap(), diff --git a/crates/openshell-tui/src/lib.rs b/crates/openshell-tui/src/lib.rs index f187f59fb..63cfb79d6 100644 --- a/crates/openshell-tui/src/lib.rs +++ b/crates/openshell-tui/src/lib.rs @@ -1960,6 +1960,7 @@ fn spawn_set_global_setting(app: &App, tx: mpsc::UnboundedSender) { setting_value: Some(SettingValue { value: Some(value) }), delete_setting: false, global: true, + merge_operations: vec![], }; let result = tokio::time::timeout(Duration::from_secs(5), client.update_config(req)).await; @@ -1994,6 +1995,7 @@ fn spawn_delete_global_setting(app: &App, tx: mpsc::UnboundedSender) { setting_value: None, delete_setting: true, global: true, + merge_operations: vec![], }; let result = tokio::time::timeout(Duration::from_secs(5), client.update_config(req)).await; @@ -2062,6 +2064,7 @@ fn spawn_set_sandbox_setting(app: &App, tx: mpsc::UnboundedSender) { setting_value: Some(SettingValue { value: Some(value) }), delete_setting: false, global: false, + merge_operations: vec![], }; let result = tokio::time::timeout(Duration::from_secs(5), client.update_config(req)).await; @@ -2100,6 +2103,7 @@ fn spawn_delete_sandbox_setting(app: &App, tx: mpsc::UnboundedSender) { setting_value: None, delete_setting: true, global: false, + merge_operations: vec![], }; let result = tokio::time::timeout(Duration::from_secs(5), client.update_config(req)).await; diff --git a/docs/reference/policy-schema.mdx b/docs/reference/policy-schema.mdx index 3e505cf3e..7152731cf 100644 --- a/docs/reference/policy-schema.mdx +++ b/docs/reference/policy-schema.mdx @@ -30,7 +30,7 @@ network_policies: { ... } | `process` | object | No | Static | Sets the user and group the agent process runs as. | | `network_policies` | map | No | Dynamic | Declares which binaries can reach which network endpoints. | -Static fields are set at sandbox creation time. Changing them requires destroying and recreating the sandbox. Dynamic fields can be updated on a running sandbox with `openshell policy set` and take effect without restarting. +Static fields are set at sandbox creation time. Changing them requires destroying and recreating the sandbox. Dynamic fields can be updated on a running sandbox with `openshell policy update` for incremental merges or `openshell policy set` for full replacement, and take effect without restarting. ## Version diff --git a/docs/sandboxes/policies.mdx b/docs/sandboxes/policies.mdx index 8d4831f1b..781981453 100644 --- a/docs/sandboxes/policies.mdx +++ b/docs/sandboxes/policies.mdx @@ -121,21 +121,36 @@ The following steps outline the hot-reload policy update workflow. openshell logs --tail --source sandbox ``` -3. Pull the current policy. Strip the metadata header (Version, Hash, Status) before reusing the file. +3. For additive network changes, use `openshell policy update`. This is the fastest path for adding endpoints, binaries, or REST allow/deny rules without replacing the full policy. + + ```shell + openshell policy update \ + --add-endpoint api.github.com:443:read-only:rest:enforce \ + --binary /usr/bin/gh \ + --wait + + openshell policy update \ + --add-allow api.github.com:443:POST:/repos/*/issues \ + --wait + ``` + + `--add-allow` and `--add-deny` currently target existing `protocol: rest` endpoints only. If you pass multiple update flags in one command, OpenShell applies them as one atomic merge batch and persists at most one new revision. + +4. For larger edits, pull the current policy and edit the YAML directly. Strip the metadata header (Version, Hash, Status) before reusing the file. ```shell openshell policy get --full > current-policy.yaml ``` -4. Edit the YAML: add or adjust `network_policies` entries, binaries, `access`, or `rules`. +5. Edit the YAML: add or adjust `network_policies` entries, binaries, `access`, or `rules`. -5. Push the updated policy. Exit codes: 0 = loaded, 1 = validation failed, 124 = timeout. +6. Push the updated policy when you need a full replacement. Exit codes: 0 = loaded, 1 = validation failed, 124 = timeout. ```shell openshell policy set --policy current-policy.yaml --wait ``` -6. Verify the new revision. If status is `loaded`, repeat from step 2 as needed; if `failed`, fix the policy and repeat from step 4. +7. Verify the new revision. If status is `loaded`, repeat from step 2 as needed; if `failed`, fix the policy and repeat from step 4. ```shell openshell policy list @@ -178,9 +193,15 @@ When triaging denied requests, check: Then push the updated policy as described above. +For small changes, prefer `openshell policy update` over rewriting the full YAML: + +```shell +openshell policy update --add-allow api.github.com:443:GET:/repos/** --wait +``` + ## Examples -Add these blocks to the `network_policies` section of your sandbox policy. Apply with `openshell policy set --policy --wait`. +Add these blocks to the `network_policies` section of your sandbox policy. Apply with `openshell policy update` for incremental additions or `openshell policy set --policy --wait` for full replacement. Use **Simple endpoint** for host-level allowlists and **Granular rules** for method/path control. diff --git a/docs/security/best-practices.mdx b/docs/security/best-practices.mdx index a84800e4c..c0e0026ce 100644 --- a/docs/security/best-practices.mdx +++ b/docs/security/best-practices.mdx @@ -24,11 +24,11 @@ If you use [NemoClaw](https://github.com/NVIDIA/NemoClaw) to run OpenClaw assist OpenShell applies security controls at two enforcement points. OpenShell locks static controls at sandbox creation and requires destroying and recreating the sandbox to change them. -You can update dynamic controls on a running sandbox with `openshell policy set`. +You can update dynamic controls on a running sandbox with `openshell policy update` or `openshell policy set`. | Layer | What it protects | Enforcement point | Changeable at runtime | | --- | --- | --- | --- | -| Network | Unauthorized outbound connections and data exfiltration. | CONNECT proxy + OPA policy engine | Yes. Use `openshell policy set` or operator approval in the TUI. | +| Network | Unauthorized outbound connections and data exfiltration. | CONNECT proxy + OPA policy engine | Yes. Use `openshell policy update`, `openshell policy set`, or operator approval in the TUI. | | Filesystem | System binary tampering, credential theft, config manipulation. | Landlock LSM (kernel level) | No. Requires sandbox re-creation. | | Process | Privilege escalation, fork bombs, dangerous syscalls. | Seccomp BPF + privilege drop (`setuid`/`setgid`) | No. Requires sandbox re-creation. | | Inference | Credential exposure, unauthorized model access. | Proxy intercept of `inference.local` | Yes. Use `openshell inference set`. | @@ -46,7 +46,7 @@ If no `network_policies` entry matches the destination host, port, and calling b | Aspect | Detail | |---|---| | Default | All egress denied. Only endpoints listed in `network_policies` can receive traffic. | -| What you can change | Add entries to `network_policies` in the policy YAML. Apply statically at creation (`--policy`) or dynamically (`openshell policy set`). | +| What you can change | Add entries to `network_policies` in the policy YAML. Apply statically at creation (`--policy`) or dynamically (`openshell policy update` for incremental changes, `openshell policy set` for full replacement). | | Risk if relaxed | Each allowed endpoint is a potential data exfiltration path. The agent can send workspace content, credentials, or conversation history to any reachable host. | | Recommendation | Add only endpoints the agent needs for its task. Start with a minimal policy and use denied-request logs (`openshell logs --source sandbox`) to identify missing endpoints. | diff --git a/proto/openshell.proto b/proto/openshell.proto index 0ee1e8904..b863e6251 100644 --- a/proto/openshell.proto +++ b/proto/openshell.proto @@ -574,6 +574,51 @@ message UpdateConfigRequest { bool delete_setting = 5; // Apply mutation at gateway-global scope. bool global = 6; + // Batched incremental policy merge operations. Sandbox-scoped only. + repeated PolicyMergeOperation merge_operations = 7; +} + +message PolicyMergeOperation { + oneof operation { + AddNetworkRule add_rule = 1; + RemoveNetworkEndpoint remove_endpoint = 2; + RemoveNetworkRule remove_rule = 3; + AddDenyRules add_deny_rules = 4; + AddAllowRules add_allow_rules = 5; + RemoveNetworkBinary remove_binary = 6; + } +} + +message AddNetworkRule { + string rule_name = 1; + openshell.sandbox.v1.NetworkPolicyRule rule = 2; +} + +message RemoveNetworkEndpoint { + string rule_name = 1; + string host = 2; + uint32 port = 3; +} + +message RemoveNetworkRule { + string rule_name = 1; +} + +message AddDenyRules { + string host = 1; + uint32 port = 2; + repeated openshell.sandbox.v1.L7DenyRule deny_rules = 3; +} + +message AddAllowRules { + string host = 1; + uint32 port = 2; + repeated openshell.sandbox.v1.L7Rule rules = 3; +} + +message RemoveNetworkBinary { + string rule_name = 1; + string binary_path = 2; } // Update sandbox policy response.