Skip to content

Commit

Permalink
implement caching for all identity implementations (#1484)
Browse files Browse the repository at this point in the history
  • Loading branch information
demoray committed Dec 1, 2023
1 parent 3dd2391 commit 0b3e2d1
Show file tree
Hide file tree
Showing 15 changed files with 324 additions and 158 deletions.
3 changes: 3 additions & 0 deletions sdk/core/src/auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,4 +84,7 @@ impl AccessToken {
pub trait TokenCredential: Send + Sync + Debug {
/// Gets a `AccessToken` for the specified resource
async fn get_token(&self, resource: &str) -> crate::Result<AccessToken>;

/// Clear the credential's cache.
async fn clear_cache(&self) -> crate::Result<()>;
}
4 changes: 4 additions & 0 deletions sdk/identity/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# UNRELEASED

- Removed AutoRefreshingTokenCredential, instead all token credentials now implement caching

# 0.3.0 (2022-05)

- [#756](https://github.com/Azure/azure-sdk-for-rust/pull/756) Export credentials from azure_identity
Expand Down
35 changes: 30 additions & 5 deletions sdk/identity/src/token_credentials/azure_cli_credentials.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use crate::token_credentials::cache::TokenCache;
use azure_core::auth::{AccessToken, Secret, TokenCredential};
use azure_core::error::{Error, ErrorKind, ResultExt};
use serde::Deserialize;
Expand Down Expand Up @@ -104,15 +105,23 @@ struct CliTokenResponse {
}

/// Enables authentication to Azure Active Directory using Azure CLI to obtain an access token.
#[derive(Debug, Default)]
#[derive(Debug)]
pub struct AzureCliCredential {
_private: (),
cache: TokenCache,
}

impl Default for AzureCliCredential {
fn default() -> Self {
Self::new()
}
}

impl AzureCliCredential {
/// Create a new `AzureCliCredential`
pub fn new() -> Self {
Self::default()
Self {
cache: TokenCache::new(),
}
}

/// Get an access token for an optional resource
Expand All @@ -138,6 +147,11 @@ impl AzureCliCredential {
args.push(resource);
}

log::trace!(
"fetching credential via Azure CLI: {program} {}",
args.join(" "),
);

match Command::new(program).args(args).output() {
Ok(az_output) if az_output.status.success() => {
let output = str::from_utf8(&az_output.stdout)?;
Expand Down Expand Up @@ -174,14 +188,25 @@ impl AzureCliCredential {
let tr = Self::get_access_token(None)?;
Ok(tr.tenant)
}

async fn get_token(&self, resource: &str) -> azure_core::Result<AccessToken> {
let tr = Self::get_access_token(Some(resource))?;
Ok(AccessToken::new(tr.access_token, tr.expires_on))
}
}

#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
impl TokenCredential for AzureCliCredential {
async fn get_token(&self, resource: &str) -> azure_core::Result<AccessToken> {
let tr = Self::get_access_token(Some(resource))?;
Ok(AccessToken::new(tr.access_token, tr.expires_on))
self.cache
.get_token(resource, self.get_token(resource))
.await
}

/// Clear the credential's cache.
async fn clear_cache(&self) -> azure_core::Result<()> {
self.cache.clear().await
}
}

Expand Down
103 changes: 100 additions & 3 deletions sdk/identity/src/token_credentials/azureauth_cli_credentials.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ pub struct AzureauthCliCredential {
client_id: ClientId,
modes: Vec<AzureauthCliMode>,
prompt_hint: Option<String>,
cache: TokenCache,
}

impl AzureauthCliCredential {
Expand All @@ -75,6 +76,7 @@ impl AzureauthCliCredential {
client_id: client_id.into(),
modes: Vec::new(),
prompt_hint: None,
cache: TokenCache,
}
}

Expand Down Expand Up @@ -180,16 +182,111 @@ impl AzureauthCliCredential {

let token_response = serde_json::from_str::<CliTokenResponse>(&output)
.map_kind(ErrorKind::DataConversion)?;
Ok(token_response)
Ok(TokenResponse::new(
token_response.token,
token_response.expiration_date,
))
}

/// Clear the azureauth cache as well as the internal cache
fn clear_cache(&self) -> azure_core::Result<CliTokenResponse> {
let resources = { self.cache.read().await.keys().cloned().collect::<Vec<_>>() };

// try using azureauth.exe first, such that azureauth through WSL is
// used first if possible.
let (cmd_name) = if Command::new("azureauth.exe")
.arg("--version")
.output()
.map(|x| x.status.success())
.unwrap_or(false)
{
"azureauth.exe"
} else {
"azureauth"
};

for resource in resources {
let mut resource = resource.to_owned();
if !resource.ends_with("/.default") {
if resource.ends_with('/') {
resource.push_str(".default");
} else {
resource.push_str("/.default");
}
}

let mut cmd = Command::new(cmd_name);
cmd.args([
"aad",
"--scope",
&resource,
"--client",
self.client_id.as_str(),
"--tenant",
self.tenant_id.as_str(),
"--clear",
]);

if let Some(prompt_hint) = &self.prompt_hint {
cmd.args(["--prompt-hint", prompt_hint]);
}

for mode in &self.modes {
match mode {
AzureauthCliMode::All => {
cmd.args(["--mode", "all"]);
}
AzureauthCliMode::IntegratedWindowsAuth => {
if use_windows_features {
cmd.args(["--mode", "iwa"]);
}
}
AzureauthCliMode::Broker => {
if use_windows_features {
cmd.args(["--mode", "broker"]);
}
}
AzureauthCliMode::Web => {
cmd.args(["--mode", "web"]);
}
};
}

let result = cmd.output();

let output = result.map_err(|e| match e.kind() {
std::io::ErrorKind::NotFound => {
Error::message(ErrorKind::Other, "azureauth CLI not installed")
}
error_kind => Error::with_message(ErrorKind::Other, || {
format!("Unknown error of kind: {error_kind:?}")
}),
})?;

if !output.status.success() {
let output = String::from_utf8_lossy(&output.stderr);
return Err(Error::with_message(ErrorKind::Credential, || {
format!("'azureauth' command failed: {output}")
}));
}
}

self.cache.clear().await?;

Ok(())
}
}

#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
impl TokenCredential for AzureauthCliCredential {
async fn get_token(&self, resource: &str) -> azure_core::Result<TokenResponse> {
let tr = self.get_access_token(resource)?;
Ok(TokenResponse::new(tr.token, tr.expiration_date))
self.cache
.get_token(resource, self.get_access_token(resource))
.await
}
async fn clear_cache(&self) -> azure_core::Result<TokenResponse> {
self.cache.clear().await
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,66 +1,56 @@
use async_lock::RwLock;
use azure_core::auth::{AccessToken, TokenCredential};
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use azure_core::auth::AccessToken;
use futures::Future;
use std::{collections::HashMap, time::Duration};
use time::OffsetDateTime;

fn is_expired(token: &AccessToken) -> bool {
token.expires_on < OffsetDateTime::now_utc() + Duration::from_secs(20)
}

#[derive(Clone)]
/// Wraps a `TokenCredential` and handles token refresh on token expiry.
pub struct AutoRefreshingTokenCredential {
credential: Arc<dyn TokenCredential>,
// Tokens are specific to a resource, so we cache tokens by resource.
token_cache: Arc<RwLock<HashMap<String, AccessToken>>>,
}
#[derive(Debug)]
pub(crate) struct TokenCache(RwLock<HashMap<String, AccessToken>>);

impl std::fmt::Debug for AutoRefreshingTokenCredential {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
f.debug_struct("AutoRefreshingTokenCredential")
.field("credential", &self.credential)
.field("token_cache", &"<REDACTED>")
.finish()
impl TokenCache {
pub(crate) fn new() -> Self {
Self(RwLock::new(HashMap::new()))
}
}

impl AutoRefreshingTokenCredential {
/// Create a new `AutoRefreshingTokenCredential` around the provided base provider.
pub fn new(provider: Arc<dyn TokenCredential>) -> Self {
Self {
credential: provider,
token_cache: Arc::new(RwLock::new(HashMap::new())),
}
pub(crate) async fn clear(&self) -> azure_core::Result<()> {
let mut token_cache = self.0.write().await;
token_cache.clear();
Ok(())
}
}

#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
impl TokenCredential for AutoRefreshingTokenCredential {
async fn get_token(&self, resource: &str) -> azure_core::Result<AccessToken> {
pub(crate) async fn get_token(
&self,
resource: &str,
callback: impl Future<Output = azure_core::Result<AccessToken>>,
) -> azure_core::Result<AccessToken> {
// if the current cached token for this resource is good, return it.
let token_cache = self.token_cache.read().await;
let token_cache = self.0.read().await;
if let Some(token) = token_cache.get(resource) {
if !is_expired(token) {
log::trace!("returning cached token");
return Ok(token.clone());
}
}

// otherwise, drop the read lock and get a write lock to refresh the token
drop(token_cache);
let mut token_cache = self.token_cache.write().await;
let mut token_cache = self.0.write().await;

// check again in case another thread refreshed the token while we were
// waiting on the write lock
if let Some(token) = token_cache.get(resource) {
if !is_expired(token) {
log::trace!("returning token that was updated while waiting on write lock");
return Ok(token.clone());
}
}

let token = self.credential.get_token(resource).await?;
log::trace!("falling back to callback");
let token = callback.await?;

// NOTE: we do not check to see if the token is expired here, as at
// least one credential, `AzureCliCredential`, specifies the token is
Expand All @@ -75,7 +65,6 @@ impl TokenCredential for AutoRefreshingTokenCredential {
mod tests {
use super::*;
use azure_core::auth::Secret;
use azure_core::auth::TokenCredential;
use std::sync::Mutex;

#[derive(Debug)]
Expand All @@ -91,11 +80,7 @@ mod tests {
get_token_call_count: Mutex::new(0),
}
}
}

#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
impl TokenCredential for MockCredential {
async fn get_token(&self, resource: &str) -> azure_core::Result<AccessToken> {
// Include an incrementing counter in the token to track how many times the token has been refreshed
let mut call_count = self.get_token_call_count.lock().unwrap();
Expand Down Expand Up @@ -124,20 +109,29 @@ mod tests {
let access_token = AccessToken::new(Secret::new(secret_string), expires_on);

let mock_credential = MockCredential::new(access_token);
let auto_refreshing_credential =
AutoRefreshingTokenCredential::new(Arc::new(mock_credential));

let cache = TokenCache::new();

// Test that querying a token for the same resource twice returns the same (cached) token on the second call
let token1 = auto_refreshing_credential.get_token(resource1).await?;
let token2 = auto_refreshing_credential.get_token(resource1).await?;
let token1 = cache
.get_token(resource1, mock_credential.get_token(resource1))
.await?;
let token2 = cache
.get_token(resource1, mock_credential.get_token(resource1))
.await?;

let expected_token = format!("{}-{}:1", resource1, secret_string);
assert_eq!(token1.token.secret(), expected_token);
assert_eq!(token2.token.secret(), expected_token);

// Test that querying a token for a second resource returns a different token, as the cache is per-resource.
// Also test that the same token is the returned (cached) on a second call.
let token3 = auto_refreshing_credential.get_token(resource2).await?;
let token4 = auto_refreshing_credential.get_token(resource2).await?;
let token3 = cache
.get_token(resource2, mock_credential.get_token(resource2))
.await?;
let token4 = cache
.get_token(resource2, mock_credential.get_token(resource2))
.await?;
let expected_token = format!("{}-{}:2", resource2, secret_string);
assert_eq!(token3.token.secret(), expected_token);
assert_eq!(token4.token.secret(), expected_token);
Expand All @@ -153,12 +147,14 @@ mod tests {
let token_response = AccessToken::new(Secret::new(access_token), expires_on);

let mock_credential = MockCredential::new(token_response);
let auto_refreshing_credential =
AutoRefreshingTokenCredential::new(Arc::new(mock_credential));

let cache = TokenCache::new();

// Test that querying an expired token returns a new token
for i in 1..5 {
let token = auto_refreshing_credential.get_token(resource).await?;
let token = cache
.get_token(resource, mock_credential.get_token(resource))
.await?;
assert_eq!(
token.token.secret(),
format!("{}-{}:{}", resource, access_token, i)
Expand Down
Loading

0 comments on commit 0b3e2d1

Please sign in to comment.