From 41d96b21628a9035623b72e80d7ea06e5fb904f9 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Fri, 29 Jul 2022 13:27:37 -0400 Subject: [PATCH] Add Builder style config objects for object_store (#2204) * Add AmazonS3Config, MicrosoftAzureBuilder, GoogleCloudStorageBuilder * fix: improve docs * review feedback: remove old code, make with_client test only --- object_store/src/aws.rs | 370 ++++++++++++++++++++++++-------------- object_store/src/azure.rs | 220 ++++++++++++++--------- object_store/src/gcp.rs | 230 +++++++++++++++--------- 3 files changed, 516 insertions(+), 304 deletions(-) diff --git a/object_store/src/aws.rs b/object_store/src/aws.rs index 3606a3806f9..89a2185128b 100644 --- a/object_store/src/aws.rs +++ b/object_store/src/aws.rs @@ -228,6 +228,14 @@ enum Error { source: rusoto_core::region::ParseRegionError, }, + #[snafu(display( + "Region must be specified for AWS S3. Regions should look like `us-east-2`" + ))] + MissingRegion {}, + + #[snafu(display("Missing bucket name"))] + MissingBucketName {}, + #[snafu(display("Missing aws-access-key"))] MissingAccessKey, @@ -584,99 +592,195 @@ fn convert_object_meta(object: rusoto_s3::Object, bucket: &str) -> Result>, - secret_access_key: Option>, - region: impl Into, - bucket_name: impl Into, - endpoint: Option>, - session_token: Option>, +/// +/// # Example +/// ``` +/// # let REGION = "foo"; +/// # let BUCKET_NAME = "foo"; +/// # let ACCESS_KEY_ID = "foo"; +/// # let SECRET_KEY = "foo"; +/// let s3 = object_store::aws::AmazonS3Builder::new() +/// .with_region(REGION) +/// .with_bucket_name(BUCKET_NAME) +/// .with_access_key_id(ACCESS_KEY_ID) +/// .with_secret_access_key(SECRET_KEY) +/// .build(); +/// ``` +#[derive(Debug)] +pub struct AmazonS3Builder { + access_key_id: Option, + secret_access_key: Option, + region: Option, + bucket_name: Option, + endpoint: Option, + token: Option, max_connections: NonZeroUsize, allow_http: bool, -) -> Result { - let region = region.into(); - let region: rusoto_core::Region = match endpoint { - None => region.parse().context(InvalidRegionSnafu { region })?, - Some(endpoint) => rusoto_core::Region::Custom { - name: region, - endpoint: endpoint.into(), - }, - }; +} - let mut builder = HyperBuilder::default(); - builder.pool_max_idle_per_host(max_connections.get()); - - let connector = if allow_http { - hyper_rustls::HttpsConnectorBuilder::new() - .with_webpki_roots() - .https_or_http() - .enable_http1() - .enable_http2() - .build() - } else { - hyper_rustls::HttpsConnectorBuilder::new() - .with_webpki_roots() - .https_only() - .enable_http1() - .enable_http2() - .build() - }; +impl Default for AmazonS3Builder { + fn default() -> Self { + Self { + access_key_id: None, + secret_access_key: None, + region: None, + bucket_name: None, + endpoint: None, + token: None, + max_connections: NonZeroUsize::new(16).unwrap(), + allow_http: false, + } + } +} - let http_client = rusoto_core::request::HttpClient::from_builder(builder, connector); +impl AmazonS3Builder { + /// Create a new [`AmazonS3Builder`] with default values. + pub fn new() -> Self { + Default::default() + } - let client = match (access_key_id, secret_access_key, session_token) { - (Some(access_key_id), Some(secret_access_key), Some(session_token)) => { - let credentials_provider = StaticProvider::new( - access_key_id.into(), - secret_access_key.into(), - Some(session_token.into()), - None, - ); - rusoto_s3::S3Client::new_with(http_client, credentials_provider, region) - } - (Some(access_key_id), Some(secret_access_key), None) => { - let credentials_provider = StaticProvider::new_minimal( - access_key_id.into(), - secret_access_key.into(), - ); - rusoto_s3::S3Client::new_with(http_client, credentials_provider, region) - } - (None, Some(_), _) => return Err(Error::MissingAccessKey.into()), - (Some(_), None, _) => return Err(Error::MissingSecretAccessKey.into()), - _ if std::env::var_os("AWS_WEB_IDENTITY_TOKEN_FILE").is_some() => { - rusoto_s3::S3Client::new_with( - http_client, - WebIdentityProvider::from_k8s_env(), - region, - ) - } - _ => rusoto_s3::S3Client::new_with( - http_client, - InstanceMetadataProvider::new(), + /// Set the AWS Access Key (required) + pub fn with_access_key_id(mut self, access_key_id: impl Into) -> Self { + self.access_key_id = Some(access_key_id.into()); + self + } + + /// Set the AWS Secret Access Key (required) + pub fn with_secret_access_key( + mut self, + secret_access_key: impl Into, + ) -> Self { + self.secret_access_key = Some(secret_access_key.into()); + self + } + + /// Set the region (e.g. `us-east-1`) (required) + pub fn with_region(mut self, region: impl Into) -> Self { + self.region = Some(region.into()); + self + } + + /// Set the bucket_name (required) + pub fn with_bucket_name(mut self, bucket_name: impl Into) -> Self { + self.bucket_name = Some(bucket_name.into()); + self + } + + /// Sets the endpoint for communicating with AWS S3. Default value + /// is based on region. + /// + /// For example, this might be set to `"http://localhost:4566:` + /// for testing against a localstack instance. + pub fn with_endpoint(mut self, endpoint: impl Into) -> Self { + self.endpoint = Some(endpoint.into()); + self + } + + /// Set the token to use for requests (passed to underlying provider) + pub fn with_token(mut self, token: impl Into) -> Self { + self.token = Some(token.into()); + self + } + + /// Sets the maximum number of concurrent outstanding + /// connectons. Default is `16`. + pub fn with_max_connections(mut self, max_connections: NonZeroUsize) -> Self { + self.max_connections = max_connections; + self + } + + /// Sets what protocol is allowed. If `allow_http` is : + /// * false (default): Only HTTPS are allowed + /// * true: HTTP and HTTPS are allowed + pub fn with_allow_http(mut self, allow_http: bool) -> Self { + self.allow_http = allow_http; + self + } + + /// Create a [`AmazonS3`] instance from the provided values, + /// consuming `self`. + pub fn build(self) -> Result { + let Self { + access_key_id, + secret_access_key, region, - ), - }; + bucket_name, + endpoint, + token, + max_connections, + allow_http, + } = self; + + let region = region.ok_or(Error::MissingRegion {})?; + let bucket_name = bucket_name.ok_or(Error::MissingBucketName {})?; + + let region: rusoto_core::Region = match endpoint { + None => region.parse().context(InvalidRegionSnafu { region })?, + Some(endpoint) => rusoto_core::Region::Custom { + name: region, + endpoint, + }, + }; - Ok(AmazonS3 { - client_unrestricted: client, - connection_semaphore: Arc::new(Semaphore::new(max_connections.get())), - bucket_name: bucket_name.into(), - }) -} + let mut builder = HyperBuilder::default(); + builder.pool_max_idle_per_host(max_connections.get()); + + let connector = if allow_http { + hyper_rustls::HttpsConnectorBuilder::new() + .with_webpki_roots() + .https_or_http() + .enable_http1() + .enable_http2() + .build() + } else { + hyper_rustls::HttpsConnectorBuilder::new() + .with_webpki_roots() + .https_only() + .enable_http1() + .enable_http2() + .build() + }; + + let http_client = + rusoto_core::request::HttpClient::from_builder(builder, connector); + + let client = match (access_key_id, secret_access_key, token) { + (Some(access_key_id), Some(secret_access_key), Some(token)) => { + let credentials_provider = StaticProvider::new( + access_key_id, + secret_access_key, + Some(token), + None, + ); + rusoto_s3::S3Client::new_with(http_client, credentials_provider, region) + } + (Some(access_key_id), Some(secret_access_key), None) => { + let credentials_provider = + StaticProvider::new_minimal(access_key_id, secret_access_key); + rusoto_s3::S3Client::new_with(http_client, credentials_provider, region) + } + (None, Some(_), _) => return Err(Error::MissingAccessKey.into()), + (Some(_), None, _) => return Err(Error::MissingSecretAccessKey.into()), + _ if std::env::var_os("AWS_WEB_IDENTITY_TOKEN_FILE").is_some() => { + rusoto_s3::S3Client::new_with( + http_client, + WebIdentityProvider::from_k8s_env(), + region, + ) + } + _ => rusoto_s3::S3Client::new_with( + http_client, + InstanceMetadataProvider::new(), + region, + ), + }; -/// Create a new [`AmazonS3`] that always errors -pub fn new_failing_s3() -> Result { - new_s3( - Some("foo"), - Some("bar"), - "us-east-1", - "bucket", - None as Option<&str>, - None as Option<&str>, - NonZeroUsize::new(16).unwrap(), - true, - ) + Ok(AmazonS3 { + client_unrestricted: client, + connection_semaphore: Arc::new(Semaphore::new(max_connections.get())), + bucket_name, + }) + } } /// S3 client bundled w/ a semaphore permit. @@ -1057,7 +1161,7 @@ mod tests { get_nonexistent_object, list_uses_directories_correctly, list_with_delimiter, put_get_delete_list, rename_and_copy, stream_get, }, - Error as ObjectStoreError, ObjectStore, + Error as ObjectStoreError, }; use bytes::Bytes; use std::env; @@ -1067,17 +1171,9 @@ mod tests { const NON_EXISTENT_NAME: &str = "nonexistentname"; - #[derive(Debug)] - struct AwsConfig { - access_key_id: String, - secret_access_key: String, - region: String, - bucket: String, - endpoint: Option, - token: Option, - } - - // Helper macro to skip tests if TEST_INTEGRATION and the AWS environment variables are not set. + // Helper macro to skip tests if TEST_INTEGRATION and the AWS + // environment variables are not set. Returns a configured + // AmazonS3Builder macro_rules! maybe_skip_integration { () => {{ dotenv::dotenv().ok(); @@ -1116,18 +1212,38 @@ mod tests { ); return; } else { - AwsConfig { - access_key_id: env::var("AWS_ACCESS_KEY_ID") - .expect("already checked AWS_ACCESS_KEY_ID"), - secret_access_key: env::var("AWS_SECRET_ACCESS_KEY") - .expect("already checked AWS_SECRET_ACCESS_KEY"), - region: env::var("AWS_DEFAULT_REGION") - .expect("already checked AWS_DEFAULT_REGION"), - bucket: env::var("OBJECT_STORE_BUCKET") - .expect("already checked OBJECT_STORE_BUCKET"), - endpoint: env::var("AWS_ENDPOINT").ok(), - token: env::var("AWS_SESSION_TOKEN").ok(), - } + let config = AmazonS3Builder::new() + .with_access_key_id( + env::var("AWS_ACCESS_KEY_ID") + .expect("already checked AWS_ACCESS_KEY_ID"), + ) + .with_secret_access_key( + env::var("AWS_SECRET_ACCESS_KEY") + .expect("already checked AWS_SECRET_ACCESS_KEY"), + ) + .with_region( + env::var("AWS_DEFAULT_REGION") + .expect("already checked AWS_DEFAULT_REGION"), + ) + .with_bucket_name( + env::var("OBJECT_STORE_BUCKET") + .expect("already checked OBJECT_STORE_BUCKET"), + ) + .with_allow_http(true); + + let config = if let Some(endpoint) = env::var("AWS_ENDPOINT").ok() { + config.with_endpoint(endpoint) + } else { + config + }; + + let config = if let Some(token) = env::var("AWS_SESSION_TOKEN").ok() { + config.with_token(token) + } else { + config + }; + + config } }}; } @@ -1148,24 +1264,10 @@ mod tests { r } - fn make_integration(config: AwsConfig) -> AmazonS3 { - new_s3( - Some(config.access_key_id), - Some(config.secret_access_key), - config.region, - config.bucket, - config.endpoint, - config.token, - NonZeroUsize::new(16).unwrap(), - true, - ) - .expect("Valid S3 config") - } - #[tokio::test] async fn s3_test() { let config = maybe_skip_integration!(); - let integration = make_integration(config); + let integration = config.build().unwrap(); check_credentials(put_get_delete_list(&integration).await).unwrap(); check_credentials(list_uses_directories_correctly(&integration).await).unwrap(); @@ -1177,7 +1279,7 @@ mod tests { #[tokio::test] async fn s3_test_get_nonexistent_location() { let config = maybe_skip_integration!(); - let integration = make_integration(config); + let integration = config.build().unwrap(); let location = Path::from_iter([NON_EXISTENT_NAME]); @@ -1204,9 +1306,8 @@ mod tests { #[tokio::test] async fn s3_test_get_nonexistent_bucket() { - let mut config = maybe_skip_integration!(); - config.bucket = NON_EXISTENT_NAME.into(); - let integration = make_integration(config); + let config = maybe_skip_integration!().with_bucket_name(NON_EXISTENT_NAME); + let integration = config.build().unwrap(); let location = Path::from_iter([NON_EXISTENT_NAME]); @@ -1220,9 +1321,9 @@ mod tests { #[tokio::test] async fn s3_test_put_nonexistent_bucket() { - let mut config = maybe_skip_integration!(); - config.bucket = NON_EXISTENT_NAME.into(); - let integration = make_integration(config); + let config = maybe_skip_integration!().with_bucket_name(NON_EXISTENT_NAME); + + let integration = config.build().unwrap(); let location = Path::from_iter([NON_EXISTENT_NAME]); let data = Bytes::from("arbitrary data"); @@ -1244,7 +1345,7 @@ mod tests { #[tokio::test] async fn s3_test_delete_nonexistent_location() { let config = maybe_skip_integration!(); - let integration = make_integration(config); + let integration = config.build().unwrap(); let location = Path::from_iter([NON_EXISTENT_NAME]); @@ -1253,9 +1354,8 @@ mod tests { #[tokio::test] async fn s3_test_delete_nonexistent_bucket() { - let mut config = maybe_skip_integration!(); - config.bucket = NON_EXISTENT_NAME.into(); - let integration = make_integration(config); + let config = maybe_skip_integration!().with_bucket_name(NON_EXISTENT_NAME); + let integration = config.build().unwrap(); let location = Path::from_iter([NON_EXISTENT_NAME]); diff --git a/object_store/src/azure.rs b/object_store/src/azure.rs index 25f311a9a39..dca52a356c1 100644 --- a/object_store/src/azure.rs +++ b/object_store/src/azure.rs @@ -185,6 +185,15 @@ enum Error { env_value: String, source: url::ParseError, }, + + #[snafu(display("Account must be specified"))] + MissingAccount {}, + + #[snafu(display("Access key must be specified"))] + MissingAccessKey {}, + + #[snafu(display("Container name must be specified"))] + MissingContainerName {}, } impl From for super::Error { @@ -570,73 +579,125 @@ fn url_from_env(env_name: &str, default_url: &str) -> Result { Ok(url) } -/// Configure a connection to container with given name on Microsoft Azure -/// Blob store. +/// Configure a connection to Mirosoft Azure Blob Storage bucket using +/// the specified credentials. /// -/// The credentials `account` and `access_key` must provide access to the -/// store. -pub fn new_azure( - account: impl Into, - access_key: impl Into, - container_name: impl Into, +/// # Example +/// ``` +/// # let ACCOUNT = "foo"; +/// # let BUCKET_NAME = "foo"; +/// # let ACCESS_KEY = "foo"; +/// let azure = object_store::azure::MicrosoftAzureBuilder::new() +/// .with_account(ACCOUNT) +/// .with_access_key(ACCESS_KEY) +/// .with_container_name(BUCKET_NAME) +/// .build(); +/// ``` +#[derive(Debug, Default)] +pub struct MicrosoftAzureBuilder { + account: Option, + access_key: Option, + container_name: Option, use_emulator: bool, -) -> Result { - let account = account.into(); - let access_key = access_key.into(); - let http_client: Arc = Arc::new(reqwest::Client::new()); - - let (is_emulator, storage_account_client) = if use_emulator { - check_if_emulator_works()?; - // Allow overriding defaults. Values taken from - // from https://docs.rs/azure_storage/0.2.0/src/azure_storage/core/clients/storage_account_client.rs.html#129-141 - let http_client = azure_core::new_http_client(); - let blob_storage_url = - url_from_env("AZURITE_BLOB_STORAGE_URL", "http://127.0.0.1:10000")?; - let queue_storage_url = - url_from_env("AZURITE_QUEUE_STORAGE_URL", "http://127.0.0.1:10001")?; - let table_storage_url = - url_from_env("AZURITE_TABLE_STORAGE_URL", "http://127.0.0.1:10002")?; - let filesystem_url = - url_from_env("AZURITE_TABLE_STORAGE_URL", "http://127.0.0.1:10004")?; - - let storage_client = StorageAccountClient::new_emulator( - http_client, - &blob_storage_url, - &table_storage_url, - &queue_storage_url, - &filesystem_url, - ); - - (true, storage_client) - } else { - ( - false, - StorageAccountClient::new_access_key( - Arc::clone(&http_client), - &account, - &access_key, - ), - ) - }; +} - let storage_client = storage_account_client.as_storage_client(); - let blob_base_url = storage_account_client - .blob_storage_url() - .as_ref() - // make url ending consistent between the emulator and remote storage account - .trim_end_matches('/') - .to_string(); +impl MicrosoftAzureBuilder { + /// Create a new [`MicrosoftAzureBuilder`] with default values. + pub fn new() -> Self { + Default::default() + } - let container_name = container_name.into(); + /// Set the Azure Account (required) + pub fn with_account(mut self, account: impl Into) -> Self { + self.account = Some(account.into()); + self + } - let container_client = storage_client.as_container_client(&container_name); + /// Set the Azure Access Key (required) + pub fn with_access_key(mut self, access_key: impl Into) -> Self { + self.access_key = Some(access_key.into()); + self + } + + /// Set the Azure Container Name (required) + pub fn with_container_name(mut self, container_name: impl Into) -> Self { + self.container_name = Some(container_name.into()); + self + } - Ok(MicrosoftAzure { - container_client, - container_name, - blob_base_url, - is_emulator, - }) + /// Set if the Azure emulator should be used (defaults to false) + pub fn with_use_emulator(mut self, use_emulator: bool) -> Self { + self.use_emulator = use_emulator; + self + } + + /// Configure a connection to container with given name on Microsoft Azure + /// Blob store. + pub fn build(self) -> Result { + let Self { + account, + access_key, + container_name, + use_emulator, + } = self; + + let account = account.ok_or(Error::MissingAccount {})?; + let access_key = access_key.ok_or(Error::MissingAccessKey {})?; + let container_name = container_name.ok_or(Error::MissingContainerName {})?; + + let http_client: Arc = Arc::new(reqwest::Client::new()); + + let (is_emulator, storage_account_client) = if use_emulator { + check_if_emulator_works()?; + // Allow overriding defaults. Values taken from + // from https://docs.rs/azure_storage/0.2.0/src/azure_storage/core/clients/storage_account_client.rs.html#129-141 + let http_client = azure_core::new_http_client(); + let blob_storage_url = + url_from_env("AZURITE_BLOB_STORAGE_URL", "http://127.0.0.1:10000")?; + let queue_storage_url = + url_from_env("AZURITE_QUEUE_STORAGE_URL", "http://127.0.0.1:10001")?; + let table_storage_url = + url_from_env("AZURITE_TABLE_STORAGE_URL", "http://127.0.0.1:10002")?; + let filesystem_url = + url_from_env("AZURITE_TABLE_STORAGE_URL", "http://127.0.0.1:10004")?; + + let storage_client = StorageAccountClient::new_emulator( + http_client, + &blob_storage_url, + &table_storage_url, + &queue_storage_url, + &filesystem_url, + ); + + (true, storage_client) + } else { + ( + false, + StorageAccountClient::new_access_key( + Arc::clone(&http_client), + &account, + &access_key, + ), + ) + }; + + let storage_client = storage_account_client.as_storage_client(); + let blob_base_url = storage_account_client + .blob_storage_url() + .as_ref() + // make url ending consistent between the emulator and remote storage account + .trim_end_matches('/') + .to_string(); + + let container_client = storage_client.as_container_client(&container_name); + + Ok(MicrosoftAzure { + container_client, + container_name, + blob_base_url, + is_emulator, + }) + } } // Relevant docs: https://azure.github.io/Storage/docs/application-and-user-data/basics/azure-blob-storage-upload-apis/ @@ -729,21 +790,13 @@ impl CloudMultiPartUploadImpl for AzureMultiPartUpload { #[cfg(test)] mod tests { - use crate::azure::new_azure; + use super::*; use crate::tests::{ copy_if_not_exists, list_uses_directories_correctly, list_with_delimiter, put_get_delete_list, rename_and_copy, }; use std::env; - #[derive(Debug)] - struct AzureConfig { - storage_account: String, - access_key: String, - bucket: String, - use_emulator: bool, - } - // Helper macro to skip tests if TEST_INTEGRATION and the Azure environment // variables are not set. macro_rules! maybe_skip_integration { @@ -785,28 +838,23 @@ mod tests { ); return; } else { - AzureConfig { - storage_account: env::var("AZURE_STORAGE_ACCOUNT") - .unwrap_or_default(), - access_key: env::var("AZURE_STORAGE_ACCESS_KEY").unwrap_or_default(), - bucket: env::var("OBJECT_STORE_BUCKET") - .expect("already checked OBJECT_STORE_BUCKET"), - use_emulator, - } + MicrosoftAzureBuilder::new() + .with_account(env::var("AZURE_STORAGE_ACCOUNT").unwrap_or_default()) + .with_access_key( + env::var("AZURE_STORAGE_ACCESS_KEY").unwrap_or_default(), + ) + .with_container_name( + env::var("OBJECT_STORE_BUCKET") + .expect("already checked OBJECT_STORE_BUCKET"), + ) + .with_use_emulator(use_emulator) } }}; } #[tokio::test] async fn azure_blob_test() { - let config = maybe_skip_integration!(); - let integration = new_azure( - config.storage_account, - config.access_key, - config.bucket, - config.use_emulator, - ) - .unwrap(); + let integration = maybe_skip_integration!().build().unwrap(); put_get_delete_list(&integration).await.unwrap(); list_uses_directories_correctly(&integration).await.unwrap(); diff --git a/object_store/src/gcp.rs b/object_store/src/gcp.rs index d740625bd92..dea8769a736 100644 --- a/object_store/src/gcp.rs +++ b/object_store/src/gcp.rs @@ -98,6 +98,12 @@ enum Error { #[snafu(display("Error decoding object size: {}", source))] InvalidSize { source: std::num::ParseIntError }, + + #[snafu(display("Missing bucket name"))] + MissingBucketName {}, + + #[snafu(display("Missing service account path"))] + MissingServiceAccountPath, } impl From for super::Error { @@ -779,55 +785,121 @@ fn reader_credentials_file( Ok(serde_json::from_reader(reader).context(DecodeCredentialsSnafu)?) } -/// Configure a connection to Google Cloud Storage. -pub fn new_gcs( - service_account_path: impl AsRef, - bucket_name: impl Into, -) -> Result { - new_gcs_with_client(service_account_path, bucket_name, Client::new()) +/// Configure a connection to Google Cloud Storage using the specified +/// credentials. +/// +/// # Example +/// ``` +/// # let BUCKET_NAME = "foo"; +/// # let SERVICE_ACCOUNT_PATH = "/tmp/foo.json"; +/// let gcs = object_store::gcp::GoogleCloudStorageBuilder::new() +/// .with_service_account_path(SERVICE_ACCOUNT_PATH) +/// .with_bucket_name(BUCKET_NAME) +/// .build(); +/// ``` +#[derive(Debug, Default)] +pub struct GoogleCloudStorageBuilder { + bucket_name: Option, + service_account_path: Option, + client: Option, } -/// Configure a connection to Google Cloud Storage with the specified HTTP client. -pub fn new_gcs_with_client( - service_account_path: impl AsRef, - bucket_name: impl Into, - client: Client, -) -> Result { - let credentials = reader_credentials_file(service_account_path)?; - - // TODO: https://cloud.google.com/storage/docs/authentication#oauth-scopes - let scope = "https://www.googleapis.com/auth/devstorage.full_control"; - let audience = "https://www.googleapis.com/oauth2/v4/token".to_string(); - - let oauth_provider = (!credentials.disable_oauth) - .then(|| { - OAuthProvider::new( - credentials.client_email, - credentials.private_key, - scope.to_string(), - audience, - ) - }) - .transpose()?; +impl GoogleCloudStorageBuilder { + /// Create a new [`GoogleCloudStorageBuilder`] with default values. + pub fn new() -> Self { + Default::default() + } - let bucket_name = bucket_name.into(); - let encoded_bucket_name = - percent_encode(bucket_name.as_bytes(), NON_ALPHANUMERIC).to_string(); + /// Set the bucket name (required) + pub fn with_bucket_name(mut self, bucket_name: impl Into) -> Self { + self.bucket_name = Some(bucket_name.into()); + self + } - // The cloud storage crate currently only supports authentication via - // environment variables. Set the environment variable explicitly so - // that we can optionally accept command line arguments instead. - Ok(GoogleCloudStorage { - client: Arc::new(GoogleCloudStorageClient { - client, - base_url: credentials.gcs_base_url, - oauth_provider, - token_cache: Default::default(), + /// Set the path to the service account file (required). Example + /// `"/tmp/gcs.json"` + /// + /// Example contents of `gcs.json`: + /// + /// ```json + /// { + /// "gcs_base_url": "https://localhost:4443", + /// "disable_oauth": true, + /// "client_email": "", + /// "private_key": "" + /// } + /// ``` + pub fn with_service_account_path( + mut self, + service_account_path: impl Into, + ) -> Self { + self.service_account_path = Some(service_account_path.into()); + self + } + + /// Use the specified http [`Client`] (defaults to [`Client::new`]) + /// + /// This allows you to set custom client options such as allowing + /// non secure connections or custom headers. + /// + /// NOTE: Currently only available in `test`s to facilitate + /// testing, to avoid leaking details and preserve our ability to + /// make changes to the implementation. + #[cfg(test)] + pub fn with_client(mut self, client: Client) -> Self { + self.client = Some(client); + self + } + + /// Configure a connection to Google Cloud Storage, returning a + /// new [`GoogleCloudStorage`] and consuming `self` + pub fn build(self) -> Result { + let Self { bucket_name, - bucket_name_encoded: encoded_bucket_name, - max_list_results: None, - }), - }) + service_account_path, + client, + } = self; + + let bucket_name = bucket_name.ok_or(Error::MissingBucketName {})?; + let service_account_path = + service_account_path.ok_or(Error::MissingServiceAccountPath)?; + let client = client.unwrap_or_else(Client::new); + + let credentials = reader_credentials_file(service_account_path)?; + + // TODO: https://cloud.google.com/storage/docs/authentication#oauth-scopes + let scope = "https://www.googleapis.com/auth/devstorage.full_control"; + let audience = "https://www.googleapis.com/oauth2/v4/token".to_string(); + + let oauth_provider = (!credentials.disable_oauth) + .then(|| { + OAuthProvider::new( + credentials.client_email, + credentials.private_key, + scope.to_string(), + audience, + ) + }) + .transpose()?; + + let encoded_bucket_name = + percent_encode(bucket_name.as_bytes(), NON_ALPHANUMERIC).to_string(); + + // The cloud storage crate currently only supports authentication via + // environment variables. Set the environment variable explicitly so + // that we can optionally accept command line arguments instead. + Ok(GoogleCloudStorage { + client: Arc::new(GoogleCloudStorageClient { + client, + base_url: credentials.gcs_base_url, + oauth_provider, + token_cache: Default::default(), + bucket_name, + bucket_name_encoded: encoded_bucket_name, + max_list_results: None, + }), + }) + } } fn convert_object_meta(object: &Object) -> Result { @@ -860,24 +932,6 @@ mod test { const NON_EXISTENT_NAME: &str = "nonexistentname"; - #[derive(Debug)] - struct GoogleCloudConfig { - bucket: String, - service_account: String, - } - - impl GoogleCloudConfig { - fn build_test(self) -> Result { - // ignore HTTPS errors in tests so we can use fake-gcs server - let client = Client::builder() - .danger_accept_invalid_certs(true) - .build() - .expect("Error creating http client for testing"); - - new_gcs_with_client(self.service_account, self.bucket, client) - } - } - // Helper macro to skip tests if TEST_INTEGRATION and the GCP environment variables are not set. macro_rules! maybe_skip_integration { () => {{ @@ -912,20 +966,29 @@ mod test { ); return; } else { - GoogleCloudConfig { - bucket: env::var("OBJECT_STORE_BUCKET") - .expect("already checked OBJECT_STORE_BUCKET"), - service_account: env::var("GOOGLE_SERVICE_ACCOUNT") - .expect("already checked GOOGLE_SERVICE_ACCOUNT"), - } + GoogleCloudStorageBuilder::new() + .with_bucket_name( + env::var("OBJECT_STORE_BUCKET") + .expect("already checked OBJECT_STORE_BUCKET") + ) + .with_service_account_path( + env::var("GOOGLE_SERVICE_ACCOUNT") + .expect("already checked GOOGLE_SERVICE_ACCOUNT") + ) + .with_client( + // ignore HTTPS errors in tests so we can use fake-gcs server + Client::builder() + .danger_accept_invalid_certs(true) + .build() + .expect("Error creating http client for testing") + ) } }}; } #[tokio::test] async fn gcs_test() { - let config = maybe_skip_integration!(); - let integration = config.build_test().unwrap(); + let integration = maybe_skip_integration!().build().unwrap(); put_get_delete_list(&integration).await.unwrap(); list_uses_directories_correctly(&integration).await.unwrap(); @@ -940,8 +1003,7 @@ mod test { #[tokio::test] async fn gcs_test_get_nonexistent_location() { - let config = maybe_skip_integration!(); - let integration = config.build_test().unwrap(); + let integration = maybe_skip_integration!().build().unwrap(); let location = Path::from_iter([NON_EXISTENT_NAME]); @@ -956,9 +1018,10 @@ mod test { #[tokio::test] async fn gcs_test_get_nonexistent_bucket() { - let mut config = maybe_skip_integration!(); - config.bucket = NON_EXISTENT_NAME.into(); - let integration = config.build_test().unwrap(); + let integration = maybe_skip_integration!() + .with_bucket_name(NON_EXISTENT_NAME) + .build() + .unwrap(); let location = Path::from_iter([NON_EXISTENT_NAME]); @@ -975,8 +1038,7 @@ mod test { #[tokio::test] async fn gcs_test_delete_nonexistent_location() { - let config = maybe_skip_integration!(); - let integration = config.build_test().unwrap(); + let integration = maybe_skip_integration!().build().unwrap(); let location = Path::from_iter([NON_EXISTENT_NAME]); @@ -990,9 +1052,10 @@ mod test { #[tokio::test] async fn gcs_test_delete_nonexistent_bucket() { - let mut config = maybe_skip_integration!(); - config.bucket = NON_EXISTENT_NAME.into(); - let integration = config.build_test().unwrap(); + let integration = maybe_skip_integration!() + .with_bucket_name(NON_EXISTENT_NAME) + .build() + .unwrap(); let location = Path::from_iter([NON_EXISTENT_NAME]); @@ -1006,9 +1069,10 @@ mod test { #[tokio::test] async fn gcs_test_put_nonexistent_bucket() { - let mut config = maybe_skip_integration!(); - config.bucket = NON_EXISTENT_NAME.into(); - let integration = config.build_test().unwrap(); + let integration = maybe_skip_integration!() + .with_bucket_name(NON_EXISTENT_NAME) + .build() + .unwrap(); let location = Path::from_iter([NON_EXISTENT_NAME]); let data = Bytes::from("arbitrary data");