Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions proxy_agent/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ clap = { version = "4.5.17", features =["derive"] } # Command Line Argument Pars
thiserror = "1.0.64"
libc = "0.2.147"
socket2 = "0.5" # Set socket options without tokio/std conversion
base64 = "0.22"

[dependencies.uuid]
version = "1.3.0"
Expand Down
188 changes: 127 additions & 61 deletions proxy_agent/src/key_keeper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,16 @@
//! ```

pub mod key;
pub mod local_rules;

use self::key::Key;
use self::local_rules::{
get_rules_dir_from_key_dir, resolve_effective_rules, LocalRuleStateTracker, LocalRuleTarget,
};
use crate::common::error::{Error, KeyErrorType};
use crate::common::result::Result;
use crate::common::{constants, helpers, logger};
use crate::key_keeper::key::KeyStatus;
use crate::key_keeper::key::{AuthorizationRules, KeyStatus};
use crate::provision;
use crate::proxy::authorization_rules::{AuthorizationRulesForLogging, ComputedAuthorizationRules};
use crate::shared_state::access_control_wrapper::AccessControlSharedState;
Expand Down Expand Up @@ -70,6 +74,8 @@ pub struct KeyKeeper {
port: u16,
/// key_dir: the folder to save the key details
key_dir: PathBuf,
/// rules_dir: the folder to save customer-managed local access control rules
rules_dir: PathBuf,
/// status_dir: the folder to log the access control rule details
status_dir: PathBuf,
/// interval: the interval to poll the secure channel status
Expand Down Expand Up @@ -110,10 +116,12 @@ impl KeyKeeper {
interval: Duration,
shared_state: &SharedState,
) -> Self {
let rules_dir = get_rules_dir_from_key_dir(&key_dir);
KeyKeeper {
host,
port,
key_dir,
rules_dir,
status_dir,
interval,
cancellation_token: shared_state.get_cancellation_token(),
Expand All @@ -132,34 +140,8 @@ impl KeyKeeper {
self.update_status_message("poll secure channel status task started.".to_string(), true)
.await;

if let Err(e) = misc_helpers::try_create_folder(&self.key_dir) {
logger::write_warning(format!(
"key folder {} created failed with error {}.",
misc_helpers::path_to_string(&self.key_dir),
e
));
} else {
logger::write(format!(
"key folder {} created if not exists before.",
misc_helpers::path_to_string(&self.key_dir)
));
}

match acl::acl_directory(self.key_dir.clone()) {
Ok(()) => {
logger::write(format!(
"Folder {} ACLed if has not before.",
misc_helpers::path_to_string(&self.key_dir)
));
}
Err(e) => {
logger::write_warning(format!(
"Folder {} ACLed failed with error {}.",
misc_helpers::path_to_string(&self.key_dir),
e
));
}
}
self.ensure_secure_directory(&self.key_dir, "key");
self.ensure_secure_directory(&self.rules_dir, "rules");

// acl current executable dir
#[cfg(windows)]
Expand Down Expand Up @@ -194,6 +176,7 @@ impl KeyKeeper {
let mut first_iteration: bool = true;
let mut started_event_threads: bool = false;
let mut provision_timeout: bool = false;
let mut local_rule_state_tracker = LocalRuleStateTracker::default();
let notify = match self.key_keeper_shared_state.get_notify().await {
Ok(notify) => notify,
Err(e) => {
Expand Down Expand Up @@ -280,7 +263,8 @@ impl KeyKeeper {
previous_key_status_message = Some(key_status_message);
}

self.update_access_control_rules(&status).await;
self.update_access_control_rules(&status, &mut local_rule_state_tracker)
.await;

let state = status.get_secure_channel_state();
let secure_channel_state_updated = self
Expand Down Expand Up @@ -483,11 +467,18 @@ impl KeyKeeper {

/// Update access control rules from the key status
/// Returns true if any rules changed
async fn update_access_control_rules(&self, status: &KeyStatus) -> bool {
async fn update_access_control_rules(
&self,
status: &KeyStatus,
local_rule_state_tracker: &mut LocalRuleStateTracker,
) -> bool {
let mut access_control_rules_changed = false;
let wireserver_rule_id = status.get_wireserver_rule_id();
let imds_rule_id = status.get_imds_rule_id();
let hostga_rule_id = status.get_hostga_rule_id();
let mut wireserver_rule_id_changed = false;
let mut imds_rule_id_changed = false;
let mut hostga_rule_id_changed = false;

// Update wireserver rules
match self
Expand All @@ -496,18 +487,11 @@ impl KeyKeeper {
.await
{
Ok((updated, old_wire_server_rule_id)) => {
wireserver_rule_id_changed = updated;
if updated {
logger::write_warning(format!(
"Wireserver rule id changed from '{old_wire_server_rule_id}' to '{wireserver_rule_id}'."
));
if let Err(e) = self
.access_control_shared_state
.set_wireserver_rules(status.get_wireserver_rules())
.await
{
logger::write_error(format!("Failed to set wireserver rules: {e}"));
}
access_control_rules_changed = true;
}
}
Err(e) => {
Expand All @@ -522,18 +506,11 @@ impl KeyKeeper {
.await
{
Ok((updated, old_imds_rule_id)) => {
imds_rule_id_changed = updated;
if updated {
logger::write_warning(format!(
"IMDS rule id changed from '{old_imds_rule_id}' to '{imds_rule_id}'."
));
if let Err(e) = self
.access_control_shared_state
.set_imds_rules(status.get_imds_rules())
.await
{
logger::write_error(format!("Failed to set imds rules: {e}"));
}
access_control_rules_changed = true;
}
}
Err(e) => {
Expand All @@ -548,40 +525,94 @@ impl KeyKeeper {
.await
{
Ok((updated, old_hostga_rule_id)) => {
hostga_rule_id_changed = updated;
if updated {
logger::write_warning(format!(
"HostGA rule id changed from '{old_hostga_rule_id}' to '{hostga_rule_id}'."
));
if let Err(e) = self
.access_control_shared_state
.set_hostga_rules(status.get_hostga_rules())
.await
{
logger::write_error(format!("Failed to set HostGA rules: {e}"));
}
access_control_rules_changed = true;
}
}
Err(e) => {
logger::write_warning(format!("Failed to update HostGA rule id: {e}"));
}
}

let (wireserver_rules, wireserver_local_state_changed) = resolve_effective_rules(
&self.rules_dir,
status.get_wireserver_rules(),
LocalRuleTarget::WireServer,
&mut local_rule_state_tracker.wireserver,
wireserver_rule_id_changed,
)
.await;
let (imds_rules, imds_local_state_changed) = resolve_effective_rules(
&self.rules_dir,
status.get_imds_rules(),
LocalRuleTarget::Imds,
&mut local_rule_state_tracker.imds,
imds_rule_id_changed,
)
.await;

if wireserver_rule_id_changed || wireserver_local_state_changed {
if let Err(e) = self
.access_control_shared_state
.set_wireserver_rules(wireserver_rules.clone())
.await
{
logger::write_error(format!("Failed to set wireserver rules: {e}"));
}
access_control_rules_changed = true;
}

if imds_rule_id_changed || imds_local_state_changed {
if let Err(e) = self
.access_control_shared_state
.set_imds_rules(imds_rules.clone())
.await
{
logger::write_error(format!("Failed to set imds rules: {e}"));
}
access_control_rules_changed = true;
}

// HostGA rules only come from server and do not have local rules, so only update when rule id changed
let hostga_rules = status.get_hostga_rules();
if hostga_rule_id_changed {
if let Err(e) = self
Comment thread
ZhidongPeng marked this conversation as resolved.
.access_control_shared_state
.set_hostga_rules(hostga_rules.clone())
.await
{
logger::write_error(format!("Failed to set HostGA rules: {e}"));
}
access_control_rules_changed = true;
}

// Write authorization rules to file if changed
if access_control_rules_changed {
if let (Ok(wireserver_rules), Ok(imds_rules), Ok(hostga_rules)) = (
let effective_rules = Some(AuthorizationRules {
wireserver: wireserver_rules.clone(),
imds: imds_rules.clone(),
hostga: hostga_rules.clone(),
});
if let (
Ok(computed_wireserver_rules),
Ok(computed_imds_rules),
Ok(computed_hostga_rules),
) = (
self.access_control_shared_state
.get_wireserver_rules()
.await,
self.access_control_shared_state.get_imds_rules().await,
self.access_control_shared_state.get_hostga_rules().await,
) {
let rules = AuthorizationRulesForLogging::new(
status.authorizationRules.clone(),
effective_rules,
ComputedAuthorizationRules {
wireserver: wireserver_rules,
imds: imds_rules,
hostga: hostga_rules,
wireserver: computed_wireserver_rules,
imds: computed_imds_rules,
hostga: computed_hostga_rules,
},
);
rules.write_all(&self.status_dir, constants::MAX_LOG_FILE_COUNT);
Expand All @@ -591,6 +622,39 @@ impl KeyKeeper {
access_control_rules_changed
}

/// Ensure the directory exists and has secure ACLs.
/// If the directory does not exist, it will be created.
fn ensure_secure_directory(&self, dir: &Path, dir_kind: &str) {
if let Err(e) = misc_helpers::try_create_folder(dir) {
logger::write_warning(format!(
"{dir_kind} folder {} created failed with error {}.",
misc_helpers::path_to_string(dir),
e
));
} else {
logger::write(format!(
"{dir_kind} folder {} created if not exists before.",
misc_helpers::path_to_string(dir)
));
}

match acl::acl_directory(dir.to_path_buf()) {
Ok(()) => {
logger::write(format!(
"Folder {} ACLed if has not before.",
misc_helpers::path_to_string(dir)
));
}
Err(e) => {
logger::write_warning(format!(
"Folder {} ACLed failed with error {}.",
misc_helpers::path_to_string(dir),
e
));
}
}
}

/// Handle key acquisition from local or server
/// Returns true if successful, false if should continue to next iteration
async fn handle_key_acquisition(&self, status: &KeyStatus, state: &str) -> bool {
Expand Down Expand Up @@ -1001,8 +1065,9 @@ impl KeyKeeper {
#[cfg(test)]
mod tests {
use super::key::Key;
use super::KeyKeeper;
use super::local_rules;
use crate::key_keeper;
use crate::key_keeper::KeyKeeper;
use proxy_agent_shared::misc_helpers;
use proxy_agent_shared::server_mock;
use std::env;
Expand Down Expand Up @@ -1084,6 +1149,7 @@ mod tests {
host: ip.to_string(),
port,
key_dir: cloned_keys_dir.clone(),
rules_dir: local_rules::get_rules_dir_from_key_dir(&cloned_keys_dir),
status_dir: cloned_keys_dir.clone(),
interval: Duration::from_millis(10),
cancellation_token: cancellation_token.clone(),
Expand Down
Loading
Loading