Skip to content

Commit

Permalink
Use Azure SDK for HTTP communication with IMDS
Browse files Browse the repository at this point in the history
Use azure_svc_imds, azure_core, azure_identity of Azure SDK for Rust,
to make use of RetryOptions::exponential for HTTP clients.
Doing that, it is possible for azure-init to retry to send HTTP
requests, when requests failed for some reason.

For now we define a fixed const for timeout, initial delay, max retries,
for simplicity.
  • Loading branch information
dongsupark committed Jul 19, 2024
1 parent 3ac3f41 commit 7b58355
Show file tree
Hide file tree
Showing 5 changed files with 104 additions and 26 deletions.
5 changes: 5 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@ anyhow = "1.0.81"
tokio = { version = "1", features = ["full"] }
tracing-subscriber = { version = "0.3.18", features = ["env-filter"] }
tracing = "0.1.40"
# Azure SDK crates must be 0.19 or older, because Rust 1.71.1 needed for backwards
# compatibility cannot be used for Azure SDK 0.20, which requires Rust 1.74 or newer.
azure_core = "0.19.0"
azure_identity = "0.19.0"
azure_svc_imds = "0.19.0"

[dependencies.libazureinit]
path = "libazureinit"
Expand Down
5 changes: 5 additions & 0 deletions libazureinit/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@ block-utils = "0.11.1"
tracing = "0.1.40"
strum = { version = "0.26.3", features = ["derive"] }
fstab = "0.4.0"
http = "1.1.0"
# Azure SDK crates must be 0.19 or older, because Rust 1.71.1 needed for backwards
# compatibility cannot be used for Azure SDK 0.20, which requires Rust 1.74 or newer.
azure_svc_imds = "0.19.0"
azure_core = "0.19.0"

[dev-dependencies]
tempfile = "3"
Expand Down
6 changes: 6 additions & 0 deletions libazureinit/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ pub enum Error {
Xml(#[from] serde_xml_rs::Error),
#[error("HTTP client error ocurred")]
Http(#[from] reqwest::Error),
#[error("HTTP invalid status code returned")]
InvalidStatusCode(#[from] http::status::InvalidStatusCode),
#[error("An I/O error occurred")]
Io(#[from] std::io::Error),
#[error("HTTP request did not succeed (HTTP {status} from {endpoint})")]
Expand Down Expand Up @@ -47,4 +49,8 @@ pub enum Error {
"Failed to set the user password; none of the provided backends succeeded"
)]
NoPasswordProvisioner,
#[error("Azure core error ocurred")]
AzureCoreError(#[from] azure_core::error::Error),
#[error("Azure service IMDS error ocurred")]
AzureSvcImdsError,
}
85 changes: 61 additions & 24 deletions libazureinit/src/imds.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.

use reqwest;
use reqwest::header::HeaderMap;
use reqwest::header::HeaderValue;
use reqwest::Client;

use azure_svc_imds::models::PublicKeysProperties;
use azure_svc_imds::Client as AzureClient;
use http::StatusCode;
use serde::{Deserialize, Deserializer};
use serde_json;
use serde_json::Value;
Expand Down Expand Up @@ -46,6 +44,21 @@ pub struct OsProfile {
pub disable_password_authentication: bool,
}

impl From<azure_svc_imds::models::OsProfile> for OsProfile {
fn from(osp: azure_svc_imds::models::OsProfile) -> Self {
Self {
admin_username: osp.admin_username.unwrap_or("".to_string()),
computer_name: osp.computer_name.unwrap_or("".to_string()),
disable_password_authentication: osp
.disable_password_authentication
.unwrap_or("".to_string())
.trim()
.parse()
.unwrap_or(true),
}
}
}

/// An SSH public key.
#[derive(Debug, Deserialize, PartialEq, Clone)]
pub struct PublicKeys {
Expand All @@ -66,6 +79,15 @@ impl From<&str> for PublicKeys {
}
}

impl From<PublicKeysProperties> for PublicKeys {
fn from(public_keys: PublicKeysProperties) -> Self {
Self {
key_data: public_keys.key_data.unwrap_or_default(),
path: public_keys.path.unwrap_or_default(),
}
}
}

/// Deserializer that handles the string "true" and "false" that the IMDS API returns.
fn string_bool<'de, D>(deserializer: D) -> Result<bool, D::Error>
where
Expand All @@ -87,26 +109,41 @@ where
}
}

pub async fn query(client: &Client) -> Result<InstanceMetadata, Error> {
let url = "http://169.254.169.254/metadata/instance?api-version=2021-02-01";
let mut headers = HeaderMap::new();

headers.insert("Metadata", HeaderValue::from_static("true"));

let request = client.get(url).headers(headers);
let response = request.send().await?;

if response.status().is_success() {
let imds_body = response.text().await?;
let metadata: InstanceMetadata = serde_json::from_str(&imds_body)?;

Ok(metadata)
} else {
Err(Error::HttpStatus {
endpoint: url.to_owned(),
status: response.status(),
})
pub async fn query(
client: &AzureClient,
endpoint: &azure_core::Url,
) -> Result<InstanceMetadata, Error> {
let resp = client
.instances_client()
.get_metadata("true")
.send()
.await?;

let response = resp.as_raw_response();

if !response.status().is_success() {
return Err(Error::HttpStatus {
endpoint: endpoint.to_string(),
status: StatusCode::from_u16(response.status().into())?,
});
}

let body = resp.into_body().await?;
let compute = body.compute.ok_or(Error::AzureSvcImdsError)?;

Ok(InstanceMetadata {
compute: Compute {
os_profile: compute.os_profile.unwrap_or_default().into(),
public_keys: compute
.public_keys
.iter()
.map(|c| PublicKeys {
key_data: c.key_data.clone().unwrap_or("".to_string()),
path: c.path.clone().unwrap_or("".to_string()),
})
.collect(),
},
})
}

#[cfg(test)]
Expand Down
29 changes: 27 additions & 2 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,15 @@
// Licensed under the MIT License.

use std::process::ExitCode;
use std::sync::Arc;
use std::time::Duration;

use anyhow::Context;

use azure_core::{ExponentialRetryOptions, RetryOptions};
use azure_identity::AzureCliCredential;
use azure_svc_imds::Client as ImdsClient;

use libazureinit::imds::InstanceMetadata;
use libazureinit::User;
use libazureinit::{
Expand All @@ -19,6 +26,12 @@ use tracing_subscriber::layer::SubscriberExt;
use tracing_subscriber::EnvFilter;

const VERSION: &str = env!("CARGO_PKG_VERSION");
const HTTP_TIMEOUT_SEC: u64 = 30;
const HTTP_INIT_DELAY_SEC: u64 = 5;
const HTTP_MAX_ELAPSED_SEC: u64 = 150;
const HTTP_MAX_RETRIES: u32 = 5;
const IMDS_ENDPOINT: &str =
"http://169.254.169.254/metadata/instance?api-version=2021-02-01";

#[instrument]
fn get_environment() -> Result<Environment, anyhow::Error> {
Expand Down Expand Up @@ -102,16 +115,28 @@ async fn provision() -> Result<(), anyhow::Error> {
)?;
default_headers.insert(header::USER_AGENT, user_agent);
let client = Client::builder()
.timeout(std::time::Duration::from_secs(30))
.timeout(std::time::Duration::from_secs(HTTP_TIMEOUT_SEC))
.default_headers(default_headers)
.build()?;

let imds_endpoint = azure_core::Url::parse(IMDS_ENDPOINT)?;
let imds_client = ImdsClient::builder(Arc::new(AzureCliCredential::new()))
.endpoint(imds_endpoint.clone())
.retry(RetryOptions::exponential(ExponentialRetryOptions {
initial_delay: Duration::from_secs(HTTP_INIT_DELAY_SEC),
max_retries: HTTP_MAX_RETRIES,
max_total_elapsed: Duration::from_secs(HTTP_MAX_ELAPSED_SEC),
max_delay: Duration::from_secs(HTTP_TIMEOUT_SEC),
}))
.build()?;

// Username can be obtained either via fetching instance metadata from IMDS
// or mounting a local device for OVF environment file. It should not fail
// immediately in a single failure, instead it should fall back to the other
// mechanism. So it is not a good idea to use `?` for query() or
// get_environment().
let instance_metadata = imds::query(&client).await.ok();
let instance_metadata =
imds::query(&imds_client, &imds_endpoint).await.ok();

let environment = get_environment().ok();

Expand Down

0 comments on commit 7b58355

Please sign in to comment.