From f57eb76c36b3f6448b41cb71a7a5422f969bdd21 Mon Sep 17 00:00:00 2001 From: Jin Igarashi Date: Wed, 18 Dec 2024 20:13:18 +0900 Subject: [PATCH] feat: use prompt to setup config.json and root data folder in init command --- cbsurge/cli.py | 2 +- cbsurge/initialize.py | 85 ++++++++++++ cbsurge/session.py | 236 +++++++++++++++++++++++++++++----- pyproject.toml | 2 +- tests/cbsurge/test_session.py | 19 +++ 5 files changed, 307 insertions(+), 37 deletions(-) create mode 100644 cbsurge/initialize.py create mode 100644 tests/cbsurge/test_session.py diff --git a/cbsurge/cli.py b/cbsurge/cli.py index 227edb95..27e3c1f8 100644 --- a/cbsurge/cli.py +++ b/cbsurge/cli.py @@ -1,5 +1,5 @@ import click as click -from cbsurge.session import init +from cbsurge.initialize import init from cbsurge.admin import admin from cbsurge.exposure.builtenv import builtenv from cbsurge.exposure.population import population diff --git a/cbsurge/initialize.py b/cbsurge/initialize.py new file mode 100644 index 00000000..8b931bef --- /dev/null +++ b/cbsurge/initialize.py @@ -0,0 +1,85 @@ +import logging +import click +import os +import shutil +from cbsurge.session import Session +from cbsurge.admin import silence_httpx_az + +logger = logging.getLogger(__name__) + + +def setup_prompt(session: Session): + credential, token = session.authenticate() + logger.debug(token) + if token is None: + click.prompt("Authentication failed.Please `az login` to authenticate first.") + return + + click.echo("Authentication successful. We need more information to setup from you.") + + # project root data folder + root_data_folder = None + absolute_root_data_folder = None + while not(root_data_folder is not None and os.path.exists(absolute_root_data_folder)): + data_folder = click.prompt("Please enter project root folder to store all data. Enter to skip if use default value", default="~/cbsurge") + absolute_root_data_folder = os.path.expanduser(data_folder) + + if os.path.exists(absolute_root_data_folder): + if click.confirm("The folder already exists. Yes to overwrite, No/Enter to use existing folder", default=False): + shutil.rmtree(absolute_root_data_folder) + click.echo(f"Removed folder {absolute_root_data_folder}") + os.makedirs(absolute_root_data_folder) + click.echo(f"The project root folder was created at {absolute_root_data_folder}") + root_data_folder = data_folder + else: + click.echo(f"Use {absolute_root_data_folder} as the root folder.") + root_data_folder = data_folder + else: + os.makedirs(absolute_root_data_folder) + click.echo(f"The project root folder was created at {absolute_root_data_folder}") + root_data_folder = data_folder + session.set_root_data_folder(root_data_folder) + + # azure blob container setting + account_name = click.prompt('Please enter account name for UNDP Azure. Enter to skip if use default value', + type=str, default='undpgeohub') + session.set_account_name(account_name) + click.echo(f"account name: {account_name}") + + container_name = click.prompt('Please enter container name for UNDP Azure. Enter to skip if use default value', + type=str, default='stacdata') + session.set_container_name(container_name) + click.echo(f"container name: {container_name}") + + # azure file share setting + share_name = click.prompt('Please enter share name for UNDP Azure. Enter to skip if use default value', + type=str, default='cbrapida') + session.set_file_share_name(share_name) + click.echo(f"file share name: {share_name}") + + session.save_config() + click.echo('Setting up was successfully done!') + + +@click.command() +@click.option('--debug', + is_flag=True, + default=False, + help="Set log level to debug" + ) +def init(debug=False): + """ + This command setup rapida command environment by authenticating to Azure. + """ + silence_httpx_az() + logging.basicConfig(level=logging.DEBUG if debug else logging.INFO, force=True) + + click.echo("Welcome to rapida CLI tool!") + with Session() as session: + config = session.get_config() + if config: + if click.confirm('Your setup has already been done. Would you like to do setup again?', abort=True): + setup_prompt(session) + else: + if click.confirm('Would you like to setup rapida tool?', abort=True): + setup_prompt(session) diff --git a/cbsurge/session.py b/cbsurge/session.py index bc7a7da1..31c6caa1 100644 --- a/cbsurge/session.py +++ b/cbsurge/session.py @@ -1,28 +1,134 @@ import logging -import click +import os +import json from azure.identity import DefaultAzureCredential, AzureAuthorityHosts from azure.core.exceptions import ClientAuthenticationError -from azure.storage.blob import BlobServiceClient +from azure.storage.blob.aio import BlobServiceClient, ContainerClient +from azure.storage.fileshare.aio import ShareServiceClient logger = logging.getLogger(__name__) class Session(object): - def __init__(self, scopes = "https://storage.azure.com/.default"): + def __init__(self): """ constructor - - Parameters: - scopes: scopes for get_token method. Default to "https://storage.azure.com/.default" """ - self.scopes = scopes + self.config = self.get_config() + if self.config is not None: + logger.debug(f"config was loaded: {self.config}") + def __enter__(self): return self + def __exit__(self, exc_type, exc_value, traceback): - self.scopes = None + pass + + + def get_config_file_path(self) -> str: + user_dir = os.path.expanduser("~") + config_file_path = os.path.join(user_dir, ".cbsurge", "config.json") + return config_file_path + + + def get_config(self): + """ + get config from ~/.cbsurge/config.json + + Returns: + JSON object + """ + config_file_path = self.get_config_file_path() + if os.path.exists(config_file_path): + with open(config_file_path, "r", encoding="utf-8") as data: + return json.load(data) + else: + return None + + def get_config_value_by_key(self, key: str, default=None): + """ + get config value by key + + Parameters: + key (str): key + default (str): default value if not exists. Default is None + """ + if self.config is None: + self.config = self.get_config() + if self.config is not None: + return self.config.get(key, default) + else: + return default + + + def set_config_value_by_key(self, key: str, value): + if self.config is None: + self.config = {} + self.config[key] = value + + + def set_root_data_folder(self, folder_name): + self.set_config_value_by_key("root_data_folder", folder_name) + + def get_root_data_folder(self, is_absolute_path=True): + """ + get root data folder + + Parameters: + is_absolute_path (bool): Optional. If true, return absolute path, otherwise relative path. Default is True. + Returns: + root data folder path (str) + """ + root_data_folder = self.get_config_value_by_key("root_data_folder") + if is_absolute_path: + return os.path.expanduser(root_data_folder) + else: + return root_data_folder + + def set_account_name(self, account_name: str): + self.set_config_value_by_key("account_name", account_name) + + def get_account_name(self): + return self.get_config_value_by_key("account_name") + + def set_container_name(self, container_name: str): + self.set_config_value_by_key("container_name", container_name) + + def get_container_name(self): + return self.get_config_value_by_key("container_name") + + def set_file_share_name(self, file_share_name: str): + self.set_config_value_by_key("file_share_name", file_share_name) + + def get_file_share_name(self): + return self.get_config_value_by_key("file_share_name") + + def save_config(self): + """ + Save config.json under user directory as ~/.cbsurge/config.json + """ + if self.get_root_data_folder() is None: + raise RuntimeError(f"root_data_folder is not set") + if self.get_account_name() is None: + raise RuntimeError(f"account_name is not set") + if self.get_container_name() is None: + raise RuntimeError(f"container_name is not set") + if self.get_file_share_name() is None: + raise RuntimeError(f"file_share_name is not set") + + config_file_path = self.get_config_file_path() + + dir_path = os.path.dirname(config_file_path) + if not os.path.exists(dir_path): + os.makedirs(dir_path) + + with open(config_file_path, "w", encoding="utf-8") as file: + json.dump(self.config, file, ensure_ascii=False, indent=4) + + logger.debug(f"config file was saved to {config_file_path}") def get_credential(self): @@ -48,10 +154,13 @@ def get_credential(self): credential = DefaultAzureCredential() return credential - def get_token(self): + + def get_token(self, scopes = "https://storage.azure.com/.default"): """ get access token for blob storage account. This token is required for using Azure REST API. + Parameters: + scopes: scopes for get_token method. Default to "https://storage.azure.com/.default" Returns: Azure token is returned if authenticated. Raises: @@ -62,13 +171,14 @@ def get_token(self): """ try: credential = self.get_credential() - token = credential.get_token(self.scopes) + token = credential.get_token(scopes) return token except ClientAuthenticationError as err: logger.error("authentication failed. Please use 'rapida init' command to setup credentials.") raise err - def authenticate(self): + + def authenticate(self, scopes = "https://storage.azure.com/.default"): """ Authenticate to Azure through interactive browser if DefaultAzureCredential is not provideds. Authentication uses DefaultAzureCredential. @@ -76,6 +186,8 @@ def authenticate(self): Please refer to https://learn.microsoft.com/en-us/python/api/azure-identity/azure.identity.defaultazurecredential?view=azure-python about DefaultAzureCredential api specificaiton. + Parameters: + scopes: scopes for get_token method. Default to "https://storage.azure.com/.default" Returns: Azure credential and token are returned if authenticated. If authentication failed, return None. """ @@ -83,17 +195,20 @@ def authenticate(self): credential = DefaultAzureCredential( exclude_interactive_browser_credential=False, ) - token = credential.get_token(self.scopes) + token = credential.get_token(scopes) return [credential, token] except ClientAuthenticationError as err: logger.error("authentication failed.") logger.error(err) return None - def get_blob_service_client(self, account_name: str) -> BlobServiceClient: + + def get_blob_service_client(self, account_name: str = None) -> BlobServiceClient: """ get BlobServiceClient for account url + If the parameter is not set, use default account name from config. + Usage example: with Session() as session: blob_service_client = session.get_blob_service_client( @@ -106,32 +221,83 @@ def get_blob_service_client(self, account_name: str) -> BlobServiceClient: BlobServiceClient """ credential = self.get_credential() + account_url = self.get_blob_service_account_url(account_name) blob_service_client = BlobServiceClient( - account_url=f"https://{account_name}.blob.core.windows.net", + account_url=account_url, credential=credential ) return blob_service_client + def get_blob_container_client(self, account_name: str = None, container_name: str = None) -> ContainerClient: + """ + get ContainerClient for account name and container name -@click.command() -@click.option('--debug', - is_flag=True, - default=False, - help="Set log level to debug" - ) -def init(debug=False): - """ - This command setup rapida command environment by authenticating to Azure. - """ - logging.basicConfig(level=logging.DEBUG if debug else logging.INFO, force=True) - - if click.confirm('Would you like to setup rapida tool?', abort=True): - # login to Azure - session = Session() - credential, token = session.authenticate() - logger.debug(token) - if token is None: - logger.info("Authentication failed.Please `az login` to authenticate first.") - return - click.echo('Setting up was successfully done!') + If the parameter is not set, use default account name from config. + + Parameters: + account_name (str): name of storage account. https://{account_name}.blob.core.windows.net + container_name (str): name of storage container name. https://{account_name}.blob.core.windows.net/{container_name} + Returns: + ContainerClient + """ + blob_service_client = self.get_blob_service_client(account_name) + ct_name = container_name if container_name is not None else self.get_container_name() + container_client = blob_service_client.get_container_client(ct_name) + return container_client + + def get_blob_service_account_url(self, account_name: str = None) -> str: + """ + get blob service account URL + + If the parameter is not set, use default account name from config. + + Parameters: + account_name (str): Optional. name of storage account url. + """ + ac_name = account_name if account_name is not None else self.get_account_name() + return f"https://{ac_name}.blob.core.windows.net" + def get_share_service_client(self, account_name: str = None, share_name: str = None) -> ShareServiceClient: + """ + get ShareServiceClient for account url + + If the parameter is not set, use default account name from config. + + Usage example: + with Session() as session: + share_service_client = session.get_share_service_client( + account_name="undpgeohub", + share_name="cbrapida" + ) + + Parameters: + account_name (str): name of storage account. + share_name (str): name of file share. + + both parameters are equivalent to the below URL's bracket places. + + https://{account_name}.file.core.windows.net/{share_name} + Returns: + ShareServiceClient + """ + credential = self.get_credential() + account_url = self.get_share_service_account_url(account_name, share_name) + share_service_client = ShareServiceClient( + account_url=account_url, + credential=credential + ) + return share_service_client + + def get_file_share_account_url(self, account_name: str = None, share_name: str = None) -> str: + """ + get blob service account URL + + If the parameter is not set, use default account name from config. + + Parameters: + account_name (str): Optional. name of storage account url. If the parameter is not set, use default account name from config. + share_name (str): name of file share. If the parameter is not set, use default account name from config. + """ + ac_name = account_name if account_name is not None else self.get_account_name() + sh_name = share_name if share_name is not None else self.get_file_share_name() + return f"https://{ac_name}.file.core.windows.net/{sh_name}" \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 59280b01..35dd41a9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,7 +22,7 @@ dependencies = [ "async", "aiofiles", "aiohttp", - "azure-storage-blob", + "azure-storage-file-share", "asyncclick", "rio-cogeo", "exactextract", diff --git a/tests/cbsurge/test_session.py b/tests/cbsurge/test_session.py new file mode 100644 index 00000000..f614d276 --- /dev/null +++ b/tests/cbsurge/test_session.py @@ -0,0 +1,19 @@ +from cbsurge.session import Session +import os + + +def test_session(): + with Session() as s: + s.set_account_name("test_account") + s.set_file_share_name("test_share") + s.set_root_data_folder("~/cbsurge") + + assert s.get_blob_service_account_url() == "https://test_account.blob.core.windows.net" + assert s.get_file_share_account_url() == "https://test_account.file.core.windows.net/test_share" + + assert s.get_blob_service_account_url(account_name="aaa") == "https://aaa.blob.core.windows.net" + assert s.get_file_share_account_url(account_name="aaa") == "https://aaa.file.core.windows.net/test_share" + assert s.get_file_share_account_url(account_name="aaa", share_name="bbb") == "https://aaa.file.core.windows.net/bbb" + + assert s.get_root_data_folder(False) == "~/cbsurge" + assert s.get_root_data_folder(True) == os.path.expanduser("~/cbsurge") \ No newline at end of file