From 9893b741fa39333693d06ceb8e592825a12216b4 Mon Sep 17 00:00:00 2001 From: Abdel Jaidi Date: Wed, 27 Jan 2021 23:10:26 +0000 Subject: [PATCH 01/25] Initial Commit --- awswrangler/__init__.py | 2 + awswrangler/_config.py | 11 ++ awswrangler/_utils.py | 2 + awswrangler/lakeformation/__init__.py | 9 ++ awswrangler/lakeformation/_read.py | 139 ++++++++++++++++++++++++++ awswrangler/lakeformation/_utils.py | 49 +++++++++ 6 files changed, 212 insertions(+) create mode 100644 awswrangler/lakeformation/__init__.py create mode 100644 awswrangler/lakeformation/_read.py create mode 100644 awswrangler/lakeformation/_utils.py diff --git a/awswrangler/__init__.py b/awswrangler/__init__.py index 25785e433..6249471ff 100644 --- a/awswrangler/__init__.py +++ b/awswrangler/__init__.py @@ -15,6 +15,7 @@ dynamodb, emr, exceptions, + lakeformation, mysql, postgresql, quicksight, @@ -40,6 +41,7 @@ "s3", "sts", "redshift", + "lakeformation", "mysql", "postgresql", "secretsmanager", diff --git a/awswrangler/_config.py b/awswrangler/_config.py index 852ea762a..4e5e5a169 100644 --- a/awswrangler/_config.py +++ b/awswrangler/_config.py @@ -41,6 +41,7 @@ class _ConfigArg(NamedTuple): "redshift_endpoint_url": _ConfigArg(dtype=str, nullable=True, enforced=True), "kms_endpoint_url": _ConfigArg(dtype=str, nullable=True, enforced=True), "emr_endpoint_url": _ConfigArg(dtype=str, nullable=True, enforced=True), + "lakeformation_endpoint_url": _ConfigArg(dtype=str, nullable=True, enforced=True), } @@ -57,6 +58,7 @@ def __init__(self) -> None: self.redshift_endpoint_url = None self.kms_endpoint_url = None self.emr_endpoint_url = None + self.lakeformation_endpoint_url = None for name in _CONFIG_ARGS: self._load_config(name=name) @@ -338,6 +340,15 @@ def emr_endpoint_url(self) -> Optional[str]: def emr_endpoint_url(self, value: Optional[str]) -> None: self._set_config_value(key="emr_endpoint_url", value=value) + @property + def lakeformation_endpoint_url(self) -> Optional[str]: + """Property lakeformation_endpoint_url.""" + return cast(Optional[str], self["lakeformation_endpoint_url"]) + + @lakeformation_endpoint_url.setter + def lakeformation_endpoint_url(self, value: Optional[str]) -> None: + self._set_config_value(key="lakeformation_endpoint_url", value=value) + def _insert_str(text: str, token: str, insert: str) -> str: """Insert string into other.""" diff --git a/awswrangler/_utils.py b/awswrangler/_utils.py index 7e4141a1b..271b34b72 100644 --- a/awswrangler/_utils.py +++ b/awswrangler/_utils.py @@ -80,6 +80,8 @@ def _get_endpoint_url(service_name: str) -> Optional[str]: endpoint_url = _config.config.kms_endpoint_url elif service_name == "emr" and _config.config.emr_endpoint_url is not None: endpoint_url = _config.config.emr_endpoint_url + elif service_name == "lakeformation" and _config.config.lakeformation_endpoint_url is not None: + endpoint_url = _config.config.lakeformation_endpoint_url return endpoint_url diff --git a/awswrangler/lakeformation/__init__.py b/awswrangler/lakeformation/__init__.py new file mode 100644 index 000000000..daf66c386 --- /dev/null +++ b/awswrangler/lakeformation/__init__.py @@ -0,0 +1,9 @@ +"""Amazon Lake Formation Module.""" + +from awswrangler.lakeformation._read import read_sql_query # noqa +from awswrangler.lakeformation._utils import wait_query # noqa + +__all__ = [ + "read_sql_query", + "wait_query", +] diff --git a/awswrangler/lakeformation/_read.py b/awswrangler/lakeformation/_read.py new file mode 100644 index 000000000..53db9882a --- /dev/null +++ b/awswrangler/lakeformation/_read.py @@ -0,0 +1,139 @@ +"""Amazon Lake Formation Module gathering all read functions.""" +import logging +import sys +from typing import Any, Dict, Iterator, Optional, Union + +import boto3 +import pandas as pd +import pyarrow as pa + +from awswrangler import _utils, exceptions +from awswrangler._config import apply_configs +from awswrangler.lakeformation._utils import wait_query + +_logger: logging.Logger = logging.getLogger(__name__) + + +@apply_configs +def read_sql_query( + sql: str, + database: str, + transaction_id: Optional[str] = None, + query_as_of_time: Optional[str] = None, + catalog_id: Optional[str] = None, + chunksize: Optional[Union[int, bool]] = None, + boto3_session: Optional[boto3.Session] = None, + params: Optional[Dict[str, Any]] = None, +) -> Union[pd.DataFrame, Iterator[pd.DataFrame]]: + """Execute PartiQL query against an AWS Glue Governed Table based on Transaction ID or time travel timestamp. Return single Pandas DataFrame or Iterator. + + Note + ---- + The database must NOT be explicitely defined in the PartiQL statement. + i.e. sql="SELECT * FROM my_table" is valid + but sql="SELECT * FROM my_db.my_table" is NOT valid + + Note + ---- + Pass one of `transaction_id` or `query_as_of_time`, not both. + + Note + ---- + `chunksize` argument (memory-friendly) (i.e batching): + + Return an Iterable of DataFrames instead of a regular DataFrame. + + There are two batching strategies: + + - If **chunksize=True**, a new DataFrame will be returned for each file in the query result. + + - If **chunksize=INTEGER**, Wrangler will iterate on the data so that the DataFrames number of rows is equal to INTEGER. + + `P.S.` `chunksize=True` is faster and uses less memory + + Parameters + ---------- + sql : str + partiQL query. + database : str + AWS Glue database name + transaction_id : str, optional + The ID of the transaction at which to read the table contents. Cannot be specified alongside query_as_of_time + query_as_of_time : str, optional + The time as of when to read the table contents. Must be a valid Unix epoch timestamp. Cannot be specified alongside transaction_id + catalog_id : str, optional + The ID of the Data Catalog from which to retrieve Databases. + If none is provided, the AWS account ID is used by default. + chunksize : Union[int, bool], optional + If passed will split the data into an Iterable of DataFrames (memory-friendly). + If `True`, Wrangler will iterate on the data by files in the most efficient way without guarantee of chunksize. + If an `INTEGER` is passed, Wrangler will iterate on the data so that the DataFrames number of rows is equal to INTEGER. + boto3_session : boto3.Session(), optional + Boto3 Session. The default boto3 session will be used if boto3_session receives None. + params: Dict[str, any], optional + Dict of parameters used to format the partiQL query. Only named parameters are supported. + The dict must contain the information in the form {'name': 'value'} and the SQL query must contain + `:name;`. + + Returns + ------- + Union[pd.DataFrame, Iterator[pd.DataFrame]] + Pandas DataFrame or Generator of Pandas DataFrames if chunksize is passed. + + Examples + -------- + >>> import awswrangler as wr + >>> df = wr.lakeformation.read_sql_query( + ... sql="SELECT * FROM my_table LIMIT 10;", + ... database="my_db", + ... transaction_id="ba9a11b5-619a-4ac3-bd70-5a744d09414c" + ... ) + + >>> import awswrangler as wr + >>> df = wr.lakeformation.read_sql_query( + ... sql="SELECT * FROM my_table WHERE name=:name;", + ... database="my_db", + ... query_as_of_time="1611142914", + ... params={"name": "filtered_name"} + ... ) + + """ + if transaction_id is None and query_as_of_time is None: + raise exceptions.InvalidArgumentCombination("Please pass one of transaction_id or query_as_of_time") + # TODO: Generate transaction_id if both transaction_id and query_as_of_time missing? + if transaction_id is not None and query_as_of_time is not None: + raise exceptions.InvalidArgumentCombination("Please pass only one of transaction_id or query_as_of_time, not both") + session: boto3.Session = _utils.ensure_session(session=boto3_session) + chunksize = sys.maxsize if chunksize is True else chunksize + if params is None: + params = {} + for key, value in params.items(): + sql = sql.replace(f":{key};", str(value)) + client_lakeformation: boto3.client = _utils.client(service_name="lakeformation", session=session) + # TODO: Check if the Glue Table is governed? + + args: Dict[str, Any] = { + "DatabaseName": database, + "Statement": sql + } + if catalog_id: + args["CatalogId"] = catalog_id + if transaction_id: + args["TransactionId"] = transaction_id + else: + args["QueryAsOfTime"] = query_as_of_time + query_id: str = client_lakeformation.plan_query(**args)["QueryId"] + + wait_query(query_id=query_id, boto3_session=session) + + work_units_output: Dict[str, Any] = client_lakeformation.get_work_units(QueryId=query_id) + print(work_units_output) + + a = client_lakeformation.execute(QueryId=query_id, Token=work_units_output["Units"][0]["Token"], WorkUnitId=0) + print(a) + + buf = a["Messages"].read() + stream = pa.RecordBatchStreamReader(buf) + table = stream.read_all() + df = table.to_pandas() + return df diff --git a/awswrangler/lakeformation/_utils.py b/awswrangler/lakeformation/_utils.py new file mode 100644 index 000000000..664b5ea65 --- /dev/null +++ b/awswrangler/lakeformation/_utils.py @@ -0,0 +1,49 @@ +"""Utilities Module for Amazon Lake Formation.""" +import logging +import time +from typing import Any, Dict, List, Optional + +import boto3 + +from awswrangler import _utils, exceptions + +_QUERY_FINAL_STATES: List[str] = ["ERROR", "FINISHED"] +_QUERY_WAIT_POLLING_DELAY: float = 2 # SECONDS + +_logger: logging.Logger = logging.getLogger(__name__) + + +def wait_query(query_id: str, boto3_session: Optional[boto3.Session] = None) -> Dict[str, Any]: + """Wait for the query to end. + + Parameters + ---------- + query_id : str + Lake Formation query execution ID. + boto3_session : boto3.Session(), optional + Boto3 Session. The default boto3 session will be used if boto3_session received None. + + Returns + ------- + Dict[str, Any] + Dictionary with the get_query_state response. + + Examples + -------- + >>> import awswrangler as wr + >>> res = wr.lakeformation.wait_query(query_id='query-id') + + """ + session: boto3.Session = _utils.ensure_session(session=boto3_session) + client_lakeformation: boto3.client = _utils.client(service_name="lakeformation", session=session) + + response: Dict[str, Any] = client_lakeformation.get_query_state(QueryId=query_id) + state: str = response["State"] + while state not in _QUERY_FINAL_STATES: + time.sleep(_QUERY_WAIT_POLLING_DELAY) + response = client_lakeformation.get_query_state(QueryId=query_id) + state = response["State"] + _logger.debug("state: %s", state) + if state == "ERROR": + raise exceptions.QueryFailed(response.get("Error")) + return response From ea9986ddc277a828287268876b700104814b4311 Mon Sep 17 00:00:00 2001 From: Abdel Jaidi Date: Thu, 28 Jan 2021 21:57:11 +0000 Subject: [PATCH 02/25] Minor - Refactoring Work Units Logic --- awswrangler/lakeformation/_read.py | 71 ++++++++++++++++++------------ 1 file changed, 43 insertions(+), 28 deletions(-) diff --git a/awswrangler/lakeformation/_read.py b/awswrangler/lakeformation/_read.py index 53db9882a..48a6006e5 100644 --- a/awswrangler/lakeformation/_read.py +++ b/awswrangler/lakeformation/_read.py @@ -1,11 +1,11 @@ """Amazon Lake Formation Module gathering all read functions.""" import logging import sys -from typing import Any, Dict, Iterator, Optional, Union +from typing import Any, Dict, Iterator, List, Optional, Tuple, Union import boto3 import pandas as pd -import pyarrow as pa +from pyarrow import RecordBatchStreamReader from awswrangler import _utils, exceptions from awswrangler._config import apply_configs @@ -25,7 +25,7 @@ def read_sql_query( boto3_session: Optional[boto3.Session] = None, params: Optional[Dict[str, Any]] = None, ) -> Union[pd.DataFrame, Iterator[pd.DataFrame]]: - """Execute PartiQL query against an AWS Glue Governed Table based on Transaction ID or time travel timestamp. Return single Pandas DataFrame or Iterator. + """Execute PartiQL query against an AWS Glue Table based on Transaction ID or time travel timestamp. Return single Pandas DataFrame or Iterator. Note ---- @@ -82,11 +82,18 @@ def read_sql_query( Examples -------- + >>> import awswrangler as wr + >>> df = wr.lakeformation.read_sql_query( + ... sql="SELECT * FROM my_table;", + ... database="my_db", + ... catalog_id="111111111111" + ... ) + >>> import awswrangler as wr >>> df = wr.lakeformation.read_sql_query( ... sql="SELECT * FROM my_table LIMIT 10;", ... database="my_db", - ... transaction_id="ba9a11b5-619a-4ac3-bd70-5a744d09414c" + ... transaction_id="1b62811fa3e02c4e5fdbaa642b752030379c4a8a70da1f8732ce6ccca47afdc9" ... ) >>> import awswrangler as wr @@ -94,46 +101,54 @@ def read_sql_query( ... sql="SELECT * FROM my_table WHERE name=:name;", ... database="my_db", ... query_as_of_time="1611142914", - ... params={"name": "filtered_name"} + ... params={"name": "\'filtered_name\'"} ... ) """ - if transaction_id is None and query_as_of_time is None: - raise exceptions.InvalidArgumentCombination("Please pass one of transaction_id or query_as_of_time") - # TODO: Generate transaction_id if both transaction_id and query_as_of_time missing? if transaction_id is not None and query_as_of_time is not None: - raise exceptions.InvalidArgumentCombination("Please pass only one of transaction_id or query_as_of_time, not both") + raise exceptions.InvalidArgumentCombination( + "Please pass only one of `transaction_id` or `query_as_of_time`, not both" + ) session: boto3.Session = _utils.ensure_session(session=boto3_session) chunksize = sys.maxsize if chunksize is True else chunksize if params is None: params = {} for key, value in params.items(): - sql = sql.replace(f":{key};", str(value)) + sql = sql.replace(f":{key}", str(value)) client_lakeformation: boto3.client = _utils.client(service_name="lakeformation", session=session) - # TODO: Check if the Glue Table is governed? - args: Dict[str, Any] = { - "DatabaseName": database, - "Statement": sql - } + args: Dict[str, Any] = {"DatabaseName": database, "Statement": sql} if catalog_id: args["CatalogId"] = catalog_id - if transaction_id: + if query_as_of_time: + args["QueryAsOfTime"] = query_as_of_time + elif transaction_id: args["TransactionId"] = transaction_id else: - args["QueryAsOfTime"] = query_as_of_time + _logger.debug("Neither `transaction_id` nor `query_as_of_time` were specified, beginning transaction") + transaction_id = client_lakeformation.begin_transaction(ReadOnly=True)["TransactionId"] + args["TransactionId"] = transaction_id query_id: str = client_lakeformation.plan_query(**args)["QueryId"] wait_query(query_id=query_id, boto3_session=session) - work_units_output: Dict[str, Any] = client_lakeformation.get_work_units(QueryId=query_id) - print(work_units_output) - - a = client_lakeformation.execute(QueryId=query_id, Token=work_units_output["Units"][0]["Token"], WorkUnitId=0) - print(a) - - buf = a["Messages"].read() - stream = pa.RecordBatchStreamReader(buf) - table = stream.read_all() - df = table.to_pandas() - return df + scan_kwargs: Dict[str, Union[str, int]] = {"QueryId": query_id, "PageSize": 2} # TODO: Inquire about good page size + next_token: str = "init_token" # Dummy token + token_work_units: List[Tuple[str, int]] = [] + while next_token: + response = client_lakeformation.get_work_units(**scan_kwargs) + token_work_units.extend( # [(Token0, WorkUnitId0), (Token0, WorkUnitId1), (Token1, WorkUnitId0) ... ] + [ + (unit["Token"], unit_id) + for unit in response["Units"] + for unit_id in range(unit["WorkUnitIdMin"], unit["WorkUnitIdMax"] + 1) # Max is inclusive + ] + ) + next_token = response["NextToken"] + scan_kwargs["NextToken"] = next_token + + dfs: List[pd.DataFrame] = [] + for token, work_unit in token_work_units: + messages: Any = client_lakeformation.execute(QueryId=query_id, Token=token, WorkUnitId=work_unit)["Messages"] + dfs.append(RecordBatchStreamReader(messages.read()).read_pandas()) + return pd.concat(dfs) From f235e7decb58e218224c9f4a34d7ef7162c8f908 Mon Sep 17 00:00:00 2001 From: Abdel Jaidi Date: Fri, 29 Jan 2021 16:38:28 +0000 Subject: [PATCH 03/25] Major - Checkpoint w/ functional read code/example --- awswrangler/lakeformation/_read.py | 141 ++++++++++++++++++----------- tests/test_lakeformation.py | 35 +++++++ 2 files changed, 125 insertions(+), 51 deletions(-) create mode 100644 tests/test_lakeformation.py diff --git a/awswrangler/lakeformation/_read.py b/awswrangler/lakeformation/_read.py index 48a6006e5..df9e9f4d7 100644 --- a/awswrangler/lakeformation/_read.py +++ b/awswrangler/lakeformation/_read.py @@ -1,11 +1,12 @@ """Amazon Lake Formation Module gathering all read functions.""" +import concurrent.futures +import itertools import logging -import sys from typing import Any, Dict, Iterator, List, Optional, Tuple, Union import boto3 import pandas as pd -from pyarrow import RecordBatchStreamReader +from pyarrow import NativeFile, RecordBatchStreamReader from awswrangler import _utils, exceptions from awswrangler._config import apply_configs @@ -14,6 +15,71 @@ _logger: logging.Logger = logging.getLogger(__name__) +def _execute_query( + query_id: str, + token_work_unit: Tuple[str, int], + boto3_session: Optional[boto3.Session] = None, +) -> pd.DataFrame: + session: boto3.Session = _utils.ensure_session(session=boto3_session) + client_lakeformation: boto3.client = _utils.client(service_name="lakeformation", session=session) + token, work_unit = token_work_unit + messages: NativeFile = client_lakeformation.execute(QueryId=query_id, Token=token, WorkUnitId=work_unit)["Messages"] + return RecordBatchStreamReader(messages.read()).read_pandas() + + +def _resolve_sql_query( + query_id: str, + chunked: Optional[bool] = None, + use_threads: bool = True, + boto3_session: Optional[boto3.Session] = None, +) -> Union[pd.DataFrame, Iterator[pd.DataFrame]]: + session: boto3.Session = _utils.ensure_session(session=boto3_session) + client_lakeformation: boto3.client = _utils.client(service_name="lakeformation", session=session) + + wait_query(query_id=query_id, boto3_session=session) + + # The LF Query Engine distributes the load across workers + # Retrieve the tokens and their associated work units until NextToken is '' + # One Token can span multiple work units + # PageSize determines the size of the "Units" array in each call + # TODO: Inquire about good page size # pylint: disable=W0511 + scan_kwargs: Dict[str, Union[str, int]] = {"QueryId": query_id, "PageSize": 2} + next_token: str = "init_token" # Dummy token + token_work_units: List[Tuple[str, int]] = [] + while next_token: + response = client_lakeformation.get_work_units(**scan_kwargs) + token_work_units.extend( # [(Token0, WorkUnitId0), (Token0, WorkUnitId1), (Token1, WorkUnitId2) ... ] + [ + (unit["Token"], unit_id) + for unit in response["Units"] + for unit_id in range(unit["WorkUnitIdMin"], unit["WorkUnitIdMax"] + 1) # Max is inclusive + ] + ) + next_token = response["NextToken"] + scan_kwargs["NextToken"] = next_token + + dfs: List[pd.DataFrame] = list() + if use_threads is False: + dfs = list( + _execute_query(query_id=query_id, token_work_unit=token_work_unit, boto3_session=boto3_session) + for token_work_unit in token_work_units + ) + else: + cpus: int = _utils.ensure_cpu_count(use_threads=use_threads) + with concurrent.futures.ThreadPoolExecutor(max_workers=cpus) as executor: + dfs = list( + executor.map( + _execute_query, + itertools.repeat(query_id), + token_work_units, + itertools.repeat(_utils.boto3_to_primitives(boto3_session=boto3_session)), + ) + ) + if not chunked: + return pd.concat(dfs) + return dfs + + @apply_configs def read_sql_query( sql: str, @@ -21,11 +87,12 @@ def read_sql_query( transaction_id: Optional[str] = None, query_as_of_time: Optional[str] = None, catalog_id: Optional[str] = None, - chunksize: Optional[Union[int, bool]] = None, + chunked: bool = False, + use_threads: bool = True, boto3_session: Optional[boto3.Session] = None, params: Optional[Dict[str, Any]] = None, ) -> Union[pd.DataFrame, Iterator[pd.DataFrame]]: - """Execute PartiQL query against an AWS Glue Table based on Transaction ID or time travel timestamp. Return single Pandas DataFrame or Iterator. + """Execute PartiQL query on AWS Glue Table (Transaction ID or time travel timestamp). Return Pandas DataFrame. Note ---- @@ -39,17 +106,8 @@ def read_sql_query( Note ---- - `chunksize` argument (memory-friendly) (i.e batching): - - Return an Iterable of DataFrames instead of a regular DataFrame. - - There are two batching strategies: - - - If **chunksize=True**, a new DataFrame will be returned for each file in the query result. - - - If **chunksize=INTEGER**, Wrangler will iterate on the data so that the DataFrames number of rows is equal to INTEGER. - - `P.S.` `chunksize=True` is faster and uses less memory + `chunked` argument (memory-friendly): + If set to `True`, return an Iterable of DataFrames instead of a regular DataFrame. Parameters ---------- @@ -58,27 +116,30 @@ def read_sql_query( database : str AWS Glue database name transaction_id : str, optional - The ID of the transaction at which to read the table contents. Cannot be specified alongside query_as_of_time + The ID of the transaction at which to read the table contents. + Cannot be specified alongside query_as_of_time query_as_of_time : str, optional - The time as of when to read the table contents. Must be a valid Unix epoch timestamp. Cannot be specified alongside transaction_id + The time as of when to read the table contents. Must be a valid Unix epoch timestamp. + Cannot be specified alongside transaction_id catalog_id : str, optional The ID of the Data Catalog from which to retrieve Databases. If none is provided, the AWS account ID is used by default. - chunksize : Union[int, bool], optional - If passed will split the data into an Iterable of DataFrames (memory-friendly). - If `True`, Wrangler will iterate on the data by files in the most efficient way without guarantee of chunksize. - If an `INTEGER` is passed, Wrangler will iterate on the data so that the DataFrames number of rows is equal to INTEGER. + chunked : bool, optional + If `True`, Wrangler returns an Iterable of DataFrames with no guarantee of chunksize. + use_threads : bool + True to enable concurrent requests, False to disable multiple threads. + When enabled, os.cpu_count() is used as the max number of threads. boto3_session : boto3.Session(), optional - Boto3 Session. The default boto3 session will be used if boto3_session receives None. + Boto3 Session. The default boto3 session is used if boto3_session receives None. params: Dict[str, any], optional Dict of parameters used to format the partiQL query. Only named parameters are supported. - The dict must contain the information in the form {'name': 'value'} and the SQL query must contain - `:name;`. + The dict must contain the information in the form {"name": "value"} and the SQL query must contain + `:name`. Returns ------- Union[pd.DataFrame, Iterator[pd.DataFrame]] - Pandas DataFrame or Generator of Pandas DataFrames if chunksize is passed. + Pandas DataFrame or Generator of Pandas DataFrames if chunked is passed. Examples -------- @@ -101,7 +162,7 @@ def read_sql_query( ... sql="SELECT * FROM my_table WHERE name=:name;", ... database="my_db", ... query_as_of_time="1611142914", - ... params={"name": "\'filtered_name\'"} + ... params={"name": "'filtered_name'"} ... ) """ @@ -110,14 +171,13 @@ def read_sql_query( "Please pass only one of `transaction_id` or `query_as_of_time`, not both" ) session: boto3.Session = _utils.ensure_session(session=boto3_session) - chunksize = sys.maxsize if chunksize is True else chunksize + client_lakeformation: boto3.client = _utils.client(service_name="lakeformation", session=session) if params is None: params = {} for key, value in params.items(): sql = sql.replace(f":{key}", str(value)) - client_lakeformation: boto3.client = _utils.client(service_name="lakeformation", session=session) - args: Dict[str, Any] = {"DatabaseName": database, "Statement": sql} + args: Dict[str, Optional[str]] = {"DatabaseName": database, "Statement": sql} if catalog_id: args["CatalogId"] = catalog_id if query_as_of_time: @@ -130,25 +190,4 @@ def read_sql_query( args["TransactionId"] = transaction_id query_id: str = client_lakeformation.plan_query(**args)["QueryId"] - wait_query(query_id=query_id, boto3_session=session) - - scan_kwargs: Dict[str, Union[str, int]] = {"QueryId": query_id, "PageSize": 2} # TODO: Inquire about good page size - next_token: str = "init_token" # Dummy token - token_work_units: List[Tuple[str, int]] = [] - while next_token: - response = client_lakeformation.get_work_units(**scan_kwargs) - token_work_units.extend( # [(Token0, WorkUnitId0), (Token0, WorkUnitId1), (Token1, WorkUnitId0) ... ] - [ - (unit["Token"], unit_id) - for unit in response["Units"] - for unit_id in range(unit["WorkUnitIdMin"], unit["WorkUnitIdMax"] + 1) # Max is inclusive - ] - ) - next_token = response["NextToken"] - scan_kwargs["NextToken"] = next_token - - dfs: List[pd.DataFrame] = [] - for token, work_unit in token_work_units: - messages: Any = client_lakeformation.execute(QueryId=query_id, Token=token, WorkUnitId=work_unit)["Messages"] - dfs.append(RecordBatchStreamReader(messages.read()).read_pandas()) - return pd.concat(dfs) + return _resolve_sql_query(query_id=query_id, chunked=chunked, use_threads=use_threads, boto3_session=boto3_session) diff --git a/tests/test_lakeformation.py b/tests/test_lakeformation.py new file mode 100644 index 000000000..783f0df16 --- /dev/null +++ b/tests/test_lakeformation.py @@ -0,0 +1,35 @@ +import logging + +import pytest + +import awswrangler as wr + +from ._utils import get_df_csv + +logging.getLogger("awswrangler").setLevel(logging.DEBUG) + + +@pytest.mark.parametrize("use_threads", [True, False]) +def test_lakeformation(path, glue_database, glue_table, use_threads): + table = f"__{glue_table}" + wr.catalog.delete_table_if_exists(database=glue_database, table=table) + wr.s3.to_parquet( + df=get_df_csv(), + path=path, + index=False, + boto3_session=None, + s3_additional_kwargs=None, + dataset=True, + table=table, + database=glue_database, + partition_cols=["par0", "par1"], + mode="overwrite", + ) + df = wr.lakeformation.read_sql_query( + sql=f"SELECT * FROM {table} WHERE id = :id;", + database=glue_database, + use_threads=use_threads, + params={"id": 1}, + ) + assert len(df.index) == 1 + wr.catalog.delete_table_if_exists(database=glue_database, table=table) From df1bfb76703a2444be36c0ecf71a5b89f468994d Mon Sep 17 00:00:00 2001 From: Abdel Jaidi Date: Wed, 27 Jan 2021 23:10:26 +0000 Subject: [PATCH 04/25] Initial Commit --- awswrangler/__init__.py | 2 + awswrangler/_config.py | 11 ++ awswrangler/_utils.py | 2 + awswrangler/lakeformation/__init__.py | 9 ++ awswrangler/lakeformation/_read.py | 139 ++++++++++++++++++++++++++ awswrangler/lakeformation/_utils.py | 49 +++++++++ 6 files changed, 212 insertions(+) create mode 100644 awswrangler/lakeformation/__init__.py create mode 100644 awswrangler/lakeformation/_read.py create mode 100644 awswrangler/lakeformation/_utils.py diff --git a/awswrangler/__init__.py b/awswrangler/__init__.py index 25785e433..6249471ff 100644 --- a/awswrangler/__init__.py +++ b/awswrangler/__init__.py @@ -15,6 +15,7 @@ dynamodb, emr, exceptions, + lakeformation, mysql, postgresql, quicksight, @@ -40,6 +41,7 @@ "s3", "sts", "redshift", + "lakeformation", "mysql", "postgresql", "secretsmanager", diff --git a/awswrangler/_config.py b/awswrangler/_config.py index 0859c86ce..fe4ab362f 100644 --- a/awswrangler/_config.py +++ b/awswrangler/_config.py @@ -41,6 +41,7 @@ class _ConfigArg(NamedTuple): "redshift_endpoint_url": _ConfigArg(dtype=str, nullable=True, enforced=True), "kms_endpoint_url": _ConfigArg(dtype=str, nullable=True, enforced=True), "emr_endpoint_url": _ConfigArg(dtype=str, nullable=True, enforced=True), + "lakeformation_endpoint_url": _ConfigArg(dtype=str, nullable=True, enforced=True), } @@ -57,6 +58,7 @@ def __init__(self) -> None: self.redshift_endpoint_url = None self.kms_endpoint_url = None self.emr_endpoint_url = None + self.lakeformation_endpoint_url = None for name in _CONFIG_ARGS: self._load_config(name=name) @@ -338,6 +340,15 @@ def emr_endpoint_url(self) -> Optional[str]: def emr_endpoint_url(self, value: Optional[str]) -> None: self._set_config_value(key="emr_endpoint_url", value=value) + @property + def lakeformation_endpoint_url(self) -> Optional[str]: + """Property lakeformation_endpoint_url.""" + return cast(Optional[str], self["lakeformation_endpoint_url"]) + + @lakeformation_endpoint_url.setter + def lakeformation_endpoint_url(self, value: Optional[str]) -> None: + self._set_config_value(key="lakeformation_endpoint_url", value=value) + def _insert_str(text: str, token: str, insert: str) -> str: """Insert string into other.""" diff --git a/awswrangler/_utils.py b/awswrangler/_utils.py index 7e4141a1b..271b34b72 100644 --- a/awswrangler/_utils.py +++ b/awswrangler/_utils.py @@ -80,6 +80,8 @@ def _get_endpoint_url(service_name: str) -> Optional[str]: endpoint_url = _config.config.kms_endpoint_url elif service_name == "emr" and _config.config.emr_endpoint_url is not None: endpoint_url = _config.config.emr_endpoint_url + elif service_name == "lakeformation" and _config.config.lakeformation_endpoint_url is not None: + endpoint_url = _config.config.lakeformation_endpoint_url return endpoint_url diff --git a/awswrangler/lakeformation/__init__.py b/awswrangler/lakeformation/__init__.py new file mode 100644 index 000000000..daf66c386 --- /dev/null +++ b/awswrangler/lakeformation/__init__.py @@ -0,0 +1,9 @@ +"""Amazon Lake Formation Module.""" + +from awswrangler.lakeformation._read import read_sql_query # noqa +from awswrangler.lakeformation._utils import wait_query # noqa + +__all__ = [ + "read_sql_query", + "wait_query", +] diff --git a/awswrangler/lakeformation/_read.py b/awswrangler/lakeformation/_read.py new file mode 100644 index 000000000..53db9882a --- /dev/null +++ b/awswrangler/lakeformation/_read.py @@ -0,0 +1,139 @@ +"""Amazon Lake Formation Module gathering all read functions.""" +import logging +import sys +from typing import Any, Dict, Iterator, Optional, Union + +import boto3 +import pandas as pd +import pyarrow as pa + +from awswrangler import _utils, exceptions +from awswrangler._config import apply_configs +from awswrangler.lakeformation._utils import wait_query + +_logger: logging.Logger = logging.getLogger(__name__) + + +@apply_configs +def read_sql_query( + sql: str, + database: str, + transaction_id: Optional[str] = None, + query_as_of_time: Optional[str] = None, + catalog_id: Optional[str] = None, + chunksize: Optional[Union[int, bool]] = None, + boto3_session: Optional[boto3.Session] = None, + params: Optional[Dict[str, Any]] = None, +) -> Union[pd.DataFrame, Iterator[pd.DataFrame]]: + """Execute PartiQL query against an AWS Glue Governed Table based on Transaction ID or time travel timestamp. Return single Pandas DataFrame or Iterator. + + Note + ---- + The database must NOT be explicitely defined in the PartiQL statement. + i.e. sql="SELECT * FROM my_table" is valid + but sql="SELECT * FROM my_db.my_table" is NOT valid + + Note + ---- + Pass one of `transaction_id` or `query_as_of_time`, not both. + + Note + ---- + `chunksize` argument (memory-friendly) (i.e batching): + + Return an Iterable of DataFrames instead of a regular DataFrame. + + There are two batching strategies: + + - If **chunksize=True**, a new DataFrame will be returned for each file in the query result. + + - If **chunksize=INTEGER**, Wrangler will iterate on the data so that the DataFrames number of rows is equal to INTEGER. + + `P.S.` `chunksize=True` is faster and uses less memory + + Parameters + ---------- + sql : str + partiQL query. + database : str + AWS Glue database name + transaction_id : str, optional + The ID of the transaction at which to read the table contents. Cannot be specified alongside query_as_of_time + query_as_of_time : str, optional + The time as of when to read the table contents. Must be a valid Unix epoch timestamp. Cannot be specified alongside transaction_id + catalog_id : str, optional + The ID of the Data Catalog from which to retrieve Databases. + If none is provided, the AWS account ID is used by default. + chunksize : Union[int, bool], optional + If passed will split the data into an Iterable of DataFrames (memory-friendly). + If `True`, Wrangler will iterate on the data by files in the most efficient way without guarantee of chunksize. + If an `INTEGER` is passed, Wrangler will iterate on the data so that the DataFrames number of rows is equal to INTEGER. + boto3_session : boto3.Session(), optional + Boto3 Session. The default boto3 session will be used if boto3_session receives None. + params: Dict[str, any], optional + Dict of parameters used to format the partiQL query. Only named parameters are supported. + The dict must contain the information in the form {'name': 'value'} and the SQL query must contain + `:name;`. + + Returns + ------- + Union[pd.DataFrame, Iterator[pd.DataFrame]] + Pandas DataFrame or Generator of Pandas DataFrames if chunksize is passed. + + Examples + -------- + >>> import awswrangler as wr + >>> df = wr.lakeformation.read_sql_query( + ... sql="SELECT * FROM my_table LIMIT 10;", + ... database="my_db", + ... transaction_id="ba9a11b5-619a-4ac3-bd70-5a744d09414c" + ... ) + + >>> import awswrangler as wr + >>> df = wr.lakeformation.read_sql_query( + ... sql="SELECT * FROM my_table WHERE name=:name;", + ... database="my_db", + ... query_as_of_time="1611142914", + ... params={"name": "filtered_name"} + ... ) + + """ + if transaction_id is None and query_as_of_time is None: + raise exceptions.InvalidArgumentCombination("Please pass one of transaction_id or query_as_of_time") + # TODO: Generate transaction_id if both transaction_id and query_as_of_time missing? + if transaction_id is not None and query_as_of_time is not None: + raise exceptions.InvalidArgumentCombination("Please pass only one of transaction_id or query_as_of_time, not both") + session: boto3.Session = _utils.ensure_session(session=boto3_session) + chunksize = sys.maxsize if chunksize is True else chunksize + if params is None: + params = {} + for key, value in params.items(): + sql = sql.replace(f":{key};", str(value)) + client_lakeformation: boto3.client = _utils.client(service_name="lakeformation", session=session) + # TODO: Check if the Glue Table is governed? + + args: Dict[str, Any] = { + "DatabaseName": database, + "Statement": sql + } + if catalog_id: + args["CatalogId"] = catalog_id + if transaction_id: + args["TransactionId"] = transaction_id + else: + args["QueryAsOfTime"] = query_as_of_time + query_id: str = client_lakeformation.plan_query(**args)["QueryId"] + + wait_query(query_id=query_id, boto3_session=session) + + work_units_output: Dict[str, Any] = client_lakeformation.get_work_units(QueryId=query_id) + print(work_units_output) + + a = client_lakeformation.execute(QueryId=query_id, Token=work_units_output["Units"][0]["Token"], WorkUnitId=0) + print(a) + + buf = a["Messages"].read() + stream = pa.RecordBatchStreamReader(buf) + table = stream.read_all() + df = table.to_pandas() + return df diff --git a/awswrangler/lakeformation/_utils.py b/awswrangler/lakeformation/_utils.py new file mode 100644 index 000000000..664b5ea65 --- /dev/null +++ b/awswrangler/lakeformation/_utils.py @@ -0,0 +1,49 @@ +"""Utilities Module for Amazon Lake Formation.""" +import logging +import time +from typing import Any, Dict, List, Optional + +import boto3 + +from awswrangler import _utils, exceptions + +_QUERY_FINAL_STATES: List[str] = ["ERROR", "FINISHED"] +_QUERY_WAIT_POLLING_DELAY: float = 2 # SECONDS + +_logger: logging.Logger = logging.getLogger(__name__) + + +def wait_query(query_id: str, boto3_session: Optional[boto3.Session] = None) -> Dict[str, Any]: + """Wait for the query to end. + + Parameters + ---------- + query_id : str + Lake Formation query execution ID. + boto3_session : boto3.Session(), optional + Boto3 Session. The default boto3 session will be used if boto3_session received None. + + Returns + ------- + Dict[str, Any] + Dictionary with the get_query_state response. + + Examples + -------- + >>> import awswrangler as wr + >>> res = wr.lakeformation.wait_query(query_id='query-id') + + """ + session: boto3.Session = _utils.ensure_session(session=boto3_session) + client_lakeformation: boto3.client = _utils.client(service_name="lakeformation", session=session) + + response: Dict[str, Any] = client_lakeformation.get_query_state(QueryId=query_id) + state: str = response["State"] + while state not in _QUERY_FINAL_STATES: + time.sleep(_QUERY_WAIT_POLLING_DELAY) + response = client_lakeformation.get_query_state(QueryId=query_id) + state = response["State"] + _logger.debug("state: %s", state) + if state == "ERROR": + raise exceptions.QueryFailed(response.get("Error")) + return response From 9630dd9d707dfa4e96394da1d0d27e76514ec371 Mon Sep 17 00:00:00 2001 From: Abdel Jaidi Date: Thu, 28 Jan 2021 21:57:11 +0000 Subject: [PATCH 05/25] Minor - Refactoring Work Units Logic --- awswrangler/lakeformation/_read.py | 71 ++++++++++++++++++------------ 1 file changed, 43 insertions(+), 28 deletions(-) diff --git a/awswrangler/lakeformation/_read.py b/awswrangler/lakeformation/_read.py index 53db9882a..48a6006e5 100644 --- a/awswrangler/lakeformation/_read.py +++ b/awswrangler/lakeformation/_read.py @@ -1,11 +1,11 @@ """Amazon Lake Formation Module gathering all read functions.""" import logging import sys -from typing import Any, Dict, Iterator, Optional, Union +from typing import Any, Dict, Iterator, List, Optional, Tuple, Union import boto3 import pandas as pd -import pyarrow as pa +from pyarrow import RecordBatchStreamReader from awswrangler import _utils, exceptions from awswrangler._config import apply_configs @@ -25,7 +25,7 @@ def read_sql_query( boto3_session: Optional[boto3.Session] = None, params: Optional[Dict[str, Any]] = None, ) -> Union[pd.DataFrame, Iterator[pd.DataFrame]]: - """Execute PartiQL query against an AWS Glue Governed Table based on Transaction ID or time travel timestamp. Return single Pandas DataFrame or Iterator. + """Execute PartiQL query against an AWS Glue Table based on Transaction ID or time travel timestamp. Return single Pandas DataFrame or Iterator. Note ---- @@ -82,11 +82,18 @@ def read_sql_query( Examples -------- + >>> import awswrangler as wr + >>> df = wr.lakeformation.read_sql_query( + ... sql="SELECT * FROM my_table;", + ... database="my_db", + ... catalog_id="111111111111" + ... ) + >>> import awswrangler as wr >>> df = wr.lakeformation.read_sql_query( ... sql="SELECT * FROM my_table LIMIT 10;", ... database="my_db", - ... transaction_id="ba9a11b5-619a-4ac3-bd70-5a744d09414c" + ... transaction_id="1b62811fa3e02c4e5fdbaa642b752030379c4a8a70da1f8732ce6ccca47afdc9" ... ) >>> import awswrangler as wr @@ -94,46 +101,54 @@ def read_sql_query( ... sql="SELECT * FROM my_table WHERE name=:name;", ... database="my_db", ... query_as_of_time="1611142914", - ... params={"name": "filtered_name"} + ... params={"name": "\'filtered_name\'"} ... ) """ - if transaction_id is None and query_as_of_time is None: - raise exceptions.InvalidArgumentCombination("Please pass one of transaction_id or query_as_of_time") - # TODO: Generate transaction_id if both transaction_id and query_as_of_time missing? if transaction_id is not None and query_as_of_time is not None: - raise exceptions.InvalidArgumentCombination("Please pass only one of transaction_id or query_as_of_time, not both") + raise exceptions.InvalidArgumentCombination( + "Please pass only one of `transaction_id` or `query_as_of_time`, not both" + ) session: boto3.Session = _utils.ensure_session(session=boto3_session) chunksize = sys.maxsize if chunksize is True else chunksize if params is None: params = {} for key, value in params.items(): - sql = sql.replace(f":{key};", str(value)) + sql = sql.replace(f":{key}", str(value)) client_lakeformation: boto3.client = _utils.client(service_name="lakeformation", session=session) - # TODO: Check if the Glue Table is governed? - args: Dict[str, Any] = { - "DatabaseName": database, - "Statement": sql - } + args: Dict[str, Any] = {"DatabaseName": database, "Statement": sql} if catalog_id: args["CatalogId"] = catalog_id - if transaction_id: + if query_as_of_time: + args["QueryAsOfTime"] = query_as_of_time + elif transaction_id: args["TransactionId"] = transaction_id else: - args["QueryAsOfTime"] = query_as_of_time + _logger.debug("Neither `transaction_id` nor `query_as_of_time` were specified, beginning transaction") + transaction_id = client_lakeformation.begin_transaction(ReadOnly=True)["TransactionId"] + args["TransactionId"] = transaction_id query_id: str = client_lakeformation.plan_query(**args)["QueryId"] wait_query(query_id=query_id, boto3_session=session) - work_units_output: Dict[str, Any] = client_lakeformation.get_work_units(QueryId=query_id) - print(work_units_output) - - a = client_lakeformation.execute(QueryId=query_id, Token=work_units_output["Units"][0]["Token"], WorkUnitId=0) - print(a) - - buf = a["Messages"].read() - stream = pa.RecordBatchStreamReader(buf) - table = stream.read_all() - df = table.to_pandas() - return df + scan_kwargs: Dict[str, Union[str, int]] = {"QueryId": query_id, "PageSize": 2} # TODO: Inquire about good page size + next_token: str = "init_token" # Dummy token + token_work_units: List[Tuple[str, int]] = [] + while next_token: + response = client_lakeformation.get_work_units(**scan_kwargs) + token_work_units.extend( # [(Token0, WorkUnitId0), (Token0, WorkUnitId1), (Token1, WorkUnitId0) ... ] + [ + (unit["Token"], unit_id) + for unit in response["Units"] + for unit_id in range(unit["WorkUnitIdMin"], unit["WorkUnitIdMax"] + 1) # Max is inclusive + ] + ) + next_token = response["NextToken"] + scan_kwargs["NextToken"] = next_token + + dfs: List[pd.DataFrame] = [] + for token, work_unit in token_work_units: + messages: Any = client_lakeformation.execute(QueryId=query_id, Token=token, WorkUnitId=work_unit)["Messages"] + dfs.append(RecordBatchStreamReader(messages.read()).read_pandas()) + return pd.concat(dfs) From 55f624df442dbf52d8010591bc813a4194011070 Mon Sep 17 00:00:00 2001 From: Abdel Jaidi Date: Fri, 29 Jan 2021 16:38:28 +0000 Subject: [PATCH 06/25] Major - Checkpoint w/ functional read code/example --- awswrangler/lakeformation/_read.py | 141 ++++++++++++++++++----------- tests/test_lakeformation.py | 35 +++++++ 2 files changed, 125 insertions(+), 51 deletions(-) create mode 100644 tests/test_lakeformation.py diff --git a/awswrangler/lakeformation/_read.py b/awswrangler/lakeformation/_read.py index 48a6006e5..df9e9f4d7 100644 --- a/awswrangler/lakeformation/_read.py +++ b/awswrangler/lakeformation/_read.py @@ -1,11 +1,12 @@ """Amazon Lake Formation Module gathering all read functions.""" +import concurrent.futures +import itertools import logging -import sys from typing import Any, Dict, Iterator, List, Optional, Tuple, Union import boto3 import pandas as pd -from pyarrow import RecordBatchStreamReader +from pyarrow import NativeFile, RecordBatchStreamReader from awswrangler import _utils, exceptions from awswrangler._config import apply_configs @@ -14,6 +15,71 @@ _logger: logging.Logger = logging.getLogger(__name__) +def _execute_query( + query_id: str, + token_work_unit: Tuple[str, int], + boto3_session: Optional[boto3.Session] = None, +) -> pd.DataFrame: + session: boto3.Session = _utils.ensure_session(session=boto3_session) + client_lakeformation: boto3.client = _utils.client(service_name="lakeformation", session=session) + token, work_unit = token_work_unit + messages: NativeFile = client_lakeformation.execute(QueryId=query_id, Token=token, WorkUnitId=work_unit)["Messages"] + return RecordBatchStreamReader(messages.read()).read_pandas() + + +def _resolve_sql_query( + query_id: str, + chunked: Optional[bool] = None, + use_threads: bool = True, + boto3_session: Optional[boto3.Session] = None, +) -> Union[pd.DataFrame, Iterator[pd.DataFrame]]: + session: boto3.Session = _utils.ensure_session(session=boto3_session) + client_lakeformation: boto3.client = _utils.client(service_name="lakeformation", session=session) + + wait_query(query_id=query_id, boto3_session=session) + + # The LF Query Engine distributes the load across workers + # Retrieve the tokens and their associated work units until NextToken is '' + # One Token can span multiple work units + # PageSize determines the size of the "Units" array in each call + # TODO: Inquire about good page size # pylint: disable=W0511 + scan_kwargs: Dict[str, Union[str, int]] = {"QueryId": query_id, "PageSize": 2} + next_token: str = "init_token" # Dummy token + token_work_units: List[Tuple[str, int]] = [] + while next_token: + response = client_lakeformation.get_work_units(**scan_kwargs) + token_work_units.extend( # [(Token0, WorkUnitId0), (Token0, WorkUnitId1), (Token1, WorkUnitId2) ... ] + [ + (unit["Token"], unit_id) + for unit in response["Units"] + for unit_id in range(unit["WorkUnitIdMin"], unit["WorkUnitIdMax"] + 1) # Max is inclusive + ] + ) + next_token = response["NextToken"] + scan_kwargs["NextToken"] = next_token + + dfs: List[pd.DataFrame] = list() + if use_threads is False: + dfs = list( + _execute_query(query_id=query_id, token_work_unit=token_work_unit, boto3_session=boto3_session) + for token_work_unit in token_work_units + ) + else: + cpus: int = _utils.ensure_cpu_count(use_threads=use_threads) + with concurrent.futures.ThreadPoolExecutor(max_workers=cpus) as executor: + dfs = list( + executor.map( + _execute_query, + itertools.repeat(query_id), + token_work_units, + itertools.repeat(_utils.boto3_to_primitives(boto3_session=boto3_session)), + ) + ) + if not chunked: + return pd.concat(dfs) + return dfs + + @apply_configs def read_sql_query( sql: str, @@ -21,11 +87,12 @@ def read_sql_query( transaction_id: Optional[str] = None, query_as_of_time: Optional[str] = None, catalog_id: Optional[str] = None, - chunksize: Optional[Union[int, bool]] = None, + chunked: bool = False, + use_threads: bool = True, boto3_session: Optional[boto3.Session] = None, params: Optional[Dict[str, Any]] = None, ) -> Union[pd.DataFrame, Iterator[pd.DataFrame]]: - """Execute PartiQL query against an AWS Glue Table based on Transaction ID or time travel timestamp. Return single Pandas DataFrame or Iterator. + """Execute PartiQL query on AWS Glue Table (Transaction ID or time travel timestamp). Return Pandas DataFrame. Note ---- @@ -39,17 +106,8 @@ def read_sql_query( Note ---- - `chunksize` argument (memory-friendly) (i.e batching): - - Return an Iterable of DataFrames instead of a regular DataFrame. - - There are two batching strategies: - - - If **chunksize=True**, a new DataFrame will be returned for each file in the query result. - - - If **chunksize=INTEGER**, Wrangler will iterate on the data so that the DataFrames number of rows is equal to INTEGER. - - `P.S.` `chunksize=True` is faster and uses less memory + `chunked` argument (memory-friendly): + If set to `True`, return an Iterable of DataFrames instead of a regular DataFrame. Parameters ---------- @@ -58,27 +116,30 @@ def read_sql_query( database : str AWS Glue database name transaction_id : str, optional - The ID of the transaction at which to read the table contents. Cannot be specified alongside query_as_of_time + The ID of the transaction at which to read the table contents. + Cannot be specified alongside query_as_of_time query_as_of_time : str, optional - The time as of when to read the table contents. Must be a valid Unix epoch timestamp. Cannot be specified alongside transaction_id + The time as of when to read the table contents. Must be a valid Unix epoch timestamp. + Cannot be specified alongside transaction_id catalog_id : str, optional The ID of the Data Catalog from which to retrieve Databases. If none is provided, the AWS account ID is used by default. - chunksize : Union[int, bool], optional - If passed will split the data into an Iterable of DataFrames (memory-friendly). - If `True`, Wrangler will iterate on the data by files in the most efficient way without guarantee of chunksize. - If an `INTEGER` is passed, Wrangler will iterate on the data so that the DataFrames number of rows is equal to INTEGER. + chunked : bool, optional + If `True`, Wrangler returns an Iterable of DataFrames with no guarantee of chunksize. + use_threads : bool + True to enable concurrent requests, False to disable multiple threads. + When enabled, os.cpu_count() is used as the max number of threads. boto3_session : boto3.Session(), optional - Boto3 Session. The default boto3 session will be used if boto3_session receives None. + Boto3 Session. The default boto3 session is used if boto3_session receives None. params: Dict[str, any], optional Dict of parameters used to format the partiQL query. Only named parameters are supported. - The dict must contain the information in the form {'name': 'value'} and the SQL query must contain - `:name;`. + The dict must contain the information in the form {"name": "value"} and the SQL query must contain + `:name`. Returns ------- Union[pd.DataFrame, Iterator[pd.DataFrame]] - Pandas DataFrame or Generator of Pandas DataFrames if chunksize is passed. + Pandas DataFrame or Generator of Pandas DataFrames if chunked is passed. Examples -------- @@ -101,7 +162,7 @@ def read_sql_query( ... sql="SELECT * FROM my_table WHERE name=:name;", ... database="my_db", ... query_as_of_time="1611142914", - ... params={"name": "\'filtered_name\'"} + ... params={"name": "'filtered_name'"} ... ) """ @@ -110,14 +171,13 @@ def read_sql_query( "Please pass only one of `transaction_id` or `query_as_of_time`, not both" ) session: boto3.Session = _utils.ensure_session(session=boto3_session) - chunksize = sys.maxsize if chunksize is True else chunksize + client_lakeformation: boto3.client = _utils.client(service_name="lakeformation", session=session) if params is None: params = {} for key, value in params.items(): sql = sql.replace(f":{key}", str(value)) - client_lakeformation: boto3.client = _utils.client(service_name="lakeformation", session=session) - args: Dict[str, Any] = {"DatabaseName": database, "Statement": sql} + args: Dict[str, Optional[str]] = {"DatabaseName": database, "Statement": sql} if catalog_id: args["CatalogId"] = catalog_id if query_as_of_time: @@ -130,25 +190,4 @@ def read_sql_query( args["TransactionId"] = transaction_id query_id: str = client_lakeformation.plan_query(**args)["QueryId"] - wait_query(query_id=query_id, boto3_session=session) - - scan_kwargs: Dict[str, Union[str, int]] = {"QueryId": query_id, "PageSize": 2} # TODO: Inquire about good page size - next_token: str = "init_token" # Dummy token - token_work_units: List[Tuple[str, int]] = [] - while next_token: - response = client_lakeformation.get_work_units(**scan_kwargs) - token_work_units.extend( # [(Token0, WorkUnitId0), (Token0, WorkUnitId1), (Token1, WorkUnitId0) ... ] - [ - (unit["Token"], unit_id) - for unit in response["Units"] - for unit_id in range(unit["WorkUnitIdMin"], unit["WorkUnitIdMax"] + 1) # Max is inclusive - ] - ) - next_token = response["NextToken"] - scan_kwargs["NextToken"] = next_token - - dfs: List[pd.DataFrame] = [] - for token, work_unit in token_work_units: - messages: Any = client_lakeformation.execute(QueryId=query_id, Token=token, WorkUnitId=work_unit)["Messages"] - dfs.append(RecordBatchStreamReader(messages.read()).read_pandas()) - return pd.concat(dfs) + return _resolve_sql_query(query_id=query_id, chunked=chunked, use_threads=use_threads, boto3_session=boto3_session) diff --git a/tests/test_lakeformation.py b/tests/test_lakeformation.py new file mode 100644 index 000000000..783f0df16 --- /dev/null +++ b/tests/test_lakeformation.py @@ -0,0 +1,35 @@ +import logging + +import pytest + +import awswrangler as wr + +from ._utils import get_df_csv + +logging.getLogger("awswrangler").setLevel(logging.DEBUG) + + +@pytest.mark.parametrize("use_threads", [True, False]) +def test_lakeformation(path, glue_database, glue_table, use_threads): + table = f"__{glue_table}" + wr.catalog.delete_table_if_exists(database=glue_database, table=table) + wr.s3.to_parquet( + df=get_df_csv(), + path=path, + index=False, + boto3_session=None, + s3_additional_kwargs=None, + dataset=True, + table=table, + database=glue_database, + partition_cols=["par0", "par1"], + mode="overwrite", + ) + df = wr.lakeformation.read_sql_query( + sql=f"SELECT * FROM {table} WHERE id = :id;", + database=glue_database, + use_threads=use_threads, + params={"id": 1}, + ) + assert len(df.index) == 1 + wr.catalog.delete_table_if_exists(database=glue_database, table=table) From 8a501c9afea374369a31f7c19084b54b66126153 Mon Sep 17 00:00:00 2001 From: Abdel Jaidi Date: Sun, 31 Jan 2021 14:31:57 +0000 Subject: [PATCH 07/25] Minor - Removing unnecessary ensure_session --- awswrangler/lakeformation/_read.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/awswrangler/lakeformation/_read.py b/awswrangler/lakeformation/_read.py index df9e9f4d7..09380cbdd 100644 --- a/awswrangler/lakeformation/_read.py +++ b/awswrangler/lakeformation/_read.py @@ -20,8 +20,7 @@ def _execute_query( token_work_unit: Tuple[str, int], boto3_session: Optional[boto3.Session] = None, ) -> pd.DataFrame: - session: boto3.Session = _utils.ensure_session(session=boto3_session) - client_lakeformation: boto3.client = _utils.client(service_name="lakeformation", session=session) + client_lakeformation: boto3.client = _utils.client(service_name="lakeformation", session=boto3_session) token, work_unit = token_work_unit messages: NativeFile = client_lakeformation.execute(QueryId=query_id, Token=token, WorkUnitId=work_unit)["Messages"] return RecordBatchStreamReader(messages.read()).read_pandas() @@ -33,10 +32,9 @@ def _resolve_sql_query( use_threads: bool = True, boto3_session: Optional[boto3.Session] = None, ) -> Union[pd.DataFrame, Iterator[pd.DataFrame]]: - session: boto3.Session = _utils.ensure_session(session=boto3_session) - client_lakeformation: boto3.client = _utils.client(service_name="lakeformation", session=session) + client_lakeformation: boto3.client = _utils.client(service_name="lakeformation", session=boto3_session) - wait_query(query_id=query_id, boto3_session=session) + wait_query(query_id=query_id, boto3_session=boto3_session) # The LF Query Engine distributes the load across workers # Retrieve the tokens and their associated work units until NextToken is '' @@ -75,6 +73,7 @@ def _resolve_sql_query( itertools.repeat(_utils.boto3_to_primitives(boto3_session=boto3_session)), ) ) + dfs = [df for df in dfs if not df.empty] if not chunked: return pd.concat(dfs) return dfs From f3015b2308d38829f062244b709bf61bbc96de68 Mon Sep 17 00:00:00 2001 From: Abdel Jaidi Date: Mon, 1 Feb 2021 12:09:19 +0000 Subject: [PATCH 08/25] Minor - Adding changes from comments and review --- awswrangler/lakeformation/_read.py | 68 +++++++++++++++++++++++++----- tests/test_lakeformation.py | 19 ++++++--- 2 files changed, 72 insertions(+), 15 deletions(-) diff --git a/awswrangler/lakeformation/_read.py b/awswrangler/lakeformation/_read.py index 09380cbdd..14cf0d2bc 100644 --- a/awswrangler/lakeformation/_read.py +++ b/awswrangler/lakeformation/_read.py @@ -6,9 +6,9 @@ import boto3 import pandas as pd -from pyarrow import NativeFile, RecordBatchStreamReader +from pyarrow import NativeFile, RecordBatchStreamReader, Table -from awswrangler import _utils, exceptions +from awswrangler import _data_types, _utils, exceptions from awswrangler._config import apply_configs from awswrangler.lakeformation._utils import wait_query @@ -18,19 +18,40 @@ def _execute_query( query_id: str, token_work_unit: Tuple[str, int], - boto3_session: Optional[boto3.Session] = None, + categories: Optional[List[str]], + safe: bool, + use_threads: bool, + boto3_session: boto3.Session, ) -> pd.DataFrame: client_lakeformation: boto3.client = _utils.client(service_name="lakeformation", session=boto3_session) token, work_unit = token_work_unit messages: NativeFile = client_lakeformation.execute(QueryId=query_id, Token=token, WorkUnitId=work_unit)["Messages"] - return RecordBatchStreamReader(messages.read()).read_pandas() + table: Table = RecordBatchStreamReader(messages.read()).read_all() + args: Dict[str, Any] = {} + if table.num_rows > 0: + args = { + "use_threads": use_threads, + "split_blocks": True, + "self_destruct": True, + "integer_object_nulls": False, + "date_as_object": True, + "ignore_metadata": True, + "strings_to_categorical": False, + "categories": categories, + "safe": safe, + "types_mapper": _data_types.pyarrow2pandas_extension, + } + df: pd.DataFrame = _utils.ensure_df_is_mutable(df=table.to_pandas(**args)) + return df def _resolve_sql_query( query_id: str, - chunked: Optional[bool] = None, - use_threads: bool = True, - boto3_session: Optional[boto3.Session] = None, + chunked: Optional[bool], + categories: Optional[List[str]], + safe: bool, + use_threads: bool, + boto3_session: boto3.Session, ) -> Union[pd.DataFrame, Iterator[pd.DataFrame]]: client_lakeformation: boto3.client = _utils.client(service_name="lakeformation", session=boto3_session) @@ -59,7 +80,14 @@ def _resolve_sql_query( dfs: List[pd.DataFrame] = list() if use_threads is False: dfs = list( - _execute_query(query_id=query_id, token_work_unit=token_work_unit, boto3_session=boto3_session) + _execute_query( + query_id=query_id, + token_work_unit=token_work_unit, + categories=categories, + safe=safe, + use_threads=use_threads, + boto3_session=boto3_session, + ) for token_work_unit in token_work_units ) else: @@ -70,12 +98,15 @@ def _resolve_sql_query( _execute_query, itertools.repeat(query_id), token_work_units, + itertools.repeat(categories), + itertools.repeat(safe), + itertools.repeat(use_threads), itertools.repeat(_utils.boto3_to_primitives(boto3_session=boto3_session)), ) ) dfs = [df for df in dfs if not df.empty] if not chunked: - return pd.concat(dfs) + return pd.concat(dfs, sort=False, copy=False, ignore_index=False) return dfs @@ -87,6 +118,8 @@ def read_sql_query( query_as_of_time: Optional[str] = None, catalog_id: Optional[str] = None, chunked: bool = False, + categories: Optional[List[str]] = None, + safe: bool = True, use_threads: bool = True, boto3_session: Optional[boto3.Session] = None, params: Optional[Dict[str, Any]] = None, @@ -125,6 +158,14 @@ def read_sql_query( If none is provided, the AWS account ID is used by default. chunked : bool, optional If `True`, Wrangler returns an Iterable of DataFrames with no guarantee of chunksize. + categories: Optional[List[str]], optional + List of columns names that should be returned as pandas.Categorical. + Recommended for memory restricted environments. + safe : bool, default True + For certain data types, a cast is needed in order to store the + data in a pandas DataFrame or Series (e.g. timestamps are always + stored as nanoseconds in pandas). This option controls whether it + is a safe cast or not. use_threads : bool True to enable concurrent requests, False to disable multiple threads. When enabled, os.cpu_count() is used as the max number of threads. @@ -189,4 +230,11 @@ def read_sql_query( args["TransactionId"] = transaction_id query_id: str = client_lakeformation.plan_query(**args)["QueryId"] - return _resolve_sql_query(query_id=query_id, chunked=chunked, use_threads=use_threads, boto3_session=boto3_session) + return _resolve_sql_query( + query_id=query_id, + chunked=chunked, + categories=categories, + safe=safe, + use_threads=use_threads, + boto3_session=session, + ) diff --git a/tests/test_lakeformation.py b/tests/test_lakeformation.py index 783f0df16..51e564259 100644 --- a/tests/test_lakeformation.py +++ b/tests/test_lakeformation.py @@ -14,22 +14,31 @@ def test_lakeformation(path, glue_database, glue_table, use_threads): table = f"__{glue_table}" wr.catalog.delete_table_if_exists(database=glue_database, table=table) wr.s3.to_parquet( - df=get_df_csv(), + df=get_df_csv()[["id", "date", "timestamp", "par0", "par1"]], path=path, index=False, boto3_session=None, s3_additional_kwargs=None, dataset=True, - table=table, - database=glue_database, partition_cols=["par0", "par1"], mode="overwrite", + table=table, + database=glue_database, ) + df = wr.lakeformation.read_sql_query( - sql=f"SELECT * FROM {table} WHERE id = :id;", + sql=f"SELECT * FROM {table};", database=glue_database, use_threads=use_threads, + ) + assert len(df.index) == 3 + assert len(df.columns) == 5 + assert df["id"].sum() == 6 + + df2 = wr.lakeformation.read_sql_query( + sql=f"SELECT * FROM {table} WHERE id = :id;", + database=glue_database, params={"id": 1}, ) - assert len(df.index) == 1 + assert len(df2.index) == 1 wr.catalog.delete_table_if_exists(database=glue_database, table=table) From e42eb44e792c2f0b7b8fa0d1524dba2fe26c0404 Mon Sep 17 00:00:00 2001 From: Abdel Jaidi Date: Sat, 13 Feb 2021 17:05:42 +0000 Subject: [PATCH 09/25] Minor - Adding Abort, Begin, Commit and Extend transactions --- awswrangler/lakeformation/__init__.py | 12 +++++++- awswrangler/lakeformation/_read.py | 42 +++++++++++++++++---------- 2 files changed, 37 insertions(+), 17 deletions(-) diff --git a/awswrangler/lakeformation/__init__.py b/awswrangler/lakeformation/__init__.py index daf66c386..cdb2f7fec 100644 --- a/awswrangler/lakeformation/__init__.py +++ b/awswrangler/lakeformation/__init__.py @@ -1,9 +1,19 @@ """Amazon Lake Formation Module.""" from awswrangler.lakeformation._read import read_sql_query # noqa -from awswrangler.lakeformation._utils import wait_query # noqa +from awswrangler.lakeformation._utils import ( # noqa + abort_transaction, + begin_transaction, + commit_transaction, + extend_transaction, + wait_query, +) __all__ = [ "read_sql_query", + "abort_transaction", + "begin_transaction", + "commit_transaction", + "extend_transaction", "wait_query", ] diff --git a/awswrangler/lakeformation/_read.py b/awswrangler/lakeformation/_read.py index 14cf0d2bc..79c0bfd08 100644 --- a/awswrangler/lakeformation/_read.py +++ b/awswrangler/lakeformation/_read.py @@ -10,7 +10,8 @@ from awswrangler import _data_types, _utils, exceptions from awswrangler._config import apply_configs -from awswrangler.lakeformation._utils import wait_query +from awswrangler.catalog._utils import _catalog_id +from awswrangler.lakeformation._utils import abort_transaction, begin_transaction, wait_query _logger: logging.Logger = logging.getLogger(__name__) @@ -62,7 +63,7 @@ def _resolve_sql_query( # One Token can span multiple work units # PageSize determines the size of the "Units" array in each call # TODO: Inquire about good page size # pylint: disable=W0511 - scan_kwargs: Dict[str, Union[str, int]] = {"QueryId": query_id, "PageSize": 2} + scan_kwargs: Dict[str, Union[str, int]] = {"QueryId": query_id, "PageSize": 10} next_token: str = "init_token" # Dummy token token_work_units: List[Tuple[str, int]] = [] while next_token: @@ -74,7 +75,7 @@ def _resolve_sql_query( for unit_id in range(unit["WorkUnitIdMin"], unit["WorkUnitIdMax"] + 1) # Max is inclusive ] ) - next_token = response["NextToken"] + next_token = response.get("NextToken", None) scan_kwargs["NextToken"] = next_token dfs: List[pd.DataFrame] = list() @@ -126,6 +127,11 @@ def read_sql_query( ) -> Union[pd.DataFrame, Iterator[pd.DataFrame]]: """Execute PartiQL query on AWS Glue Table (Transaction ID or time travel timestamp). Return Pandas DataFrame. + Note + ---- + ORDER BY operations are not honoured. + i.e. sql="SELECT * FROM my_table ORDER BY my_column" is NOT valid + Note ---- The database must NOT be explicitely defined in the PartiQL statement. @@ -217,24 +223,28 @@ def read_sql_query( for key, value in params.items(): sql = sql.replace(f":{key}", str(value)) - args: Dict[str, Optional[str]] = {"DatabaseName": database, "Statement": sql} - if catalog_id: - args["CatalogId"] = catalog_id + args: Dict[str, Optional[str]] = _catalog_id(catalog_id=catalog_id, **{"DatabaseName": database, "Statement": sql}) if query_as_of_time: args["QueryAsOfTime"] = query_as_of_time elif transaction_id: args["TransactionId"] = transaction_id else: _logger.debug("Neither `transaction_id` nor `query_as_of_time` were specified, beginning transaction") - transaction_id = client_lakeformation.begin_transaction(ReadOnly=True)["TransactionId"] + transaction_id = begin_transaction(read_only=True, boto3_session=session) args["TransactionId"] = transaction_id query_id: str = client_lakeformation.plan_query(**args)["QueryId"] - - return _resolve_sql_query( - query_id=query_id, - chunked=chunked, - categories=categories, - safe=safe, - use_threads=use_threads, - boto3_session=session, - ) + try: + return _resolve_sql_query( + query_id=query_id, + chunked=chunked, + categories=categories, + safe=safe, + use_threads=use_threads, + boto3_session=session, + ) + except Exception as ex: + _logger.debug("Aborting transaction with ID: %s.", transaction_id) + if transaction_id: + abort_transaction(transaction_id=transaction_id, boto3_session=session) + _logger.error(ex) + raise From e7ad4c8f7b35384f460e238538a067b7fcf6fa31 Mon Sep 17 00:00:00 2001 From: Abdel Jaidi Date: Sat, 13 Feb 2021 17:20:05 +0000 Subject: [PATCH 10/25] Minor - Adding missing functions --- awswrangler/lakeformation/_utils.py | 110 ++++++++++++++++++++++++++++ 1 file changed, 110 insertions(+) diff --git a/awswrangler/lakeformation/_utils.py b/awswrangler/lakeformation/_utils.py index 664b5ea65..5f5fcde2f 100644 --- a/awswrangler/lakeformation/_utils.py +++ b/awswrangler/lakeformation/_utils.py @@ -13,6 +13,116 @@ _logger: logging.Logger = logging.getLogger(__name__) +def abort_transaction(transaction_id: str, boto3_session: Optional[boto3.Session] = None) -> None: + """Abort the specified transaction. Returns exception if the transaction was previously committed. + + Parameters + ---------- + transaction_id : str + The ID of the transaction. + boto3_session : boto3.Session(), optional + Boto3 Session. The default boto3 session will be used if boto3_session received None. + + Returns + ------- + None + None. + + Examples + -------- + >>> import awswrangler as wr + >>> wr.lakeformation.abort_transaction(transaction_id="...") + + """ + session: boto3.Session = _utils.ensure_session(session=boto3_session) + client_lakeformation: boto3.client = _utils.client(service_name="lakeformation", session=session) + + client_lakeformation.abort_transaction(TransactionId=transaction_id) + + +def begin_transaction(read_only: Optional[bool] = False, boto3_session: Optional[boto3.Session] = None) -> str: + """Start a new transaction and returns its transaction ID. + + Parameters + ---------- + read_only : bool, optional + Indicates that that this transaction should be read only. + Writes made using a read-only transaction ID will be rejected. + Read-only transactions do not need to be committed. + boto3_session : boto3.Session(), optional + Boto3 Session. The default boto3 session will be used if boto3_session received None. + + Returns + ------- + str + An opaque identifier for the transaction. + + Examples + -------- + >>> import awswrangler as wr + >>> transaction_id = wr.lakeformation.begin_transaction(read_only=False) + + """ + session: boto3.Session = _utils.ensure_session(session=boto3_session) + client_lakeformation: boto3.client = _utils.client(service_name="lakeformation", session=session) + transaction_id: str = client_lakeformation.begin_transaction(ReadOnly=read_only)["TransactionId"] + return transaction_id + + +def commit_transaction(transaction_id: str, boto3_session: Optional[boto3.Session] = None) -> None: + """Commit the specified transaction. Returns exception if the transaction was previously aborted. + + Parameters + ---------- + transaction_id : str + The ID of the transaction. + boto3_session : boto3.Session(), optional + Boto3 Session. The default boto3 session will be used if boto3_session received None. + + Returns + ------- + None + None. + + Examples + -------- + >>> import awswrangler as wr + >>> wr.lakeformation.commit_transaction(transaction_id="...") + + """ + session: boto3.Session = _utils.ensure_session(session=boto3_session) + client_lakeformation: boto3.client = _utils.client(service_name="lakeformation", session=session) + + client_lakeformation.commit_transaction(TransactionId=transaction_id) + + +def extend_transaction(transaction_id: str, boto3_session: Optional[boto3.Session] = None) -> None: + """Indicate to the service that the specified transaction is still active and should not be aborted. + + Parameters + ---------- + transaction_id : str + The ID of the transaction. + boto3_session : boto3.Session(), optional + Boto3 Session. The default boto3 session will be used if boto3_session received None. + + Returns + ------- + None + None. + + Examples + -------- + >>> import awswrangler as wr + >>> wr.lakeformation.extend_transaction(transaction_id="...") + + """ + session: boto3.Session = _utils.ensure_session(session=boto3_session) + client_lakeformation: boto3.client = _utils.client(service_name="lakeformation", session=session) + + client_lakeformation.extend_transaction(TransactionId=transaction_id) + + def wait_query(query_id: str, boto3_session: Optional[boto3.Session] = None) -> Dict[str, Any]: """Wait for the query to end. From dca75f17b26a1e3db14af22e895bb6535872d481 Mon Sep 17 00:00:00 2001 From: Abdel Jaidi Date: Sat, 13 Feb 2021 17:34:44 +0000 Subject: [PATCH 11/25] Minor - Adding missing @property --- awswrangler/_config.py | 1 + 1 file changed, 1 insertion(+) diff --git a/awswrangler/_config.py b/awswrangler/_config.py index 9bb9cadbc..f11219d39 100644 --- a/awswrangler/_config.py +++ b/awswrangler/_config.py @@ -353,6 +353,7 @@ def lakeformation_endpoint_url(self) -> Optional[str]: def lakeformation_endpoint_url(self, value: Optional[str]) -> None: self._set_config_value(key="lakeformation_endpoint_url", value=value) + @property def botocore_config(self) -> botocore.config.Config: """Property botocore_config.""" return cast(Optional[botocore.config.Config], self["botocore_config"]) From cea72c462dffb305a825e4c4b1cf375c5079f47c Mon Sep 17 00:00:00 2001 From: Abdel Jaidi Date: Sat, 13 Feb 2021 17:44:27 +0000 Subject: [PATCH 12/25] Minor - Disable too many public methods --- awswrangler/_config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/awswrangler/_config.py b/awswrangler/_config.py index f11219d39..f3f5fe06d 100644 --- a/awswrangler/_config.py +++ b/awswrangler/_config.py @@ -48,7 +48,7 @@ class _ConfigArg(NamedTuple): } -class _Config: # pylint: disable=too-many-instance-attributes +class _Config: # pylint: disable=too-many-instance-attributes,too-many-public-methods """Wrangler's Configuration class.""" def __init__(self) -> None: From e47b06664cd28e08e642f800e9625755bab9f166 Mon Sep 17 00:00:00 2001 From: Abdel Jaidi Date: Sun, 14 Feb 2021 18:42:35 +0000 Subject: [PATCH 13/25] Minor - Checkpoint --- awswrangler/catalog/_create.py | 12 +++ awswrangler/catalog/_definitions.py | 6 +- awswrangler/lakeformation/_utils.py | 113 +++++++++++++++++++++++++++- awswrangler/s3/_write.py | 10 ++- awswrangler/s3/_write_dataset.py | 112 ++++++++++++++++++++++++--- awswrangler/s3/_write_parquet.py | 54 ++++++++++++- awswrangler/s3/_write_text.py | 4 + 7 files changed, 293 insertions(+), 18 deletions(-) diff --git a/awswrangler/catalog/_create.py b/awswrangler/catalog/_create.py index 2dcbc6fc7..fe6016605 100644 --- a/awswrangler/catalog/_create.py +++ b/awswrangler/catalog/_create.py @@ -214,6 +214,7 @@ def _create_parquet_table( table: str, path: str, columns_types: Dict[str, str], + table_type: Optional[str], partitions_types: Optional[Dict[str, str]], bucketing_info: Optional[Tuple[List[str], int]], catalog_id: Optional[str], @@ -253,6 +254,7 @@ def _create_parquet_table( table=table, path=path, columns_types=columns_types, + table_type=table_type, partitions_types=partitions_types, bucketing_info=bucketing_info, compression=compression, @@ -286,6 +288,7 @@ def _create_csv_table( table: str, path: str, columns_types: Dict[str, str], + table_type: Optional[str], partitions_types: Optional[Dict[str, str]], bucketing_info: Optional[Tuple[List[str], int]], description: Optional[str], @@ -324,6 +327,7 @@ def _create_csv_table( table=table, path=path, columns_types=columns_types, + table_type=table_type, partitions_types=partitions_types, bucketing_info=bucketing_info, compression=compression, @@ -519,6 +523,7 @@ def create_parquet_table( table: str, path: str, columns_types: Dict[str, str], + table_type: Optional[str] = None, partitions_types: Optional[Dict[str, str]] = None, bucketing_info: Optional[Tuple[List[str], int]] = None, catalog_id: Optional[str] = None, @@ -550,6 +555,8 @@ def create_parquet_table( Amazon S3 path (e.g. s3://bucket/prefix/). columns_types: Dict[str, str] Dictionary with keys as column names and values as data types (e.g. {'col0': 'bigint', 'col1': 'double'}). + table_type: str, optional + The type of the Glue Table (EXTERNAL_TABLE, GOVERNED...). Set to EXTERNAL_TABLE if None partitions_types: Dict[str, str], optional Dictionary with keys as partition names and values as data types (e.g. {'col2': 'date'}). bucketing_info: Tuple[List[str], int], optional @@ -627,6 +634,7 @@ def create_parquet_table( table=table, path=path, columns_types=columns_types, + table_type=table_type, partitions_types=partitions_types, bucketing_info=bucketing_info, catalog_id=catalog_id, @@ -653,6 +661,7 @@ def create_csv_table( table: str, path: str, columns_types: Dict[str, str], + table_type: Optional[str] = None, partitions_types: Optional[Dict[str, str]] = None, bucketing_info: Optional[Tuple[List[str], int]] = None, compression: Optional[str] = None, @@ -686,6 +695,8 @@ def create_csv_table( Amazon S3 path (e.g. s3://bucket/prefix/). columns_types: Dict[str, str] Dictionary with keys as column names and values as data types (e.g. {'col0': 'bigint', 'col1': 'double'}). + table_type: str, optional + The type of the Glue Table (EXTERNAL_TABLE, GOVERNED...). Set to EXTERNAL_TABLE if None partitions_types: Dict[str, str], optional Dictionary with keys as partition names and values as data types (e.g. {'col2': 'date'}). bucketing_info: Tuple[List[str], int], optional @@ -767,6 +778,7 @@ def create_csv_table( table=table, path=path, columns_types=columns_types, + table_type=table_type, partitions_types=partitions_types, bucketing_info=bucketing_info, catalog_id=catalog_id, diff --git a/awswrangler/catalog/_definitions.py b/awswrangler/catalog/_definitions.py index 778d428dd..f0d0e7ab2 100644 --- a/awswrangler/catalog/_definitions.py +++ b/awswrangler/catalog/_definitions.py @@ -31,6 +31,7 @@ def _parquet_table_definition( table: str, path: str, columns_types: Dict[str, str], + table_type: Optional[str], partitions_types: Dict[str, str], bucketing_info: Optional[Tuple[List[str], int]], compression: Optional[str], @@ -39,7 +40,7 @@ def _parquet_table_definition( return { "Name": table, "PartitionKeys": [{"Name": cname, "Type": dtype} for cname, dtype in partitions_types.items()], - "TableType": "EXTERNAL_TABLE", + "TableType": "EXTERNAL_TABLE" if table_type is None else table_type, "Parameters": {"classification": "parquet", "compressionType": str(compression).lower(), "typeOfData": "file"}, "StorageDescriptor": { "Columns": [{"Name": cname, "Type": dtype} for cname, dtype in columns_types.items()], @@ -100,6 +101,7 @@ def _csv_table_definition( table: str, path: str, columns_types: Dict[str, str], + table_type: Optional[str], partitions_types: Dict[str, str], bucketing_info: Optional[Tuple[List[str], int]], compression: Optional[str], @@ -120,7 +122,7 @@ def _csv_table_definition( return { "Name": table, "PartitionKeys": [{"Name": cname, "Type": dtype} for cname, dtype in partitions_types.items()], - "TableType": "EXTERNAL_TABLE", + "TableType": "EXTERNAL_TABLE" if table_type is None else table_type, "Parameters": parameters, "StorageDescriptor": { "Columns": [{"Name": cname, "Type": dtype} for cname, dtype in columns_types.items()], diff --git a/awswrangler/lakeformation/_utils.py b/awswrangler/lakeformation/_utils.py index 5f5fcde2f..5cf6418be 100644 --- a/awswrangler/lakeformation/_utils.py +++ b/awswrangler/lakeformation/_utils.py @@ -1,11 +1,13 @@ """Utilities Module for Amazon Lake Formation.""" import logging import time -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Union import boto3 from awswrangler import _utils, exceptions +from awswrangler.catalog._utils import _catalog_id +from awswrangler.s3._describe import describe_objects _QUERY_FINAL_STATES: List[str] = ["ERROR", "FINISHED"] _QUERY_WAIT_POLLING_DELAY: float = 2 # SECONDS @@ -13,6 +15,115 @@ _logger: logging.Logger = logging.getLogger(__name__) +def _build_partition_predicate( + partition_cols: List[str], + partitions_types: Dict[str, str], + partitions_values: List[str], +) -> str: + partition_predicates: List[str] = [] + for col, val in zip(partition_cols, partitions_values): + if partitions_types[col].startswith(("tinyint", "smallint", "int", "bigint", "float", "double", "decimal")): + partition_predicates.append(f"{col}={str(val)}") + else: + partition_predicates.append(f"{col}='{str(val)}'") + return " AND ".join(partition_predicates) + + +def _build_table_objects( + paths: List[str], + partitions_values: Dict[str, List[str]], + use_threads: bool, + boto3_session: Optional[boto3.Session], +) -> List[Union[str, int, List[Any]]]: + table_objects: List[Union[str, int, List[Any]]] = [] + paths_desc: Dict[str, Dict[str, Any]] = describe_objects( + path=paths, use_threads=use_threads, boto3_session=boto3_session + ) + for path, path_desc in paths_desc.items(): + table_object: Dict[str, Any] = { + "Uri": path, + "ETag": path_desc["ETag"], + "Size": path_desc["ContentLength"], + } + if partitions_values: + table_object["PartitionValues"] = partitions_values[path.rsplit("/", 1)[0]] + table_objects.append(table_object) + return table_objects + + +def _get_table_objects( + catalog_id: Optional[str], + database: str, + table: str, + transaction_id: str, + partition_cols: Optional[List[str]], + partitions_types: Optional[Dict[str, str]], + partitions_values: Optional[List[str]], + boto3_session: Optional[boto3.Session], +) -> List[Union[str, int, List[Any]]]: + """Get Governed Table Objects from Lake Formation Engine""" + session: boto3.Session = _utils.ensure_session(session=boto3_session) + client_lakeformation: boto3.client = _utils.client(service_name="lakeformation", session=session) + + scan_kwargs: Dict[str, Union[str, int]] = _catalog_id( + catalog_id=catalog_id, + **{ + "TransactionId": transaction_id, + "DatabaseName": database, + "TableName": table, + "MaxResults": 100, + }, + ) + if partition_cols: + scan_kwargs["PartitionPredicate"] = _build_partition_predicate( + partition_cols=partition_cols, partitions_types=partitions_types, partitions_values=partitions_values + ) + + next_token: str = "init_token" # Dummy token + table_objects: List[Union[str, int, List[Any]]] = [] + while next_token: + response = client_lakeformation.get_table_objects(**scan_kwargs) + for objects in response["Objects"]: + for table_object in objects["Objects"]: + table_object["PartitionValues"] = objects["PartitionValues"] + table_objects.append(table_object) + next_token = response.get("NextToken", None) + scan_kwargs["NextToken"] = next_token + return table_objects + + +def _update_table_objects( + catalog_id: Optional[str], + database: str, + table: str, + transaction_id: str, + boto3_session: Optional[boto3.Session], + add_objects: Optional[List[Union[str, int, List[Any]]]] = None, + del_objects: Optional[List[Union[str, int, List[Any]]]] = None, +) -> None: + """Register Governed Table Objects changes Lake Formation Engine""" + session: boto3.Session = _utils.ensure_session(session=boto3_session) + client_lakeformation: boto3.client = _utils.client(service_name="lakeformation", session=session) + + update_kwargs: Dict[str, Union[str, int]] = _catalog_id( + catalog_id=catalog_id, + **{ + "TransactionId": transaction_id, + "DatabaseName": database, + "TableName": table, + }, + ) + + write_operations: List[Dict[Dict[Any]]] = [] + if add_objects: + write_operations.append({"AddObject": obj for obj in add_objects}) + elif del_objects: + write_operations.append({"DeleteObject": obj for obj in del_objects}) + update_kwargs["WriteOperations"] = write_operations + + client_lakeformation.update_table_objects(**update_kwargs) + + def abort_transaction(transaction_id: str, boto3_session: Optional[boto3.Session] = None) -> None: """Abort the specified transaction. Returns exception if the transaction was previously committed. diff --git a/awswrangler/s3/_write.py b/awswrangler/s3/_write.py index e94a71288..75d3c4ce7 100644 --- a/awswrangler/s3/_write.py +++ b/awswrangler/s3/_write.py @@ -48,6 +48,8 @@ def _validate_args( database: Optional[str], dataset: bool, path: str, + table_type: Optional[str], + transaction_id: Optional[str], partition_cols: Optional[List[str]], bucketing_info: Optional[Tuple[List[str], int]], mode: Optional[str], @@ -58,7 +60,9 @@ def _validate_args( if df.empty is True: raise exceptions.EmptyDataFrame() if dataset is False: - if path.endswith("/"): + if path is None: + raise exceptions.InvalidArgumentValue("If dataset is False, the argument `path` must be passed.") + elif path.endswith("/"): raise exceptions.InvalidArgumentValue( "If , the argument should be a file path, not a directory." ) @@ -79,6 +83,10 @@ def _validate_args( "Arguments database and table must be passed together. If you want to store your dataset metadata in " "the Glue Catalog, please ensure you are passing both." ) + elif (table_type != "GOVERNED") and (transaction_id is not None): + raise exceptions.InvalidArgumentCombination( + "When passing a `transaction_id` as an argument, `table_type` must be set to 'GOVERNED'" + ) elif bucketing_info and bucketing_info[1] <= 0: raise exceptions.InvalidArgumentValue( "Please pass a value greater than 1 for the number of buckets for bucketing." diff --git a/awswrangler/s3/_write_dataset.py b/awswrangler/s3/_write_dataset.py index bf3a7a1f4..5d3b0e548 100644 --- a/awswrangler/s3/_write_dataset.py +++ b/awswrangler/s3/_write_dataset.py @@ -9,6 +9,14 @@ import pandas as pd from awswrangler import exceptions +from awswrangler.lakeformation._utils import ( + _build_table_objects, + _get_table_objects, + _update_table_objects, + abort_transaction, + begin_transaction, + commit_transaction, +) from awswrangler.s3._delete import delete_objects from awswrangler.s3._write_concurrent import _WriteProxy @@ -23,6 +31,12 @@ def _to_partitions( use_threads: bool, mode: str, partition_cols: List[str], + partitions_types: Optional[List[str]], + catalog_id: Optional[str], + database: Optional[str], + table: Optional[str], + table_type: Optional[str], + transaction_id: Optional[str], bucketing_info: Optional[Tuple[List[str], int]], boto3_session: boto3.Session, **func_kwargs: Any, @@ -37,12 +51,32 @@ def _to_partitions( subdir = "/".join([f"{name}={val}" for name, val in zip(partition_cols, keys)]) prefix: str = f"{path_root}{subdir}/" if mode == "overwrite_partitions": - delete_objects( - path=prefix, - use_threads=use_threads, - boto3_session=boto3_session, - s3_additional_kwargs=func_kwargs.get("s3_additional_kwargs"), - ) + if table_type == "GOVERNED": + del_objects: List[Union[str, int, List[Any]]] = _get_table_objects( + catalog_id=catalog_id, + database=database, + table=table, + transaction_id=transaction_id, + partition_cols=partition_cols, + partitions_values=keys, + partitions_types=partitions_types, + boto3_session=boto3_session, + ) + _update_table_objects( + catalog_id=catalog_id, + database=database, + table=table, + transaction_id=transaction_id, + del_objects=del_objects, + boto3_session=boto3_session, + ) + else: + delete_objects( + path=prefix, + use_threads=use_threads, + boto3_session=boto3_session, + s3_additional_kwargs=func_kwargs.get("s3_additional_kwargs"), + ) if bucketing_info: _to_buckets( func=func, @@ -137,24 +171,50 @@ def _to_dataset( use_threads: bool, mode: str, partition_cols: Optional[List[str]], + partitions_types: Optional[List[str]], + catalog_id: Optional[str], + database: Optional[str], + table: Optional[str], + table_type: Optional[str], + transaction_id: Optional[str], bucketing_info: Optional[Tuple[List[str], int]], boto3_session: boto3.Session, **func_kwargs: Any, ) -> Tuple[List[str], Dict[str, List[str]]]: path_root = path_root if path_root.endswith("/") else f"{path_root}/" + commit_trans: bool = False + if table_type == "GOVERNED": + # Check whether to skip committing the transaction (i.e. multiple read/write operations) + if transaction_id is None: + _logger.debug("`transaction_id` not specified, beginning transaction") + transaction_id = begin_transaction(read_only=False, boto3_session=boto3_session) + commit_trans = True + # Evaluate mode if mode not in ["append", "overwrite", "overwrite_partitions"]: raise exceptions.InvalidArgumentValue( f"{mode} is a invalid mode, please use append, overwrite or overwrite_partitions." ) if (mode == "overwrite") or ((mode == "overwrite_partitions") and (not partition_cols)): - delete_objects( - path=path_root, - use_threads=use_threads, - boto3_session=boto3_session, - s3_additional_kwargs=func_kwargs.get("s3_additional_kwargs"), - ) + if table_type == "GOVERNED": + del_objects: List[Union[str, int, List[Any]]] = _get_table_objects( + catalog_id=catalog_id, + database=database, + table=table, + transaction_id=transaction_id, + boto3_session=boto3_session, + ) + _update_table_objects( + catalog_id=catalog_id, + database=database, + table=table, + transaction_id=transaction_id, + del_objects=del_objects, + boto3_session=boto3_session, + ) + else: + delete_objects(path=path_root, use_threads=use_threads, boto3_session=boto3_session) # Writing partitions_values: Dict[str, List[str]] = {} @@ -167,8 +227,14 @@ def _to_dataset( path_root=path_root, use_threads=use_threads, mode=mode, + catalog_id=catalog_id, + database=database, + table=table, + table_type=table_type, + transaction_id=transaction_id, bucketing_info=bucketing_info, partition_cols=partition_cols, + partitions_types=partitions_types, boto3_session=boto3_session, index=index, **func_kwargs, @@ -190,4 +256,26 @@ def _to_dataset( ) _logger.debug("paths: %s", paths) _logger.debug("partitions_values: %s", partitions_values) + if table_type == "GOVERNED": + add_objects: List[Union[str, int, List[Any]]] = _build_table_objects( + paths, partitions_values, use_threads=use_threads, boto3_session=boto3_session + ) + try: + _update_table_objects( + catalog_id=catalog_id, + database=database, + table=table, + transaction_id=transaction_id, + add_objects=add_objects, + boto3_session=boto3_session, + ) + if commit_trans: + commit_transaction(transaction_id=transaction_id, boto3_session=boto3_session) + except Exception as ex: + _logger.debug("Aborting transaction with ID: %s.", transaction_id) + if transaction_id: + abort_transaction(transaction_id=transaction_id, boto3_session=boto3_session) + _logger.error(ex) + raise + return paths, partitions_values diff --git a/awswrangler/s3/_write_parquet.py b/awswrangler/s3/_write_parquet.py index 5e9311c2b..5c4e2b95a 100644 --- a/awswrangler/s3/_write_parquet.py +++ b/awswrangler/s3/_write_parquet.py @@ -198,7 +198,7 @@ def _to_parquet( @apply_configs def to_parquet( # pylint: disable=too-many-arguments,too-many-locals df: pd.DataFrame, - path: str, + path: Optional[str] = None, index: bool = False, compression: Optional[str] = "snappy", max_rows_by_file: Optional[int] = None, @@ -215,6 +215,8 @@ def to_parquet( # pylint: disable=too-many-arguments,too-many-locals schema_evolution: bool = True, database: Optional[str] = None, table: Optional[str] = None, + table_type: Optional[str] = None, + transaction_id: Optional[str] = None, dtype: Optional[Dict[str, str]] = None, description: Optional[str] = None, parameters: Optional[Dict[str, str]] = None, @@ -306,6 +308,10 @@ def to_parquet( # pylint: disable=too-many-arguments,too-many-locals Glue/Athena catalog: Database name. table : str, optional Glue/Athena catalog: Table name. + table_type: str, optional + The type of the Glue Table. Set to EXTERNAL_TABLE if None. + transaction_id: str, optional + The ID of the transaction when writing to a Governed Table. dtype : Dict[str, str], optional Dictionary of columns names and Athena/Glue types to be casted. Useful when you have columns with undetermined or mixed data types. @@ -479,6 +485,8 @@ def to_parquet( # pylint: disable=too-many-arguments,too-many-locals database=database, dataset=dataset, path=path, + table_type=table_type, + transaction_id=transaction_id, partition_cols=partition_cols, bucketing_info=bucketing_info, mode=mode, @@ -510,6 +518,13 @@ def to_parquet( # pylint: disable=too-many-arguments,too-many-locals catalog_table_input = catalog._get_table_input( # pylint: disable=protected-access database=database, table=table, boto3_session=session, catalog_id=catalog_id ) + if path is None: + if catalog_table_input is not None: + path = catalog_table_input["StorageDescriptor"]["Location"] + else: + raise exceptions.InvalidArgumentValue( + "Glue table does not exist. Please pass the `path` argument to create it." + ) df = _apply_dtype(df=df, dtype=dtype, catalog_table_input=catalog_table_input, mode=mode) schema: pa.Schema = _data_types.pyarrow_schema_from_pandas( df=df, index=index, ignore_cols=partition_cols, dtype=dtype @@ -540,6 +555,34 @@ def to_parquet( # pylint: disable=too-many-arguments,too-many-locals ) if schema_evolution is False: _check_schema_changes(columns_types=columns_types, table_input=catalog_table_input, mode=mode) + if (catalog_table_input is None) and (table_type == "GOVERNED"): + catalog._create_parquet_table( + database=database, + table=table, + path=path, + columns_types=columns_types, + table_type=table_type, + partitions_types=partitions_types, + bucketing_info=bucketing_info, + compression=compression, + description=description, + parameters=parameters, + columns_comments=columns_comments, + boto3_session=session, + mode=mode, + catalog_versioning=catalog_versioning, + projection_enabled=projection_enabled, + projection_types=projection_types, + projection_ranges=projection_ranges, + projection_values=projection_values, + projection_intervals=projection_intervals, + projection_digits=projection_digits, + catalog_id=catalog_id, + catalog_table_input=catalog_table_input, + ) + catalog_table_input = catalog._get_table_input( # pylint: disable=protected-access + database=database, table=table, boto3_session=session, catalog_id=catalog_id + ) paths, partitions_values = _to_dataset( func=_to_parquet, concurrent_partitioning=concurrent_partitioning, @@ -548,9 +591,15 @@ def to_parquet( # pylint: disable=too-many-arguments,too-many-locals index=index, compression=compression, compression_ext=compression_ext, + catalog_id=catalog_id, + database=database, + table=table, + table_type=table_type, + transaction_id=transaction_id, cpus=cpus, use_threads=use_threads, partition_cols=partition_cols, + partitions_types=partitions_types, bucketing_info=bucketing_info, dtype=dtype, mode=mode, @@ -566,6 +615,7 @@ def to_parquet( # pylint: disable=too-many-arguments,too-many-locals table=table, path=path, columns_types=columns_types, + table_type=table_type, partitions_types=partitions_types, bucketing_info=bucketing_info, compression=compression, @@ -584,7 +634,7 @@ def to_parquet( # pylint: disable=too-many-arguments,too-many-locals catalog_id=catalog_id, catalog_table_input=catalog_table_input, ) - if partitions_values and (regular_partitions is True): + if partitions_values and (regular_partitions is True) and (table_type != "GOVERNED"): _logger.debug("partitions_values:\n%s", partitions_values) catalog.add_parquet_partitions( database=database, diff --git a/awswrangler/s3/_write_text.py b/awswrangler/s3/_write_text.py index d8e8d2adb..e793d891d 100644 --- a/awswrangler/s3/_write_text.py +++ b/awswrangler/s3/_write_text.py @@ -90,6 +90,7 @@ def to_csv( # pylint: disable=too-many-arguments,too-many-locals,too-many-state catalog_versioning: bool = False, database: Optional[str] = None, table: Optional[str] = None, + table_type: Optional[str] = None, dtype: Optional[Dict[str, str]] = None, description: Optional[str] = None, parameters: Optional[Dict[str, str]] = None, @@ -183,6 +184,8 @@ def to_csv( # pylint: disable=too-many-arguments,too-many-locals,too-many-state Glue/Athena catalog: Database name. table : str, optional Glue/Athena catalog: Table name. + table_type: str, optional + The type of the Glue Table. Set to EXTERNAL_TABLE if None dtype : Dict[str, str], optional Dictionary of columns names and Athena/Glue types to be casted. Useful when you have columns with undetermined or mixed data types. @@ -487,6 +490,7 @@ def to_csv( # pylint: disable=too-many-arguments,too-many-locals,too-many-state table=table, path=path, columns_types=columns_types, + table_type=table_type, partitions_types=partitions_types, bucketing_info=bucketing_info, description=description, From 4970a3931345b97eff4d1a6720c9eeb3cd5307b7 Mon Sep 17 00:00:00 2001 From: Abdel Jaidi Date: Mon, 15 Feb 2021 18:10:19 +0000 Subject: [PATCH 14/25] Major - Governed tables write operations tested --- awswrangler/catalog/_create.py | 6 +- awswrangler/lakeformation/__init__.py | 3 +- awswrangler/lakeformation/_read.py | 109 +++++++++++++++++++++++++- awswrangler/lakeformation/_utils.py | 41 +++++----- awswrangler/s3/_write.py | 4 +- awswrangler/s3/_write_dataset.py | 75 +++++++++--------- awswrangler/s3/_write_parquet.py | 32 ++++++-- awswrangler/s3/_write_text.py | 5 ++ tests/test__routines.py | 68 ++++++++++++---- tests/test_lakeformation.py | 16 ++-- 10 files changed, 270 insertions(+), 89 deletions(-) diff --git a/awswrangler/catalog/_create.py b/awswrangler/catalog/_create.py index fe6016605..eddd30daf 100644 --- a/awswrangler/catalog/_create.py +++ b/awswrangler/catalog/_create.py @@ -33,6 +33,7 @@ def _create_table( # pylint: disable=too-many-branches,too-many-statements catalog_versioning: bool, boto3_session: Optional[boto3.Session], table_input: Dict[str, Any], + table_type: Optional[str], table_exist: bool, projection_enabled: bool, partitions_types: Optional[Dict[str, str]], @@ -118,7 +119,8 @@ def _create_table( # pylint: disable=too-many-branches,too-many-statements f"{mode} is not a valid mode. It must be 'overwrite', 'append' or 'overwrite_partitions'." ) if table_exist is True and mode == "overwrite": - delete_all_partitions(table=table, database=database, catalog_id=catalog_id, boto3_session=session) + if table_type != "GOVERNED": + delete_all_partitions(table=table, database=database, catalog_id=catalog_id, boto3_session=session) _logger.debug("Updating table (%s)...", mode) client_glue.update_table( **_catalog_id( @@ -271,6 +273,7 @@ def _create_parquet_table( catalog_versioning=catalog_versioning, boto3_session=boto3_session, table_input=table_input, + table_type=table_type, table_exist=table_exist, partitions_types=partitions_types, projection_enabled=projection_enabled, @@ -346,6 +349,7 @@ def _create_csv_table( catalog_versioning=catalog_versioning, boto3_session=boto3_session, table_input=table_input, + table_type=table_type, table_exist=table_exist, partitions_types=partitions_types, projection_enabled=projection_enabled, diff --git a/awswrangler/lakeformation/__init__.py b/awswrangler/lakeformation/__init__.py index cdb2f7fec..8b8c3084e 100644 --- a/awswrangler/lakeformation/__init__.py +++ b/awswrangler/lakeformation/__init__.py @@ -1,6 +1,6 @@ """Amazon Lake Formation Module.""" -from awswrangler.lakeformation._read import read_sql_query # noqa +from awswrangler.lakeformation._read import read_sql_query, read_sql_table # noqa from awswrangler.lakeformation._utils import ( # noqa abort_transaction, begin_transaction, @@ -11,6 +11,7 @@ __all__ = [ "read_sql_query", + "read_sql_table", "abort_transaction", "begin_transaction", "commit_transaction", diff --git a/awswrangler/lakeformation/_read.py b/awswrangler/lakeformation/_read.py index 79c0bfd08..be414d753 100644 --- a/awswrangler/lakeformation/_read.py +++ b/awswrangler/lakeformation/_read.py @@ -8,7 +8,7 @@ import pandas as pd from pyarrow import NativeFile, RecordBatchStreamReader, Table -from awswrangler import _data_types, _utils, exceptions +from awswrangler import _data_types, _utils, catalog, exceptions from awswrangler._config import apply_configs from awswrangler.catalog._utils import _catalog_id from awswrangler.lakeformation._utils import abort_transaction, begin_transaction, wait_query @@ -62,7 +62,6 @@ def _resolve_sql_query( # Retrieve the tokens and their associated work units until NextToken is '' # One Token can span multiple work units # PageSize determines the size of the "Units" array in each call - # TODO: Inquire about good page size # pylint: disable=W0511 scan_kwargs: Dict[str, Union[str, int]] = {"QueryId": query_id, "PageSize": 10} next_token: str = "init_token" # Dummy token token_work_units: List[Tuple[str, int]] = [] @@ -248,3 +247,109 @@ def read_sql_query( abort_transaction(transaction_id=transaction_id, boto3_session=session) _logger.error(ex) raise + + +@apply_configs +def read_sql_table( + table: str, + database: str, + transaction_id: Optional[str] = None, + query_as_of_time: Optional[str] = None, + catalog_id: Optional[str] = None, + chunked: bool = False, + categories: Optional[List[str]] = None, + safe: bool = True, + use_threads: bool = True, + boto3_session: Optional[boto3.Session] = None, +) -> Union[pd.DataFrame, Iterator[pd.DataFrame]]: + """Extract all rows from AWS Glue Table (Transaction ID or time travel timestamp). Return Pandas DataFrame. + + Note + ---- + ORDER BY operations are not honoured. + i.e. sql="SELECT * FROM my_table ORDER BY my_column" is NOT valid + + Note + ---- + Pass one of `transaction_id` or `query_as_of_time`, not both. + + Note + ---- + `chunked` argument (memory-friendly): + If set to `True`, return an Iterable of DataFrames instead of a regular DataFrame. + + Parameters + ---------- + table : str + AWS Glue table name. + database : str + AWS Glue database name + transaction_id : str, optional + The ID of the transaction at which to read the table contents. + Cannot be specified alongside query_as_of_time + query_as_of_time : str, optional + The time as of when to read the table contents. Must be a valid Unix epoch timestamp. + Cannot be specified alongside transaction_id + catalog_id : str, optional + The ID of the Data Catalog from which to retrieve Databases. + If none is provided, the AWS account ID is used by default. + chunked : bool, optional + If `True`, Wrangler returns an Iterable of DataFrames with no guarantee of chunksize. + categories: Optional[List[str]], optional + List of columns names that should be returned as pandas.Categorical. + Recommended for memory restricted environments. + safe : bool, default True + For certain data types, a cast is needed in order to store the + data in a pandas DataFrame or Series (e.g. timestamps are always + stored as nanoseconds in pandas). This option controls whether it + is a safe cast or not. + use_threads : bool + True to enable concurrent requests, False to disable multiple threads. + When enabled, os.cpu_count() is used as the max number of threads. + boto3_session : boto3.Session(), optional + Boto3 Session. The default boto3 session is used if boto3_session receives None. + + Returns + ------- + Union[pd.DataFrame, Iterator[pd.DataFrame]] + Pandas DataFrame or Generator of Pandas DataFrames if chunked is passed. + + Examples + -------- + >>> import awswrangler as wr + >>> df = wr.lakeformation.read_sql_table( + ... table="my_table", + ... database="my_db", + ... catalog_id="111111111111", + ... ) + + >>> import awswrangler as wr + >>> df = wr.lakeformation.read_sql_table( + ... table="my_table", + ... database="my_db", + ... transaction_id="1b62811fa3e02c4e5fdbaa642b752030379c4a8a70da1f8732ce6ccca47afdc9", + ... chunked=True, + ... ) + + >>> import awswrangler as wr + >>> df = wr.lakeformation.read_sql_table( + ... table="my_table", + ... database="my_db", + ... query_as_of_time="1611142914", + ... use_threads=True, + ... ) + + """ + table = catalog.sanitize_table_name(table=table) + return read_sql_query( + sql=f"SELECT * FROM {table}", + database=database, + transaction_id=transaction_id, + query_as_of_time=query_as_of_time, + safe=safe, + catalog_id=catalog_id, + categories=categories, + chunked=chunked, + use_threads=use_threads, + boto3_session=boto3_session, + ) diff --git a/awswrangler/lakeformation/_utils.py b/awswrangler/lakeformation/_utils.py index 5cf6418be..5088096b1 100644 --- a/awswrangler/lakeformation/_utils.py +++ b/awswrangler/lakeformation/_utils.py @@ -15,6 +15,10 @@ _logger: logging.Logger = logging.getLogger(__name__) +def _without_keys(d: Dict[str, Any], keys: List[str]) -> Dict[str, Any]: + return {x: d[x] for x in d if x not in keys} + + def _build_partition_predicate( partition_cols: List[str], partitions_types: Dict[str, str], @@ -34,8 +38,8 @@ def _build_table_objects( partitions_values: Dict[str, List[str]], use_threads: bool, boto3_session: Optional[boto3.Session], -) -> List[Union[str, int, List[Any]]]: - table_objects: List[Union[str, int, List[Any]]] = [] +) -> List[Dict[str, Any]]: + table_objects: List[Dict[str, Any]] = [] paths_desc: Dict[str, Dict[str, Any]] = describe_objects( path=paths, use_threads=use_threads, boto3_session=boto3_session ) @@ -46,7 +50,7 @@ def _build_table_objects( "Size": path_desc["ContentLength"], } if partitions_values: - table_object["PartitionValues"] = partitions_values[path.rsplit("/", 1)[0]] + table_object["PartitionValues"] = partitions_values[f"{path.rsplit('/', 1)[0].rstrip('/')}/"] table_objects.append(table_object) return table_objects @@ -56,12 +60,12 @@ def _get_table_objects( database: str, table: str, transaction_id: str, - partition_cols: Optional[List[str]], - partitions_types: Optional[Dict[str, str]], - partitions_values: Optional[List[str]], boto3_session: Optional[boto3.Session], -) -> List[Union[str, int, List[Any]]]: - """Get Governed Table Objects from Lake Formation Engine""" + partition_cols: Optional[List[str]] = None, + partitions_types: Optional[Dict[str, str]] = None, + partitions_values: Optional[List[str]] = None, +) -> List[Dict[str, Any]]: + """Get Governed Table Objects from Lake Formation Engine.""" session: boto3.Session = _utils.ensure_session(session=boto3_session) client_lakeformation: boto3.client = _utils.client(service_name="lakeformation", session=session) @@ -74,18 +78,19 @@ def _get_table_objects( "MaxResults": 100, }, ) - if partition_cols: + if partition_cols and partitions_types and partitions_values: scan_kwargs["PartitionPredicate"] = _build_partition_predicate( partition_cols=partition_cols, partitions_types=partitions_types, partitions_values=partitions_values ) next_token: str = "init_token" # Dummy token - table_objects: List[Union[str, int, List[Any]]] = [] + table_objects: List[Dict[str, Any]] = [] while next_token: response = client_lakeformation.get_table_objects(**scan_kwargs) for objects in response["Objects"]: for table_object in objects["Objects"]: - table_object["PartitionValues"] = objects["PartitionValues"] + if objects["PartitionValues"]: + table_object["PartitionValues"] = objects["PartitionValues"] table_objects.append(table_object) next_token = response.get("NextToken", None) scan_kwargs["NextToken"] = next_token @@ -98,14 +103,14 @@ def _update_table_objects( table: str, transaction_id: str, boto3_session: Optional[boto3.Session], - add_objects: Optional[List[Union[str, int, List[Any]]]] = None, - del_objects: Optional[List[Union[str, int, List[Any]]]] = None, + add_objects: Optional[List[Dict[str, Any]]] = None, + del_objects: Optional[List[Dict[str, Any]]] = None, ) -> None: - """Register Governed Table Objects changes Lake Formation Engine""" + """Register Governed Table Objects changes to Lake Formation Engine.""" session: boto3.Session = _utils.ensure_session(session=boto3_session) client_lakeformation: boto3.client = _utils.client(service_name="lakeformation", session=session) - update_kwargs: Dict[str, Union[str, int]] = _catalog_id( + update_kwargs: Dict[str, Union[str, int, List[Dict[str, Dict[str, Any]]]]] = _catalog_id( catalog_id=catalog_id, **{ "TransactionId": transaction_id, @@ -114,11 +119,11 @@ def _update_table_objects( }, ) - write_operations: List[Dict[Dict[Any]]] = [] + write_operations: List[Dict[str, Dict[str, Any]]] = [] if add_objects: - write_operations.append({"AddObject": obj for obj in add_objects}) + write_operations.extend({"AddObject": obj} for obj in add_objects) elif del_objects: - write_operations.append({"DeleteObject": obj for obj in del_objects}) + write_operations.extend({"DeleteObject": _without_keys(obj, ["Size"])} for obj in del_objects) update_kwargs["WriteOperations"] = write_operations client_lakeformation.update_table_objects(**update_kwargs) diff --git a/awswrangler/s3/_write.py b/awswrangler/s3/_write.py index 75d3c4ce7..0ed48535e 100644 --- a/awswrangler/s3/_write.py +++ b/awswrangler/s3/_write.py @@ -47,7 +47,7 @@ def _validate_args( table: Optional[str], database: Optional[str], dataset: bool, - path: str, + path: Optional[str], table_type: Optional[str], transaction_id: Optional[str], partition_cols: Optional[List[str]], @@ -62,7 +62,7 @@ def _validate_args( if dataset is False: if path is None: raise exceptions.InvalidArgumentValue("If dataset is False, the argument `path` must be passed.") - elif path.endswith("/"): + if path.endswith("/"): raise exceptions.InvalidArgumentValue( "If , the argument should be a file path, not a directory." ) diff --git a/awswrangler/s3/_write_dataset.py b/awswrangler/s3/_write_dataset.py index 5d3b0e548..3bc05cf2d 100644 --- a/awswrangler/s3/_write_dataset.py +++ b/awswrangler/s3/_write_dataset.py @@ -31,7 +31,7 @@ def _to_partitions( use_threads: bool, mode: str, partition_cols: List[str], - partitions_types: Optional[List[str]], + partitions_types: Optional[Dict[str, str]], catalog_id: Optional[str], database: Optional[str], table: Optional[str], @@ -51,25 +51,26 @@ def _to_partitions( subdir = "/".join([f"{name}={val}" for name, val in zip(partition_cols, keys)]) prefix: str = f"{path_root}{subdir}/" if mode == "overwrite_partitions": - if table_type == "GOVERNED": - del_objects: List[Union[str, int, List[Any]]] = _get_table_objects( + if (table_type == "GOVERNED") and (table is not None) and (database is not None): + del_objects: List[Dict[str, Any]] = _get_table_objects( catalog_id=catalog_id, database=database, table=table, - transaction_id=transaction_id, + transaction_id=transaction_id, # type: ignore partition_cols=partition_cols, partitions_values=keys, partitions_types=partitions_types, boto3_session=boto3_session, ) - _update_table_objects( - catalog_id=catalog_id, - database=database, - table=table, - transaction_id=transaction_id, - del_objects=del_objects, - boto3_session=boto3_session, - ) + if del_objects: + _update_table_objects( + catalog_id=catalog_id, + database=database, + table=table, + transaction_id=transaction_id, # type: ignore + del_objects=del_objects, + boto3_session=boto3_session, + ) else: delete_objects( path=prefix, @@ -171,7 +172,7 @@ def _to_dataset( use_threads: bool, mode: str, partition_cols: Optional[List[str]], - partitions_types: Optional[List[str]], + partitions_types: Optional[Dict[str, str]], catalog_id: Optional[str], database: Optional[str], table: Optional[str], @@ -197,22 +198,23 @@ def _to_dataset( f"{mode} is a invalid mode, please use append, overwrite or overwrite_partitions." ) if (mode == "overwrite") or ((mode == "overwrite_partitions") and (not partition_cols)): - if table_type == "GOVERNED": - del_objects: List[Union[str, int, List[Any]]] = _get_table_objects( - catalog_id=catalog_id, - database=database, - table=table, - transaction_id=transaction_id, - boto3_session=boto3_session, - ) - _update_table_objects( + if (table_type == "GOVERNED") and (table is not None) and (database is not None): + del_objects: List[Dict[str, Any]] = _get_table_objects( catalog_id=catalog_id, database=database, table=table, - transaction_id=transaction_id, - del_objects=del_objects, + transaction_id=transaction_id, # type: ignore boto3_session=boto3_session, ) + if del_objects: + _update_table_objects( + catalog_id=catalog_id, + database=database, + table=table, + transaction_id=transaction_id, # type: ignore + del_objects=del_objects, + boto3_session=boto3_session, + ) else: delete_objects(path=path_root, use_threads=use_threads, boto3_session=boto3_session) @@ -256,21 +258,22 @@ def _to_dataset( ) _logger.debug("paths: %s", paths) _logger.debug("partitions_values: %s", partitions_values) - if table_type == "GOVERNED": - add_objects: List[Union[str, int, List[Any]]] = _build_table_objects( + if (table_type == "GOVERNED") and (table is not None) and (database is not None): + add_objects: List[Dict[str, Any]] = _build_table_objects( paths, partitions_values, use_threads=use_threads, boto3_session=boto3_session ) try: - _update_table_objects( - catalog_id=catalog_id, - database=database, - table=table, - transaction_id=transaction_id, - add_objects=add_objects, - boto3_session=boto3_session, - ) - if commit_trans: - commit_transaction(transaction_id=transaction_id, boto3_session=boto3_session) + if add_objects: + _update_table_objects( + catalog_id=catalog_id, + database=database, + table=table, + transaction_id=transaction_id, # type: ignore + add_objects=add_objects, + boto3_session=boto3_session, + ) + if commit_trans: + commit_transaction(transaction_id=transaction_id, boto3_session=boto3_session) # type: ignore except Exception as ex: _logger.debug("Aborting transaction with ID: %s.", transaction_id) if transaction_id: diff --git a/awswrangler/s3/_write_parquet.py b/awswrangler/s3/_write_parquet.py index 5c4e2b95a..2148c87f7 100644 --- a/awswrangler/s3/_write_parquet.py +++ b/awswrangler/s3/_write_parquet.py @@ -254,7 +254,7 @@ def to_parquet( # pylint: disable=too-many-arguments,too-many-locals ---------- df: pandas.DataFrame Pandas DataFrame https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.html - path : str + path : str, optional S3 path (for file e.g. ``s3://bucket/prefix/filename.parquet``) (for dataset e.g. ``s3://bucket/prefix``). index : bool True to store the DataFrame index in file, otherwise False to ignore it. @@ -457,6 +457,28 @@ def to_parquet( # pylint: disable=too-many-arguments,too-many-locals } } + Writing dataset to Glue governed table + + >>> import awswrangler as wr + >>> import pandas as pd + >>> wr.s3.to_parquet( + ... df=pd.DataFrame({ + ... 'col': [1, 2, 3], + ... 'col2': ['A', 'A', 'B'], + ... 'col3': [None, None, None] + ... }), + ... dataset=True, + ... mode='overwrite', + ... database='default', # Athena/Glue database + ... table='my_table', # Athena/Glue table + ... table_type='GOVERNED', + ... transaction_id="xxx", + ... ) + { + 'paths': ['s3://.../x.parquet'], + 'partitions_values: {} + } + Writing dataset casting empty column data type >>> import awswrangler as wr @@ -556,10 +578,10 @@ def to_parquet( # pylint: disable=too-many-arguments,too-many-locals if schema_evolution is False: _check_schema_changes(columns_types=columns_types, table_input=catalog_table_input, mode=mode) if (catalog_table_input is None) and (table_type == "GOVERNED"): - catalog._create_parquet_table( + catalog._create_parquet_table( # pylint: disable=protected-access database=database, table=table, - path=path, + path=path, # type: ignore columns_types=columns_types, table_type=table_type, partitions_types=partitions_types, @@ -587,7 +609,7 @@ def to_parquet( # pylint: disable=too-many-arguments,too-many-locals func=_to_parquet, concurrent_partitioning=concurrent_partitioning, df=df, - path_root=path, + path_root=path, # type: ignore index=index, compression=compression, compression_ext=compression_ext, @@ -613,7 +635,7 @@ def to_parquet( # pylint: disable=too-many-arguments,too-many-locals catalog._create_parquet_table( # pylint: disable=protected-access database=database, table=table, - path=path, + path=path, # type: ignore columns_types=columns_types, table_type=table_type, partitions_types=partitions_types, diff --git a/awswrangler/s3/_write_text.py b/awswrangler/s3/_write_text.py index e793d891d..029511cb2 100644 --- a/awswrangler/s3/_write_text.py +++ b/awswrangler/s3/_write_text.py @@ -91,6 +91,7 @@ def to_csv( # pylint: disable=too-many-arguments,too-many-locals,too-many-state database: Optional[str] = None, table: Optional[str] = None, table_type: Optional[str] = None, + transaction_id: Optional[str] = None, dtype: Optional[Dict[str, str]] = None, description: Optional[str] = None, parameters: Optional[Dict[str, str]] = None, @@ -186,6 +187,8 @@ def to_csv( # pylint: disable=too-many-arguments,too-many-locals,too-many-state Glue/Athena catalog: Table name. table_type: str, optional The type of the Glue Table. Set to EXTERNAL_TABLE if None + transaction_id: str, optional + The ID of the transaction when writing to a Governed Table. dtype : Dict[str, str], optional Dictionary of columns names and Athena/Glue types to be casted. Useful when you have columns with undetermined or mixed data types. @@ -391,6 +394,8 @@ def to_csv( # pylint: disable=too-many-arguments,too-many-locals,too-many-state database=database, dataset=dataset, path=path, + table_type=table_type, + transaction_id=transaction_id, partition_cols=partition_cols, bucketing_info=bucketing_info, mode=mode, diff --git a/tests/test__routines.py b/tests/test__routines.py index fb08e8d12..57c9f2526 100644 --- a/tests/test__routines.py +++ b/tests/test__routines.py @@ -10,7 +10,8 @@ @pytest.mark.parametrize("use_threads", [True, False]) @pytest.mark.parametrize("concurrent_partitioning", [True, False]) -def test_routine_0(glue_database, glue_table, path, use_threads, concurrent_partitioning): +@pytest.mark.parametrize("table_type", ["EXTERNAL_TABLE", "GOVERNED"]) +def test_routine_0(glue_database, glue_table, table_type, path, use_threads, concurrent_partitioning): # Round 1 - Warm up df = pd.DataFrame({"c0": [0, None]}, dtype="Int64") @@ -21,6 +22,7 @@ def test_routine_0(glue_database, glue_table, path, use_threads, concurrent_part mode="overwrite", database=glue_database, table=glue_table, + table_type=table_type, description="c0", parameters={"num_cols": str(len(df.columns)), "num_rows": str(len(df.index))}, columns_comments={"c0": "0"}, @@ -28,7 +30,10 @@ def test_routine_0(glue_database, glue_table, path, use_threads, concurrent_part concurrent_partitioning=concurrent_partitioning, ) assert wr.catalog.get_table_number_of_versions(table=glue_table, database=glue_database) == 1 - df2 = wr.athena.read_sql_table(glue_table, glue_database, use_threads=use_threads) + if table_type == "GOVERNED": + df2 = wr.lakeformation.read_sql_table(glue_table, glue_database, use_threads=use_threads) + else: + df2 = wr.athena.read_sql_table(glue_table, glue_database, use_threads=use_threads) assert df.shape == df2.shape assert df.c0.sum() == df2.c0.sum() parameters = wr.catalog.get_table_parameters(glue_database, glue_table) @@ -44,11 +49,11 @@ def test_routine_0(glue_database, glue_table, path, use_threads, concurrent_part df = pd.DataFrame({"c1": [None, 1, None]}, dtype="Int16") wr.s3.to_parquet( df=df, - path=path, dataset=True, mode="overwrite", database=glue_database, table=glue_table, + table_type=table_type, description="c1", parameters={"num_cols": str(len(df.columns)), "num_rows": str(len(df.index))}, columns_comments={"c1": "1"}, @@ -56,7 +61,10 @@ def test_routine_0(glue_database, glue_table, path, use_threads, concurrent_part concurrent_partitioning=concurrent_partitioning, ) assert wr.catalog.get_table_number_of_versions(table=glue_table, database=glue_database) == 1 - df2 = wr.athena.read_sql_table(glue_table, glue_database, use_threads=use_threads) + if table_type == "GOVERNED": + df2 = wr.lakeformation.read_sql_table(glue_table, glue_database, use_threads=use_threads) + else: + df2 = wr.athena.read_sql_table(glue_table, glue_database, use_threads=use_threads) assert df.shape == df2.shape assert df.c1.sum() == df2.c1.sum() parameters = wr.catalog.get_table_parameters(glue_database, glue_table) @@ -77,6 +85,7 @@ def test_routine_0(glue_database, glue_table, path, use_threads, concurrent_part mode="append", database=glue_database, table=glue_table, + table_type=table_type, description="c1", parameters={"num_cols": str(len(df.columns)), "num_rows": str(len(df.index) * 2)}, columns_comments={"c1": "1"}, @@ -84,7 +93,10 @@ def test_routine_0(glue_database, glue_table, path, use_threads, concurrent_part concurrent_partitioning=concurrent_partitioning, ) assert wr.catalog.get_table_number_of_versions(table=glue_table, database=glue_database) == 1 - df2 = wr.athena.read_sql_table(glue_table, glue_database, use_threads=use_threads) + if table_type == "GOVERNED": + df2 = wr.lakeformation.read_sql_table(glue_table, glue_database, use_threads=use_threads) + else: + df2 = wr.athena.read_sql_table(glue_table, glue_database, use_threads=use_threads) assert len(df.columns) == len(df2.columns) assert len(df.index) * 2 == len(df2.index) assert df.c1.sum() + 1 == df2.c1.sum() @@ -101,11 +113,11 @@ def test_routine_0(glue_database, glue_table, path, use_threads, concurrent_part df = pd.DataFrame({"c2": ["a", None, "b"], "c1": [None, None, None]}) wr.s3.to_parquet( df=df, - path=path, dataset=True, mode="append", database=glue_database, table=glue_table, + table_type=table_type, description="c1+c2", parameters={"num_cols": "2", "num_rows": "9"}, columns_comments={"c1": "1", "c2": "2"}, @@ -113,7 +125,10 @@ def test_routine_0(glue_database, glue_table, path, use_threads, concurrent_part concurrent_partitioning=concurrent_partitioning, ) assert wr.catalog.get_table_number_of_versions(table=glue_table, database=glue_database) == 1 - df2 = wr.athena.read_sql_table(glue_table, glue_database, use_threads=use_threads) + if table_type == "GOVERNED": + df2 = wr.lakeformation.read_sql_table(glue_table, glue_database, use_threads=use_threads) + else: + df2 = wr.athena.read_sql_table(glue_table, glue_database, use_threads=use_threads) assert len(df2.columns) == 2 assert len(df2.index) == 9 assert df2.c1.sum() == 3 @@ -136,6 +151,7 @@ def test_routine_0(glue_database, glue_table, path, use_threads, concurrent_part mode="append", database=glue_database, table=glue_table, + table_type=table_type, description="c1+c2+c3", parameters={"num_cols": "3", "num_rows": "10"}, columns_comments={"c1": "1!", "c2": "2!", "c3": "3"}, @@ -143,7 +159,10 @@ def test_routine_0(glue_database, glue_table, path, use_threads, concurrent_part concurrent_partitioning=concurrent_partitioning, ) assert wr.catalog.get_table_number_of_versions(table=glue_table, database=glue_database) == 1 - df2 = wr.athena.read_sql_table(glue_table, glue_database, use_threads=use_threads) + if table_type == "GOVERNED": + df2 = wr.lakeformation.read_sql_table(glue_table, glue_database, use_threads=use_threads) + else: + df2 = wr.athena.read_sql_table(glue_table, glue_database, use_threads=use_threads) assert len(df2.columns) == 3 assert len(df2.index) == 10 assert df2.c1.sum() == 4 @@ -158,7 +177,13 @@ def test_routine_0(glue_database, glue_table, path, use_threads, concurrent_part assert comments["c2"] == "2!" assert comments["c3"] == "3" - # Round 6 - Overwrite Partitioned + +@pytest.mark.parametrize("use_threads", [True, False]) +@pytest.mark.parametrize("concurrent_partitioning", [True, False]) +@pytest.mark.parametrize("table_type", ["EXTERNAL_TABLE", "GOVERNED"]) +def test_routine_1(glue_database, glue_table, table_type, path, use_threads, concurrent_partitioning): + + # Round 1 - Overwrite Partitioned df = pd.DataFrame({"c0": ["foo", None], "c1": [0, 1]}) wr.s3.to_parquet( df=df, @@ -167,6 +192,7 @@ def test_routine_0(glue_database, glue_table, path, use_threads, concurrent_part mode="overwrite", database=glue_database, table=glue_table, + table_type=table_type, partition_cols=["c1"], description="c0+c1", parameters={"num_cols": "2", "num_rows": "2"}, @@ -175,7 +201,10 @@ def test_routine_0(glue_database, glue_table, path, use_threads, concurrent_part concurrent_partitioning=concurrent_partitioning, ) assert wr.catalog.get_table_number_of_versions(table=glue_table, database=glue_database) == 1 - df2 = wr.athena.read_sql_table(glue_table, glue_database, use_threads=use_threads) + if table_type == "GOVERNED": + df2 = wr.lakeformation.read_sql_table(glue_table, glue_database, use_threads=use_threads) + else: + df2 = wr.athena.read_sql_table(glue_table, glue_database, use_threads=use_threads) assert df.shape == df2.shape assert df.c1.sum() == df2.c1.sum() parameters = wr.catalog.get_table_parameters(glue_database, glue_table) @@ -188,15 +217,15 @@ def test_routine_0(glue_database, glue_table, path, use_threads, concurrent_part assert comments["c0"] == "zero" assert comments["c1"] == "one" - # Round 7 - Overwrite Partitions + # Round 2 - Overwrite Partitions df = pd.DataFrame({"c0": [None, None], "c1": [0, 2]}) wr.s3.to_parquet( df=df, - path=path, dataset=True, mode="overwrite_partitions", database=glue_database, table=glue_table, + table_type=table_type, partition_cols=["c1"], description="c0+c1", parameters={"num_cols": "2", "num_rows": "3"}, @@ -205,7 +234,10 @@ def test_routine_0(glue_database, glue_table, path, use_threads, concurrent_part use_threads=use_threads, ) assert wr.catalog.get_table_number_of_versions(table=glue_table, database=glue_database) == 1 - df2 = wr.athena.read_sql_table(glue_table, glue_database, use_threads=use_threads) + if table_type == "GOVERNED": + df2 = wr.lakeformation.read_sql_table(glue_table, glue_database, use_threads=use_threads) + else: + df2 = wr.athena.read_sql_table(glue_table, glue_database, use_threads=use_threads) assert len(df2.columns) == 2 assert len(df2.index) == 3 assert df2.c1.sum() == 3 @@ -219,7 +251,7 @@ def test_routine_0(glue_database, glue_table, path, use_threads, concurrent_part assert comments["c0"] == "zero" assert comments["c1"] == "one" - # Round 8 - Overwrite Partitions + New Column + Wrong Type + # Round 3 - Overwrite Partitions + New Column + Wrong Type df = pd.DataFrame({"c0": [1, 2], "c1": ["1", "3"], "c2": [True, False]}) wr.s3.to_parquet( df=df, @@ -228,6 +260,7 @@ def test_routine_0(glue_database, glue_table, path, use_threads, concurrent_part mode="overwrite_partitions", database=glue_database, table=glue_table, + table_type=table_type, partition_cols=["c1"], description="c0+c1+c2", parameters={"num_cols": "3", "num_rows": "4"}, @@ -236,7 +269,10 @@ def test_routine_0(glue_database, glue_table, path, use_threads, concurrent_part concurrent_partitioning=concurrent_partitioning, ) assert wr.catalog.get_table_number_of_versions(table=glue_table, database=glue_database) == 1 - df2 = wr.athena.read_sql_table(glue_table, glue_database, use_threads=use_threads) + if table_type == "GOVERNED": + df2 = wr.lakeformation.read_sql_table(glue_table, glue_database, use_threads=use_threads) + else: + df2 = wr.athena.read_sql_table(glue_table, glue_database, use_threads=use_threads) assert len(df2.columns) == 3 assert len(df2.index) == 4 assert df2.c1.sum() == 6 @@ -252,7 +288,7 @@ def test_routine_0(glue_database, glue_table, path, use_threads, concurrent_part assert comments["c2"] == "two" -def test_routine_1(glue_database, glue_table, path): +def test_routine_2(glue_database, glue_table, path): # Round 1 - Warm up df = pd.DataFrame({"c0": [0, None]}, dtype="Int64") diff --git a/tests/test_lakeformation.py b/tests/test_lakeformation.py index 51e564259..a36759de8 100644 --- a/tests/test_lakeformation.py +++ b/tests/test_lakeformation.py @@ -4,7 +4,7 @@ import awswrangler as wr -from ._utils import get_df_csv +from ._utils import get_df logging.getLogger("awswrangler").setLevel(logging.DEBUG) @@ -14,7 +14,7 @@ def test_lakeformation(path, glue_database, glue_table, use_threads): table = f"__{glue_table}" wr.catalog.delete_table_if_exists(database=glue_database, table=table) wr.s3.to_parquet( - df=get_df_csv()[["id", "date", "timestamp", "par0", "par1"]], + df=get_df().drop(["iint8", "binary"], axis=1), # tinyint & binary currently not supported path=path, index=False, boto3_session=None, @@ -26,19 +26,19 @@ def test_lakeformation(path, glue_database, glue_table, use_threads): database=glue_database, ) - df = wr.lakeformation.read_sql_query( - sql=f"SELECT * FROM {table};", + df = wr.lakeformation.read_sql_table( + table=table, database=glue_database, use_threads=use_threads, ) assert len(df.index) == 3 - assert len(df.columns) == 5 - assert df["id"].sum() == 6 + assert len(df.columns) == 14 + assert df["iint32"].sum() == 3 df2 = wr.lakeformation.read_sql_query( - sql=f"SELECT * FROM {table} WHERE id = :id;", + sql=f"SELECT * FROM {table} WHERE iint16 = :iint16;", database=glue_database, - params={"id": 1}, + params={"iint16": 1}, ) assert len(df2.index) == 1 wr.catalog.delete_table_if_exists(database=glue_database, table=table) From be33b0d488b94416378e95a3afbc3ec992c6a145 Mon Sep 17 00:00:00 2001 From: Abdel Jaidi Date: Mon, 15 Feb 2021 19:05:07 +0000 Subject: [PATCH 15/25] Minor - Adding validate flow on branches --- .github/workflows/static-checking.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/static-checking.yml b/.github/workflows/static-checking.yml index fafca0e93..536231d38 100644 --- a/.github/workflows/static-checking.yml +++ b/.github/workflows/static-checking.yml @@ -5,10 +5,12 @@ on: branches: - main - main-governed-tables + - feature/lf-write-table pull_request: branches: - main - main-governed-tables + - feature/lf-transactions jobs: Check: From 424048eae40fe51c309f298067b548d44982e454 Mon Sep 17 00:00:00 2001 From: Abdel Jaidi Date: Mon, 15 Feb 2021 19:09:15 +0000 Subject: [PATCH 16/25] Minor - reducing static checks --- .github/workflows/static-checking.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/.github/workflows/static-checking.yml b/.github/workflows/static-checking.yml index 536231d38..f4622631a 100644 --- a/.github/workflows/static-checking.yml +++ b/.github/workflows/static-checking.yml @@ -5,7 +5,6 @@ on: branches: - main - main-governed-tables - - feature/lf-write-table pull_request: branches: - main From e6eca7ba96a2d0789c8ee98f3cc303895fe7b7ed Mon Sep 17 00:00:00 2001 From: Abdel Jaidi Date: Sat, 20 Feb 2021 17:14:08 +0000 Subject: [PATCH 17/25] Minor - Adding to_csv code --- awswrangler/catalog/_create.py | 2 +- awswrangler/catalog/_definitions.py | 2 +- awswrangler/s3/_write_parquet.py | 2 +- awswrangler/s3/_write_text.py | 85 ++++++++++++++++++++++++++--- 4 files changed, 80 insertions(+), 11 deletions(-) diff --git a/awswrangler/catalog/_create.py b/awswrangler/catalog/_create.py index eddd30daf..50f7f82d0 100644 --- a/awswrangler/catalog/_create.py +++ b/awswrangler/catalog/_create.py @@ -289,7 +289,7 @@ def _create_parquet_table( def _create_csv_table( database: str, table: str, - path: str, + path: Optional[str], columns_types: Dict[str, str], table_type: Optional[str], partitions_types: Optional[Dict[str, str]], diff --git a/awswrangler/catalog/_definitions.py b/awswrangler/catalog/_definitions.py index f0d0e7ab2..97aea2eac 100644 --- a/awswrangler/catalog/_definitions.py +++ b/awswrangler/catalog/_definitions.py @@ -99,7 +99,7 @@ def _parquet_partition_definition( def _csv_table_definition( table: str, - path: str, + path: Optional[str], columns_types: Dict[str, str], table_type: Optional[str], partitions_types: Dict[str, str], diff --git a/awswrangler/s3/_write_parquet.py b/awswrangler/s3/_write_parquet.py index 2148c87f7..a5fecd1e3 100644 --- a/awswrangler/s3/_write_parquet.py +++ b/awswrangler/s3/_write_parquet.py @@ -468,7 +468,7 @@ def to_parquet( # pylint: disable=too-many-arguments,too-many-locals ... 'col3': [None, None, None] ... }), ... dataset=True, - ... mode='overwrite', + ... mode='append', ... database='default', # Athena/Glue database ... table='my_table', # Athena/Glue table ... table_type='GOVERNED', diff --git a/awswrangler/s3/_write_text.py b/awswrangler/s3/_write_text.py index 029511cb2..e62e053af 100644 --- a/awswrangler/s3/_write_text.py +++ b/awswrangler/s3/_write_text.py @@ -74,7 +74,7 @@ def _to_text( @apply_configs def to_csv( # pylint: disable=too-many-arguments,too-many-locals,too-many-statements df: pd.DataFrame, - path: str, + path: Optional[str] = None, sep: str = ",", index: bool = True, columns: Optional[List[str]] = None, @@ -139,7 +139,7 @@ def to_csv( # pylint: disable=too-many-arguments,too-many-locals,too-many-state ---------- df: pandas.DataFrame Pandas DataFrame https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.html - path : str + path : str, optional Amazon S3 path (e.g. s3://bucket/filename.csv). sep : str String of length 1. Field delimiter for the output file. @@ -355,6 +355,28 @@ def to_csv( # pylint: disable=too-many-arguments,too-many-locals,too-many-state } } + Writing dataset to Glue governed table + + >>> import awswrangler as wr + >>> import pandas as pd + >>> wr.s3.to_csv( + ... df=pd.DataFrame({ + ... 'col': [1, 2, 3], + ... 'col2': ['A', 'A', 'B'], + ... 'col3': [None, None, None] + ... }), + ... dataset=True, + ... mode='append', + ... database='default', # Athena/Glue database + ... table='my_table', # Athena/Glue table + ... table_type='GOVERNED', + ... transaction_id="xxx", + ... ) + { + 'paths': ['s3://.../x.csv'], + 'partitions_values: {} + } + Writing dataset casting empty column data type >>> import awswrangler as wr @@ -421,6 +443,13 @@ def to_csv( # pylint: disable=too-many-arguments,too-many-locals,too-many-state catalog_table_input = catalog._get_table_input( # pylint: disable=protected-access database=database, table=table, boto3_session=session, catalog_id=catalog_id ) + if path is None: + if catalog_table_input is not None: + path = catalog_table_input["StorageDescriptor"]["Location"] + else: + raise exceptions.InvalidArgumentValue( + "Glue table does not exist. Please pass the `path` argument to create it." + ) if pandas_kwargs.get("compression") not in ("gzip", "bz2", None): raise exceptions.InvalidArgumentCombination( "If database and table are given, you must use one of these compressions: gzip, bz2 or None." @@ -428,6 +457,7 @@ def to_csv( # pylint: disable=too-many-arguments,too-many-locals,too-many-state df = _apply_dtype(df=df, dtype=dtype, catalog_table_input=catalog_table_input, mode=mode) + paths: List[str] = [] if dataset is False: pandas_kwargs["sep"] = sep pandas_kwargs["index"] = index @@ -441,7 +471,7 @@ def to_csv( # pylint: disable=too-many-arguments,too-many-locals,too-many-state s3_additional_kwargs=s3_additional_kwargs, **pandas_kwargs, ) - paths = [path] + paths = [path] # type: ignore else: if database and table: quoting: Optional[int] = csv.QUOTE_NONE @@ -464,16 +494,58 @@ def to_csv( # pylint: disable=too-many-arguments,too-many-locals,too-many-state pd_kwargs.pop("compression", None) df = df[columns] if columns else df + columns_types: Dict[str, str] = {} + partitions_types: Dict[str, str] = {} + if database and table: + columns_types, partitions_types = _data_types.athena_types_from_pandas_partitioned( + df=df, index=index, partition_cols=partition_cols, dtype=dtype, index_left=True + ) + if (catalog_table_input is None) and (table_type == "GOVERNED"): + catalog._create_csv_table( # pylint: disable=protected-access + database=database, + table=table, + path=path, + columns_types=columns_types, + table_type=table_type, + partitions_types=partitions_types, + bucketing_info=bucketing_info, + description=description, + parameters=parameters, + columns_comments=columns_comments, + boto3_session=session, + mode=mode, + catalog_versioning=catalog_versioning, + sep=sep, + projection_enabled=projection_enabled, + projection_types=projection_types, + projection_ranges=projection_ranges, + projection_values=projection_values, + projection_intervals=projection_intervals, + projection_digits=projection_digits, + catalog_table_input=catalog_table_input, + catalog_id=catalog_id, + compression=pandas_kwargs.get("compression"), + skip_header_line_count=None, + ) + catalog_table_input = catalog._get_table_input( # pylint: disable=protected-access + database=database, table=table, boto3_session=session, catalog_id=catalog_id + ) paths, partitions_values = _to_dataset( func=_to_text, concurrent_partitioning=concurrent_partitioning, df=df, - path_root=path, + path_root=path, # type: ignore index=index, sep=sep, compression=compression, + catalog_id=catalog_id, + database=database, + table=table, + table_type=table_type, + transaction_id=transaction_id, use_threads=use_threads, partition_cols=partition_cols, + partitions_types=partitions_types, bucketing_info=bucketing_info, mode=mode, boto3_session=session, @@ -487,9 +559,6 @@ def to_csv( # pylint: disable=too-many-arguments,too-many-locals,too-many-state ) if database and table: try: - columns_types, partitions_types = _data_types.athena_types_from_pandas_partitioned( - df=df, index=index, partition_cols=partition_cols, dtype=dtype, index_left=True - ) catalog._create_csv_table( # pylint: disable=protected-access database=database, table=table, @@ -516,7 +585,7 @@ def to_csv( # pylint: disable=too-many-arguments,too-many-locals,too-many-state compression=pandas_kwargs.get("compression"), skip_header_line_count=None, ) - if partitions_values and (regular_partitions is True): + if partitions_values and (regular_partitions is True) and (table_type != "GOVERNED"): _logger.debug("partitions_values:\n%s", partitions_values) catalog.add_csv_partitions( database=database, From 866084b83a8a82456c8fe4ebcdcffd38055e6db6 Mon Sep 17 00:00:00 2001 From: Abdel Jaidi Date: Sat, 20 Feb 2021 18:27:30 +0000 Subject: [PATCH 18/25] Minor - Disabling too-many-branches --- awswrangler/s3/_write_text.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/awswrangler/s3/_write_text.py b/awswrangler/s3/_write_text.py index e62e053af..516596b07 100644 --- a/awswrangler/s3/_write_text.py +++ b/awswrangler/s3/_write_text.py @@ -72,7 +72,7 @@ def _to_text( @apply_configs -def to_csv( # pylint: disable=too-many-arguments,too-many-locals,too-many-statements +def to_csv( # pylint: disable=too-many-arguments,too-many-locals,too-many-statements,too-many-branches df: pd.DataFrame, path: Optional[str] = None, sep: str = ",", From 65a5b09e6944451e0f6a3b8294fd3a53073c9e7a Mon Sep 17 00:00:00 2001 From: Abdel Jaidi Date: Mon, 22 Feb 2021 17:33:59 +0000 Subject: [PATCH 19/25] Major - Ready for release --- CONTRIBUTING.md | 11 + awswrangler/lakeformation/_read.py | 2 +- cloudformation/base.yaml | 113 ++++- tests/_utils.py | 33 +- tests/conftest.py | 5 + tests/test__routines.py | 148 +++--- tests/test_lakeformation.py | 126 ++++- ...029 - Lake Formation Governed Tables.ipynb | 448 ++++++++++++++++++ 8 files changed, 800 insertions(+), 86 deletions(-) create mode 100644 tutorials/029 - Lake Formation Governed Tables.ipynb diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 81361573d..9e1027319 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -230,6 +230,17 @@ or ``./cloudformation/delete-databases.sh`` +### Enabling Lake Formation: +If your feature is related to AWS Lake Formation, there are a number of additional steps required in order to complete testing: + +1. In the AWS console, enable Lake Formation by setting your IAM role as an Administrator and by unchecking the boxes in the ``Data Catalog Settings`` section + +2. In the ``./cloudformation/base.yaml`` template file, set ``EnableLakeFormation`` to ``True``. Then run the ``./deploy-base.sh`` once more to add an AWS Glue Database and an S3 bucket registered with Lake Formation + +3. Back in the console, in the ``Data Locations`` section, grant your IAM role access to the S3 Lake Formation bucket (``s3://aws-wrangler-base-lakeformation...``) + +4. Finally, in the ``Data Permissions`` section, grant your IAM role ``Super`` permissions on both the ``aws_data_wrangler`` and ``aws_data_wrangler_lakeformation`` databases + ## Recommended Visual Studio Code Recommended setting ```json diff --git a/awswrangler/lakeformation/_read.py b/awswrangler/lakeformation/_read.py index be414d753..007976bfd 100644 --- a/awswrangler/lakeformation/_read.py +++ b/awswrangler/lakeformation/_read.py @@ -105,7 +105,7 @@ def _resolve_sql_query( ) ) dfs = [df for df in dfs if not df.empty] - if not chunked: + if (not chunked) and dfs: return pd.concat(dfs, sort=False, copy=False, ignore_index=False) return dfs diff --git a/cloudformation/base.yaml b/cloudformation/base.yaml index 6e77560d4..76b69acff 100644 --- a/cloudformation/base.yaml +++ b/cloudformation/base.yaml @@ -1,6 +1,19 @@ AWSTemplateFormatVersion: 2010-09-09 Description: | AWS Data Wrangler Development Base Data Lake Infrastructure. VPC, Subnets, S3 Bucket, Glue Database, etc. +Parameters: + EnableLakeFormation: + Type: String + Description: set to True if Lake Formation is enabled in the account + Default: false + AllowedValues: + - true + - false +Conditions: + CreateLFResources: + Fn::Equals: + - Ref: EnableLakeFormation + - true Resources: VPC: Type: AWS::EC2::VPC @@ -161,6 +174,7 @@ Resources: - Key: Env Value: aws-data-wrangler Description: Aws Data Wrangler Test Key. + EnableKeyRotation: true KeyPolicy: Version: '2012-10-17' Id: aws-data-wrangler-key @@ -217,7 +231,99 @@ Resources: Ref: AWS::AccountId DatabaseInput: Name: aws_data_wrangler - Description: AWS Data Wrangler Test Arena - Glue Database + Description: AWS Data Wrangler Test Athena - Glue Database + LakeFormationBucket: + Type: AWS::S3::Bucket + Condition: CreateLFResources + Properties: + Tags: + - Key: Env + Value: aws-data-wrangler + PublicAccessBlockConfiguration: + BlockPublicAcls: true + BlockPublicPolicy: true + IgnorePublicAcls: true + RestrictPublicBuckets: true + LifecycleConfiguration: + Rules: + - Id: CleaningUp + Status: Enabled + ExpirationInDays: 1 + AbortIncompleteMultipartUpload: + DaysAfterInitiation: 1 + NoncurrentVersionExpirationInDays: 1 + LakeFormationTransactionRole: + Type: AWS::IAM::Role + Condition: CreateLFResources + Properties: + Tags: + - Key: Env + Value: aws-data-wrangler + AssumeRolePolicyDocument: + Version: 2012-10-17 + Statement: + - Effect: Allow + Principal: + Service: + - lakeformation.amazonaws.com + Action: + - sts:AssumeRole + Path: / + Policies: + - PolicyName: Root + PolicyDocument: + Version: 2012-10-17 + Statement: + - Effect: Allow + Action: + - s3:DeleteObject + - s3:GetObject + - s3:PutObject + Resource: + - Fn::Sub: arn:aws:s3:::${LakeFormationBucket}/* + - Effect: Allow + Action: + - s3:ListObject + Resource: + - Fn::Sub: arn:aws:s3:::${LakeFormationBucket} + - Effect: Allow + Action: + - execute-api:Invoke + Resource: arn:aws:execute-api:*:*:*/*/POST/reportStatus + - Effect: Allow + Action: + - lakeformation:AbortTransaction + - lakeformation:BeginTransaction + - lakeformation:CommitTransaction + - lakeformation:GetTableObjects + - lakeformation:UpdateTableObjects + Resource: '*' + - Effect: Allow + Action: + - glue:GetTable + - glue:GetPartitions + - glue:UpdateTable + Resource: '*' + LakeFormationBucketS3Registration: + Type: AWS::LakeFormation::Resource + Condition: CreateLFResources + Properties: + ResourceArn: + Fn::Sub: arn:aws:::s3:${LakeFormationBucket}/ + RoleArn: + Fn::GetAtt: + - LakeFormationTransactionRole + - Arn + UseServiceLinkedRole: false + LakeFormationGlueDatabase: + Type: AWS::Glue::Database + Condition: CreateLFResources + Properties: + CatalogId: + Ref: AWS::AccountId + DatabaseInput: + Name: aws_data_wrangler_lakeformation + Description: AWS Data Wrangler - Lake Formation Database LogGroup: Type: AWS::Logs::LogGroup Properties: @@ -274,6 +380,11 @@ Outputs: Value: Ref: GlueDatabase Description: Glue Database Name. + LakeFormationGlueDatabaseName: + Condition: CreateLFResources + Value: + Ref: LakeFormationGlueDatabase + Description: Lake Formation Glue Database Name. LogGroupName: Value: Ref: LogGroup diff --git a/tests/_utils.py b/tests/_utils.py index 85df69484..c931445c2 100644 --- a/tests/_utils.py +++ b/tests/_utils.py @@ -17,7 +17,7 @@ CFN_VALID_STATUS = ["CREATE_COMPLETE", "ROLLBACK_COMPLETE", "UPDATE_COMPLETE", "UPDATE_ROLLBACK_COMPLETE"] -def get_df(): +def get_df(governed=False): df = pd.DataFrame( { "iint8": [1, None, 2], @@ -45,10 +45,13 @@ def get_df(): df["float"] = df["float"].astype("float32") df["string"] = df["string"].astype("string") df["category"] = df["category"].astype("category") + + if governed: + df = df.drop(["iint8", "binary"], axis=1) # tinyint & binary currently not supported return df -def get_df_list(): +def get_df_list(governed=False): df = pd.DataFrame( { "iint8": [1, None, 2], @@ -79,10 +82,13 @@ def get_df_list(): df["float"] = df["float"].astype("float32") df["string"] = df["string"].astype("string") df["category"] = df["category"].astype("category") + + if governed: + df = (df.drop(["iint8", "binary"], axis=1),) # tinyint & binary currently not supported return df -def get_df_cast(): +def get_df_cast(governed=False): df = pd.DataFrame( { "iint8": [None, None, None], @@ -103,6 +109,8 @@ def get_df_cast(): "par1": ["a", "b", "b"], } ) + if governed: + df = (df.drop(["iint8", "binary"], axis=1),) # tinyint & binary currently not supported return df @@ -418,7 +426,7 @@ def get_query_long(): """ -def ensure_data_types(df, has_list=False): +def ensure_data_types(df, has_list=False, governed=False): if "iint8" in df.columns: assert str(df["iint8"].dtype).startswith("Int") assert str(df["iint16"].dtype).startswith("Int") @@ -430,7 +438,10 @@ def ensure_data_types(df, has_list=False): if "string_object" in df.columns: assert str(df["string_object"].dtype) == "string" assert str(df["string"].dtype) == "string" - assert str(df["date"].dtype) == "object" + if governed: + assert str(df["date"].dtype) == "datetime64[ns]" + else: + assert str(df["date"].dtype) == "object" assert str(df["timestamp"].dtype) == "datetime64[ns]" assert str(df["bool"].dtype) in ("boolean", "Int64", "object") if "binary" in df.columns: @@ -447,7 +458,10 @@ def ensure_data_types(df, has_list=False): if not row.empty: row = row.iloc[0] assert str(type(row["decimal"]).__name__) == "Decimal" - assert str(type(row["date"]).__name__) == "date" + if governed: + assert str(type(row["date"]).__name__) == "Timestamp" + else: + assert str(type(row["date"]).__name__) == "date" if "binary" in df.columns: assert str(type(row["binary"]).__name__) == "bytes" if has_list is True: @@ -468,7 +482,7 @@ def ensure_data_types_category(df): assert str(df["par1"].dtype) == "category" -def ensure_data_types_csv(df): +def ensure_data_types_csv(df, governed=False): if "__index_level_0__" in df: assert str(df["__index_level_0__"].dtype).startswith("Int") assert str(df["id"].dtype).startswith("Int") @@ -480,7 +494,10 @@ def ensure_data_types_csv(df): assert str(df["float"].dtype).startswith("float") if "int" in df: assert str(df["int"].dtype).startswith("Int") - assert str(df["date"].dtype) == "object" + if governed: + assert str(df["date"].dtype).startswith("datetime") + else: + assert str(df["date"].dtype) == "object" assert str(df["timestamp"].dtype).startswith("datetime") if "bool" in df: assert str(df["bool"].dtype) == "boolean" diff --git a/tests/conftest.py b/tests/conftest.py index 011fccfca..7bdb19b64 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -28,6 +28,11 @@ def glue_database(cloudformation_outputs): return cloudformation_outputs["GlueDatabaseName"] +@pytest.fixture(scope="session") +def lakeformation_glue_database(cloudformation_outputs): + return cloudformation_outputs["LakeFormationGlueDatabaseName"] + + @pytest.fixture(scope="session") def kms_key(cloudformation_outputs): return cloudformation_outputs["KmsKeyArn"] diff --git a/tests/test__routines.py b/tests/test__routines.py index 57c9f2526..70b8a5428 100644 --- a/tests/test__routines.py +++ b/tests/test__routines.py @@ -11,7 +11,12 @@ @pytest.mark.parametrize("use_threads", [True, False]) @pytest.mark.parametrize("concurrent_partitioning", [True, False]) @pytest.mark.parametrize("table_type", ["EXTERNAL_TABLE", "GOVERNED"]) -def test_routine_0(glue_database, glue_table, table_type, path, use_threads, concurrent_partitioning): +def test_routine_0( + lakeformation_glue_database, glue_database, glue_table, table_type, path, use_threads, concurrent_partitioning +): + + table = f"__{glue_table}" + database = lakeformation_glue_database if table_type == "GOVERNED" else glue_database # Round 1 - Warm up df = pd.DataFrame({"c0": [0, None]}, dtype="Int64") @@ -20,8 +25,8 @@ def test_routine_0(glue_database, glue_table, table_type, path, use_threads, con path=path, dataset=True, mode="overwrite", - database=glue_database, - table=glue_table, + database=database, + table=table, table_type=table_type, description="c0", parameters={"num_cols": str(len(df.columns)), "num_rows": str(len(df.index))}, @@ -29,19 +34,19 @@ def test_routine_0(glue_database, glue_table, table_type, path, use_threads, con use_threads=use_threads, concurrent_partitioning=concurrent_partitioning, ) - assert wr.catalog.get_table_number_of_versions(table=glue_table, database=glue_database) == 1 + assert wr.catalog.get_table_number_of_versions(table=table, database=database) == 1 if table_type == "GOVERNED": - df2 = wr.lakeformation.read_sql_table(glue_table, glue_database, use_threads=use_threads) + df2 = wr.lakeformation.read_sql_table(table, database, use_threads=use_threads) else: - df2 = wr.athena.read_sql_table(glue_table, glue_database, use_threads=use_threads) + df2 = wr.athena.read_sql_table(table, database, use_threads=use_threads) assert df.shape == df2.shape assert df.c0.sum() == df2.c0.sum() - parameters = wr.catalog.get_table_parameters(glue_database, glue_table) + parameters = wr.catalog.get_table_parameters(database, table) assert len(parameters) >= 5 assert parameters["num_cols"] == str(len(df2.columns)) assert parameters["num_rows"] == str(len(df2.index)) - assert wr.catalog.get_table_description(glue_database, glue_table) == "c0" - comments = wr.catalog.get_columns_comments(glue_database, glue_table) + assert wr.catalog.get_table_description(database, table) == "c0" + comments = wr.catalog.get_columns_comments(database, table) assert len(comments) == len(df.columns) assert comments["c0"] == "0" @@ -51,8 +56,8 @@ def test_routine_0(glue_database, glue_table, table_type, path, use_threads, con df=df, dataset=True, mode="overwrite", - database=glue_database, - table=glue_table, + database=database, + table=table, table_type=table_type, description="c1", parameters={"num_cols": str(len(df.columns)), "num_rows": str(len(df.index))}, @@ -60,19 +65,19 @@ def test_routine_0(glue_database, glue_table, table_type, path, use_threads, con use_threads=use_threads, concurrent_partitioning=concurrent_partitioning, ) - assert wr.catalog.get_table_number_of_versions(table=glue_table, database=glue_database) == 1 + assert wr.catalog.get_table_number_of_versions(table=table, database=database) == 1 if table_type == "GOVERNED": - df2 = wr.lakeformation.read_sql_table(glue_table, glue_database, use_threads=use_threads) + df2 = wr.lakeformation.read_sql_table(table, database, use_threads=use_threads) else: - df2 = wr.athena.read_sql_table(glue_table, glue_database, use_threads=use_threads) + df2 = wr.athena.read_sql_table(table, database, use_threads=use_threads) assert df.shape == df2.shape assert df.c1.sum() == df2.c1.sum() - parameters = wr.catalog.get_table_parameters(glue_database, glue_table) + parameters = wr.catalog.get_table_parameters(database, table) assert len(parameters) >= 5 assert parameters["num_cols"] == str(len(df2.columns)) assert parameters["num_rows"] == str(len(df2.index)) - assert wr.catalog.get_table_description(glue_database, glue_table) == "c1" - comments = wr.catalog.get_columns_comments(glue_database, glue_table) + assert wr.catalog.get_table_description(database, table) == "c1" + comments = wr.catalog.get_columns_comments(database, table) assert len(comments) == len(df.columns) assert comments["c1"] == "1" @@ -83,8 +88,8 @@ def test_routine_0(glue_database, glue_table, table_type, path, use_threads, con path=path, dataset=True, mode="append", - database=glue_database, - table=glue_table, + database=database, + table=table, table_type=table_type, description="c1", parameters={"num_cols": str(len(df.columns)), "num_rows": str(len(df.index) * 2)}, @@ -92,20 +97,20 @@ def test_routine_0(glue_database, glue_table, table_type, path, use_threads, con use_threads=use_threads, concurrent_partitioning=concurrent_partitioning, ) - assert wr.catalog.get_table_number_of_versions(table=glue_table, database=glue_database) == 1 + assert wr.catalog.get_table_number_of_versions(table=table, database=database) == 1 if table_type == "GOVERNED": - df2 = wr.lakeformation.read_sql_table(glue_table, glue_database, use_threads=use_threads) + df2 = wr.lakeformation.read_sql_table(table, database, use_threads=use_threads) else: - df2 = wr.athena.read_sql_table(glue_table, glue_database, use_threads=use_threads) + df2 = wr.athena.read_sql_table(table, database, use_threads=use_threads) assert len(df.columns) == len(df2.columns) assert len(df.index) * 2 == len(df2.index) assert df.c1.sum() + 1 == df2.c1.sum() - parameters = wr.catalog.get_table_parameters(glue_database, glue_table) + parameters = wr.catalog.get_table_parameters(database, table) assert len(parameters) >= 5 assert parameters["num_cols"] == str(len(df2.columns)) assert parameters["num_rows"] == str(len(df2.index)) - assert wr.catalog.get_table_description(glue_database, glue_table) == "c1" - comments = wr.catalog.get_columns_comments(glue_database, glue_table) + assert wr.catalog.get_table_description(database, table) == "c1" + comments = wr.catalog.get_columns_comments(database, table) assert len(comments) == len(df.columns) assert comments["c1"] == "1" @@ -115,8 +120,8 @@ def test_routine_0(glue_database, glue_table, table_type, path, use_threads, con df=df, dataset=True, mode="append", - database=glue_database, - table=glue_table, + database=database, + table=table, table_type=table_type, description="c1+c2", parameters={"num_cols": "2", "num_rows": "9"}, @@ -124,20 +129,20 @@ def test_routine_0(glue_database, glue_table, table_type, path, use_threads, con use_threads=use_threads, concurrent_partitioning=concurrent_partitioning, ) - assert wr.catalog.get_table_number_of_versions(table=glue_table, database=glue_database) == 1 + assert wr.catalog.get_table_number_of_versions(table=table, database=database) == 1 if table_type == "GOVERNED": - df2 = wr.lakeformation.read_sql_table(glue_table, glue_database, use_threads=use_threads) + df2 = wr.lakeformation.read_sql_table(table, database, use_threads=use_threads) else: - df2 = wr.athena.read_sql_table(glue_table, glue_database, use_threads=use_threads) + df2 = wr.athena.read_sql_table(table, database, use_threads=use_threads) assert len(df2.columns) == 2 assert len(df2.index) == 9 assert df2.c1.sum() == 3 - parameters = wr.catalog.get_table_parameters(glue_database, glue_table) + parameters = wr.catalog.get_table_parameters(database, table) assert len(parameters) >= 5 assert parameters["num_cols"] == "2" assert parameters["num_rows"] == "9" - assert wr.catalog.get_table_description(glue_database, glue_table) == "c1+c2" - comments = wr.catalog.get_columns_comments(glue_database, glue_table) + assert wr.catalog.get_table_description(database, table) == "c1+c2" + comments = wr.catalog.get_columns_comments(database, table) assert len(comments) == len(df.columns) assert comments["c1"] == "1" assert comments["c2"] == "2" @@ -149,8 +154,8 @@ def test_routine_0(glue_database, glue_table, table_type, path, use_threads, con path=path, dataset=True, mode="append", - database=glue_database, - table=glue_table, + database=database, + table=table, table_type=table_type, description="c1+c2+c3", parameters={"num_cols": "3", "num_rows": "10"}, @@ -158,30 +163,37 @@ def test_routine_0(glue_database, glue_table, table_type, path, use_threads, con use_threads=use_threads, concurrent_partitioning=concurrent_partitioning, ) - assert wr.catalog.get_table_number_of_versions(table=glue_table, database=glue_database) == 1 + assert wr.catalog.get_table_number_of_versions(table=table, database=database) == 1 if table_type == "GOVERNED": - df2 = wr.lakeformation.read_sql_table(glue_table, glue_database, use_threads=use_threads) + df2 = wr.lakeformation.read_sql_table(table, database, use_threads=use_threads) else: - df2 = wr.athena.read_sql_table(glue_table, glue_database, use_threads=use_threads) + df2 = wr.athena.read_sql_table(table, database, use_threads=use_threads) assert len(df2.columns) == 3 assert len(df2.index) == 10 assert df2.c1.sum() == 4 - parameters = wr.catalog.get_table_parameters(glue_database, glue_table) + parameters = wr.catalog.get_table_parameters(database, table) assert len(parameters) >= 5 assert parameters["num_cols"] == "3" assert parameters["num_rows"] == "10" - assert wr.catalog.get_table_description(glue_database, glue_table) == "c1+c2+c3" - comments = wr.catalog.get_columns_comments(glue_database, glue_table) + assert wr.catalog.get_table_description(database, table) == "c1+c2+c3" + comments = wr.catalog.get_columns_comments(database, table) assert len(comments) == len(df.columns) assert comments["c1"] == "1!" assert comments["c2"] == "2!" assert comments["c3"] == "3" + wr.catalog.delete_table_if_exists(database=database, table=table) + @pytest.mark.parametrize("use_threads", [True, False]) @pytest.mark.parametrize("concurrent_partitioning", [True, False]) @pytest.mark.parametrize("table_type", ["EXTERNAL_TABLE", "GOVERNED"]) -def test_routine_1(glue_database, glue_table, table_type, path, use_threads, concurrent_partitioning): +def test_routine_1( + lakeformation_glue_database, glue_database, glue_table, table_type, path, use_threads, concurrent_partitioning +): + + table = f"__{glue_table}" + database = lakeformation_glue_database if table_type == "GOVERNED" else glue_database # Round 1 - Overwrite Partitioned df = pd.DataFrame({"c0": ["foo", None], "c1": [0, 1]}) @@ -190,8 +202,8 @@ def test_routine_1(glue_database, glue_table, table_type, path, use_threads, con path=path, dataset=True, mode="overwrite", - database=glue_database, - table=glue_table, + database=database, + table=table, table_type=table_type, partition_cols=["c1"], description="c0+c1", @@ -200,19 +212,19 @@ def test_routine_1(glue_database, glue_table, table_type, path, use_threads, con use_threads=use_threads, concurrent_partitioning=concurrent_partitioning, ) - assert wr.catalog.get_table_number_of_versions(table=glue_table, database=glue_database) == 1 + assert wr.catalog.get_table_number_of_versions(table=table, database=database) == 1 if table_type == "GOVERNED": - df2 = wr.lakeformation.read_sql_table(glue_table, glue_database, use_threads=use_threads) + df2 = wr.lakeformation.read_sql_table(table, database, use_threads=use_threads) else: - df2 = wr.athena.read_sql_table(glue_table, glue_database, use_threads=use_threads) + df2 = wr.athena.read_sql_table(table, database, use_threads=use_threads) assert df.shape == df2.shape assert df.c1.sum() == df2.c1.sum() - parameters = wr.catalog.get_table_parameters(glue_database, glue_table) + parameters = wr.catalog.get_table_parameters(database, table) assert len(parameters) >= 5 assert parameters["num_cols"] == "2" assert parameters["num_rows"] == "2" - assert wr.catalog.get_table_description(glue_database, glue_table) == "c0+c1" - comments = wr.catalog.get_columns_comments(glue_database, glue_table) + assert wr.catalog.get_table_description(database, table) == "c0+c1" + comments = wr.catalog.get_columns_comments(database, table) assert len(comments) == len(df.columns) assert comments["c0"] == "zero" assert comments["c1"] == "one" @@ -223,8 +235,8 @@ def test_routine_1(glue_database, glue_table, table_type, path, use_threads, con df=df, dataset=True, mode="overwrite_partitions", - database=glue_database, - table=glue_table, + database=database, + table=table, table_type=table_type, partition_cols=["c1"], description="c0+c1", @@ -233,20 +245,20 @@ def test_routine_1(glue_database, glue_table, table_type, path, use_threads, con concurrent_partitioning=concurrent_partitioning, use_threads=use_threads, ) - assert wr.catalog.get_table_number_of_versions(table=glue_table, database=glue_database) == 1 + assert wr.catalog.get_table_number_of_versions(table=table, database=database) == 1 if table_type == "GOVERNED": - df2 = wr.lakeformation.read_sql_table(glue_table, glue_database, use_threads=use_threads) + df2 = wr.lakeformation.read_sql_table(table, database, use_threads=use_threads) else: - df2 = wr.athena.read_sql_table(glue_table, glue_database, use_threads=use_threads) + df2 = wr.athena.read_sql_table(table, database, use_threads=use_threads) assert len(df2.columns) == 2 assert len(df2.index) == 3 assert df2.c1.sum() == 3 - parameters = wr.catalog.get_table_parameters(glue_database, glue_table) + parameters = wr.catalog.get_table_parameters(database, table) assert len(parameters) >= 5 assert parameters["num_cols"] == "2" assert parameters["num_rows"] == "3" - assert wr.catalog.get_table_description(glue_database, glue_table) == "c0+c1" - comments = wr.catalog.get_columns_comments(glue_database, glue_table) + assert wr.catalog.get_table_description(database, table) == "c0+c1" + comments = wr.catalog.get_columns_comments(database, table) assert len(comments) == len(df.columns) assert comments["c0"] == "zero" assert comments["c1"] == "one" @@ -258,8 +270,8 @@ def test_routine_1(glue_database, glue_table, table_type, path, use_threads, con path=path, dataset=True, mode="overwrite_partitions", - database=glue_database, - table=glue_table, + database=database, + table=table, table_type=table_type, partition_cols=["c1"], description="c0+c1+c2", @@ -268,25 +280,27 @@ def test_routine_1(glue_database, glue_table, table_type, path, use_threads, con use_threads=use_threads, concurrent_partitioning=concurrent_partitioning, ) - assert wr.catalog.get_table_number_of_versions(table=glue_table, database=glue_database) == 1 + assert wr.catalog.get_table_number_of_versions(table=table, database=database) == 1 if table_type == "GOVERNED": - df2 = wr.lakeformation.read_sql_table(glue_table, glue_database, use_threads=use_threads) + df2 = wr.lakeformation.read_sql_table(table, database, use_threads=use_threads) else: - df2 = wr.athena.read_sql_table(glue_table, glue_database, use_threads=use_threads) + df2 = wr.athena.read_sql_table(table, database, use_threads=use_threads) assert len(df2.columns) == 3 assert len(df2.index) == 4 assert df2.c1.sum() == 6 - parameters = wr.catalog.get_table_parameters(glue_database, glue_table) + parameters = wr.catalog.get_table_parameters(database, table) assert len(parameters) >= 5 assert parameters["num_cols"] == "3" assert parameters["num_rows"] == "4" - assert wr.catalog.get_table_description(glue_database, glue_table) == "c0+c1+c2" - comments = wr.catalog.get_columns_comments(glue_database, glue_table) + assert wr.catalog.get_table_description(database, table) == "c0+c1+c2" + comments = wr.catalog.get_columns_comments(database, table) assert len(comments) == len(df.columns) assert comments["c0"] == "zero" assert comments["c1"] == "one" assert comments["c2"] == "two" + wr.catalog.delete_table_if_exists(database=database, table=table) + def test_routine_2(glue_database, glue_table, path): @@ -481,3 +495,5 @@ def test_routine_2(glue_database, glue_table, path): assert comments["c0"] == "zero" assert comments["c1"] == "one" assert comments["c2"] == "two" + + wr.catalog.delete_table_if_exists(database=glue_database, table=glue_table) diff --git a/tests/test_lakeformation.py b/tests/test_lakeformation.py index a36759de8..242cb3a0f 100644 --- a/tests/test_lakeformation.py +++ b/tests/test_lakeformation.py @@ -1,20 +1,24 @@ +import calendar import logging +import time -import pytest +import pandas as pd import awswrangler as wr -from ._utils import get_df +from ._utils import ensure_data_types, ensure_data_types_csv, get_df, get_df_csv logging.getLogger("awswrangler").setLevel(logging.DEBUG) -@pytest.mark.parametrize("use_threads", [True, False]) -def test_lakeformation(path, glue_database, glue_table, use_threads): +def test_lakeformation(path, path2, lakeformation_glue_database, glue_table, glue_table2, use_threads=False): table = f"__{glue_table}" - wr.catalog.delete_table_if_exists(database=glue_database, table=table) + table2 = f"__{glue_table2}" + wr.catalog.delete_table_if_exists(database=lakeformation_glue_database, table=table) + wr.catalog.delete_table_if_exists(database=lakeformation_glue_database, table=table2) + wr.s3.to_parquet( - df=get_df().drop(["iint8", "binary"], axis=1), # tinyint & binary currently not supported + df=get_df(governed=True), path=path, index=False, boto3_session=None, @@ -23,22 +27,124 @@ def test_lakeformation(path, glue_database, glue_table, use_threads): partition_cols=["par0", "par1"], mode="overwrite", table=table, - database=glue_database, + table_type="GOVERNED", + database=lakeformation_glue_database, ) df = wr.lakeformation.read_sql_table( table=table, - database=glue_database, + database=lakeformation_glue_database, use_threads=use_threads, ) assert len(df.index) == 3 assert len(df.columns) == 14 assert df["iint32"].sum() == 3 + ensure_data_types(df=df, governed=True) + # Filter query df2 = wr.lakeformation.read_sql_query( sql=f"SELECT * FROM {table} WHERE iint16 = :iint16;", - database=glue_database, + database=lakeformation_glue_database, params={"iint16": 1}, ) assert len(df2.index) == 1 - wr.catalog.delete_table_if_exists(database=glue_database, table=table) + + wr.s3.to_csv( + df=get_df_csv(), + path=path2, + index=False, + boto3_session=None, + s3_additional_kwargs=None, + dataset=True, + partition_cols=["par0", "par1"], + mode="append", + table=table2, + table_type="GOVERNED", + database=lakeformation_glue_database, + ) + # Read within a transaction + transaction_id = wr.lakeformation.begin_transaction(read_only=True) + df3 = wr.lakeformation.read_sql_table( + table=table2, + database=lakeformation_glue_database, + transaction_id=transaction_id, + use_threads=use_threads, + ) + assert df3["int"].sum() == 3 + ensure_data_types_csv(df3, governed=True) + + # Read within a query as of time + query_as_of_time = calendar.timegm(time.gmtime()) + df4 = wr.lakeformation.read_sql_table( + table=table2, + database=lakeformation_glue_database, + query_as_of_time=query_as_of_time, + use_threads=use_threads, + ) + assert len(df4.index) == 3 + + wr.catalog.delete_table_if_exists(database=lakeformation_glue_database, table=table) + wr.catalog.delete_table_if_exists(database=lakeformation_glue_database, table=table2) + + +def test_lakeformation_multi_transaction( + path, path2, lakeformation_glue_database, glue_table, glue_table2, use_threads=True +): + table = f"__{glue_table}" + table2 = f"__{glue_table2}" + wr.catalog.delete_table_if_exists(database=lakeformation_glue_database, table=table) + wr.catalog.delete_table_if_exists(database=lakeformation_glue_database, table=table2) + + df = pd.DataFrame({"c0": [0, None]}, dtype="Int64") + transaction_id = wr.lakeformation.begin_transaction(read_only=False) + wr.s3.to_parquet( + df=df, + path=path, + dataset=True, + mode="append", + database=lakeformation_glue_database, + table=table, + table_type="GOVERNED", + transaction_id=transaction_id, + description="c0", + parameters={"num_cols": str(len(df.columns)), "num_rows": str(len(df.index))}, + columns_comments={"c0": "0"}, + use_threads=use_threads, + ) + + df2 = pd.DataFrame({"c1": [None, 1, None]}, dtype="Int16") + wr.s3.to_parquet( + df=df2, + path=path2, + dataset=True, + mode="append", + database=lakeformation_glue_database, + table=table2, + table_type="GOVERNED", + transaction_id=transaction_id, + description="c1", + parameters={"num_cols": str(len(df.columns)), "num_rows": str(len(df.index))}, + columns_comments={"c1": "1"}, + use_threads=use_threads, + ) + wr.lakeformation.commit_transaction(transaction_id=transaction_id) + + df3 = wr.lakeformation.read_sql_table( + table=table, + database=lakeformation_glue_database, + use_threads=use_threads, + ) + df4 = wr.lakeformation.read_sql_table( + table=table2, + database=lakeformation_glue_database, + use_threads=use_threads, + ) + + assert df.shape == df3.shape + assert df.c0.sum() == df3.c0.sum() + + assert df2.shape == df4.shape + assert df2.c1.sum() == df4.c1.sum() + + wr.catalog.delete_table_if_exists(database=lakeformation_glue_database, table=table) + wr.catalog.delete_table_if_exists(database=lakeformation_glue_database, table=table2) diff --git a/tutorials/029 - Lake Formation Governed Tables.ipynb b/tutorials/029 - Lake Formation Governed Tables.ipynb new file mode 100644 index 000000000..6aaec37aa --- /dev/null +++ b/tutorials/029 - Lake Formation Governed Tables.ipynb @@ -0,0 +1,448 @@ +{ + "metadata": { + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.1" + }, + "orig_nbformat": 2, + "kernelspec": { + "name": "python3", + "display_name": "Python 3.9.1 64-bit ('.venv': venv)", + "metadata": { + "interpreter": { + "hash": "2878c7ae46413c5ab07cafef85a7415922732432fa2f847b9105997e244ed975" + } + } + } + }, + "nbformat": 4, + "nbformat_minor": 2, + "cells": [ + { + "source": [ + "[![AWS Data Wrangler](_static/logo.png \"AWS Data Wrangler\")](https://github.com/awslabs/aws-data-wrangler)" + ], + "cell_type": "markdown", + "metadata": {} + }, + { + "source": [ + "# AWS Lake Formation - Glue Governed tables" + ], + "cell_type": "markdown", + "metadata": {} + }, + { + "source": [ + "### This tutorial assumes that the IAM user/role has the required Lake Formation permissions to create and read AWS Glue governed tables in Lake Formation" + ], + "cell_type": "markdown", + "metadata": {} + }, + { + "source": [ + "## Table of Contents\n", + "* [1. Read Governed table](#1.-Read-Governed-table)\n", + " * [1.1 Read PartiQL query](#1.1-Read-PartiQL-query)\n", + " * [1.1.1 Read within transaction](#1.1.1-Read-within-transaction)\n", + " * [1.1.2 Read within query as of time](#1.1.2-Read-within-query-as-of-time)\n", + " * [1.2 Read full table](#1.2-Read-full-table)\n", + "* [2. Write Governed table](#2.-Write-Governed-table)\n", + " * [2.1 Create new Governed table](#2.1-Create-new-Governed-table)\n", + " * [2.1.1 CSV table](#2.1.1-CSV-table)\n", + " * [2.1.2 Parquet table](#2.1.2-Parquet-table)\n", + " * [2.2 Overwrite operations](#2.2-Overwrite-operations)\n", + " * [2.2.1 Overwrite](#2.2.1-Overwrite)\n", + " * [2.2.2 Append](#2.2.2-Append)\n", + " * [2.2.3 Create partitioned Governed table](#2.2.3-Create-partitioned-Governed-table)\n", + " * [2.2.4 Overwrite partitions](#2.2.4-Overwrite-partitions)\n", + "* [3. Multiple read/write operations within a transaction](#2.-Multiple-read/write-operations-within-a-transaction)" + ], + "cell_type": "markdown", + "metadata": {} + }, + { + "source": [ + "# 1. Read Governed table" + ], + "cell_type": "markdown", + "metadata": {} + }, + { + "source": [ + "## 1.1 Read PartiQL query" + ], + "cell_type": "markdown", + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import awswrangler as wr\n", + "\n", + "database = \"gov_db\" # Assumes a Glue database registered with Lake Formation exists in the account\n", + "table = \"gov_table\" # Assumes a Governed table of the same name exists in the account\n", + "catalog_id = \"111111111111\" # AWS Account Id\n", + "\n", + "# Note 1: If a transaction_id is not specified, a new transaction is started\n", + "df = wr.lakeformation.read_sql_query(\n", + " sql=f\"SELECT * FROM {table};\",\n", + " database=database,\n", + " catalog_id=\"111111111111\"\n", + ")" + ] + }, + { + "source": [ + "### 1.1.1 Read within transaction" + ], + "cell_type": "markdown", + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "transaction_id = wr.lakeformation.begin_transaction(read_only=True)\n", + "df = wr.lakeformation.read_sql_query(\n", + " sql=f\"SELECT * FROM {table};\",\n", + " database=database,\n", + " transaction_id=transaction_id\n", + ")" + ] + }, + { + "source": [ + "### 1.1.2 Read within query as of time" + ], + "cell_type": "markdown", + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import calendar\n", + "import time\n", + "\n", + "query_as_of_time = query_as_of_time = calendar.timegm(time.gmtime())\n", + "df = wr.lakeformation.read_sql_query(\n", + " sql=f\"SELECT * FROM {table} WHERE id = :id;\",\n", + " database=database,\n", + " query_as_of_time=query_as_of_time,\n", + " params={\"id\": 1}\n", + ")" + ] + }, + { + "source": [ + "## 1.2 Read full table" + ], + "cell_type": "markdown", + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "df = wr.lakeformation.read_sql_table(\n", + " table=table,\n", + " database=database,\n", + " use_threads=True\n", + ")" + ] + }, + { + "source": [ + "# 2. Write Governed table" + ], + "cell_type": "markdown", + "metadata": {} + }, + { + "source": [ + "## 2.1 Create a new Governed table" + ], + "cell_type": "markdown", + "metadata": {} + }, + { + "source": [ + "## Enter your bucket name:" + ], + "cell_type": "markdown", + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import getpass\n", + "\n", + "bucket = getpass.getpass()" + ] + }, + { + "source": [ + "### If a governed table does not exist, you can specify an S3 `path` and it will be created. Make sure your IAM user/role has enough permissions in the Lake Formation database" + ], + "cell_type": "markdown", + "metadata": {} + }, + { + "source": [ + "### 2.1.1 CSV table" + ], + "cell_type": "markdown", + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd\n", + "\n", + "table = \"gov_table_csv\"\n", + "\n", + "# Note 1: If a transaction_id is not specified, a new transaction is started\n", + "df=pd.DataFrame({\n", + " \"col\": [1, 2, 3],\n", + " \"col2\": [\"A\", \"A\", \"B\"],\n", + " \"col3\": [None, \"test\", None]\n", + "})\n", + "wr.s3.to_csv(\n", + " df=df,\n", + " path=f\"s3://{bucket}/{table}/\", # S3 path\n", + " dataset=True,\n", + " database=database,\n", + " table=table,\n", + " table_type=\"GOVERNED\"\n", + ")" + ] + }, + { + "source": [ + "### 2.1.2 Parquet table" + ], + "cell_type": "markdown", + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "table = \"gov_table_parquet\"\n", + "\n", + "df = pd.DataFrame({\"c0\": [0, None]}, dtype=\"Int64\")\n", + "wr.s3.to_parquet(\n", + " df=df,\n", + " path=f\"s3://{bucket}/{table}/\",\n", + " dataset=True,\n", + " database=database,\n", + " table=table,\n", + " table_type=\"GOVERNED\",\n", + " description=\"c0\",\n", + " parameters={\"num_cols\": str(len(df.columns)), \"num_rows\": str(len(df.index))},\n", + " columns_comments={\"c0\": \"0\"},\n", + " use_threads=True\n", + ")" + ] + }, + { + "source": [ + "## 2.2 Overwrite operations" + ], + "cell_type": "markdown", + "metadata": {} + }, + { + "source": [ + "### 2.2.1 Overwrite" + ], + "cell_type": "markdown", + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "df = pd.DataFrame({\"c1\": [None, 1, None]}, dtype=\"Int16\")\n", + "wr.s3.to_parquet(\n", + " df=df,\n", + " dataset=True,\n", + " mode=\"overwrite\",\n", + " database=database,\n", + " table=table,\n", + " table_type=\"GOVERNED\",\n", + " description=\"c1\",\n", + " parameters={\"num_cols\": str(len(df.columns)), \"num_rows\": str(len(df.index))},\n", + " columns_comments={\"c1\": \"1\"},\n", + " use_threads=True\n", + ")" + ] + }, + { + "source": [ + "### 2.2.2 Append" + ], + "cell_type": "markdown", + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "df = pd.DataFrame({\"c1\": [None, 2, None]}, dtype=\"Int8\")\n", + "wr.s3.to_parquet(\n", + " df=df,\n", + " dataset=True,\n", + " mode=\"append\",\n", + " database=database,\n", + " table=table,\n", + " table_type=\"GOVERNED\",\n", + " description=\"c1\",\n", + " parameters={\"num_cols\": str(len(df.columns)), \"num_rows\": str(len(df.index) * 2)},\n", + " columns_comments={\"c1\": \"1\"},\n", + " use_threads=True\n", + ")" + ] + }, + { + "source": [ + "### 2.2.3 Create partitioned Governed table" + ], + "cell_type": "markdown", + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "table = \"gov_table_parquet_partitioned\"\n", + "\n", + "df = pd.DataFrame({\"c0\": [\"foo\", None], \"c1\": [0, 1]})\n", + "wr.s3.to_parquet(\n", + " df=df,\n", + " path=f\"s3://{bucket}/{table}/\",\n", + " dataset=True,\n", + " database=database,\n", + " table=table,\n", + " table_type=\"GOVERNED\",\n", + " partition_cols=[\"c1\"],\n", + " description=\"c0+c1\",\n", + " parameters={\"num_cols\": \"2\", \"num_rows\": \"2\"},\n", + " columns_comments={\"c0\": \"zero\", \"c1\": \"one\"},\n", + " use_threads=True\n", + ")" + ] + }, + { + "source": [ + "### 2.2.4 Overwrite partitions" + ], + "cell_type": "markdown", + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "df = pd.DataFrame({\"c0\": [None, None], \"c1\": [0, 2]})\n", + "wr.s3.to_parquet(\n", + " df=df,\n", + " dataset=True,\n", + " mode=\"overwrite_partitions\",\n", + " database=database,\n", + " table=table,\n", + " table_type=\"GOVERNED\",\n", + " partition_cols=[\"c1\"],\n", + " description=\"c0+c1\",\n", + " parameters={\"num_cols\": \"2\", \"num_rows\": \"3\"},\n", + " columns_comments={\"c0\": \"zero\", \"c1\": \"one\"}\n", + ")" + ] + }, + { + "source": [ + "# 3. Multiple read/write operations within a transaction" + ], + "cell_type": "markdown", + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "read_table = \"gov_table_parquet\"\n", + "write_table = \"gov_table_multi_parquet\"\n", + "\n", + "transaction_id = wr.lakeformation.begin_transaction(read_only=False)\n", + "\n", + "df = pd.DataFrame({\"c0\": [0, None]}, dtype=\"Int64\")\n", + "wr.s3.to_parquet(\n", + " df=df,\n", + " path=f\"s3://{bucket}/{table}/{write_table}_1\",\n", + " dataset=True,\n", + " database=database,\n", + " table=f\"{write_table}_1\",\n", + " table_type=\"GOVERNED\",\n", + " transaction_id=transaction_id,\n", + ")\n", + "\n", + "df2 = wr.lakeformation.read_sql_table(\n", + " table=read_table,\n", + " database=database,\n", + " transaction_id=transaction_id,\n", + " use_threads=True\n", + ")\n", + "\n", + "df3 = pd.DataFrame({\"c1\": [None, 1, None]}, dtype=\"Int16\")\n", + "wr.s3.to_parquet(\n", + " df=df2,\n", + " path=f\"s3://{bucket}/{table}/{write_table}_2\",\n", + " dataset=True,\n", + " mode=\"append\",\n", + " database=database,\n", + " table=f\"{write_table}_2\",\n", + " table_type=\"GOVERNED\",\n", + " transaction_id=transaction_id,\n", + ")\n", + "\n", + "wr.lakeformation.commit_transaction(transaction_id=transaction_id)" + ] + } + ] +} \ No newline at end of file From 78194ed1872ddb8f7ae7c58e775a756e4c8310fd Mon Sep 17 00:00:00 2001 From: Abdel Jaidi Date: Tue, 23 Feb 2021 08:55:38 +0000 Subject: [PATCH 20/25] Minor - Proofreading --- tutorials/029 - Lake Formation Governed Tables.ipynb | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tutorials/029 - Lake Formation Governed Tables.ipynb b/tutorials/029 - Lake Formation Governed Tables.ipynb index 6aaec37aa..75d0ab6f3 100644 --- a/tutorials/029 - Lake Formation Governed Tables.ipynb +++ b/tutorials/029 - Lake Formation Governed Tables.ipynb @@ -42,7 +42,7 @@ }, { "source": [ - "### This tutorial assumes that the IAM user/role has the required Lake Formation permissions to create and read AWS Glue governed tables in Lake Formation" + "### This tutorial assumes that your IAM user/role has the required Lake Formation permissions to create and read AWS Glue Governed tables" ], "cell_type": "markdown", "metadata": {} @@ -92,7 +92,7 @@ "import awswrangler as wr\n", "\n", "database = \"gov_db\" # Assumes a Glue database registered with Lake Formation exists in the account\n", - "table = \"gov_table\" # Assumes a Governed table of the same name exists in the account\n", + "table = \"gov_table\" # Assumes a Governed table exists in the account\n", "catalog_id = \"111111111111\" # AWS Account Id\n", "\n", "# Note 1: If a transaction_id is not specified, a new transaction is started\n", @@ -203,7 +203,7 @@ }, { "source": [ - "### If a governed table does not exist, you can specify an S3 `path` and it will be created. Make sure your IAM user/role has enough permissions in the Lake Formation database" + "### If a governed table does not exist, it can be created by passing an S3 `path` argument. Make sure your IAM user/role has enough permissions in the Lake Formation database" ], "cell_type": "markdown", "metadata": {} @@ -225,12 +225,12 @@ "\n", "table = \"gov_table_csv\"\n", "\n", - "# Note 1: If a transaction_id is not specified, a new transaction is started\n", "df=pd.DataFrame({\n", " \"col\": [1, 2, 3],\n", " \"col2\": [\"A\", \"A\", \"B\"],\n", " \"col3\": [None, \"test\", None]\n", "})\n", + "# Note 1: If a transaction_id is not specified, a new transaction is started\n", "wr.s3.to_csv(\n", " df=df,\n", " path=f\"s3://{bucket}/{table}/\", # S3 path\n", From e7cc97f9febdf34f8791784d82957f8b7d549b01 Mon Sep 17 00:00:00 2001 From: Abdel Jaidi Date: Tue, 23 Feb 2021 16:43:43 +0000 Subject: [PATCH 21/25] Minor - Removing needless use_threads argument --- .../029 - Lake Formation Governed Tables.ipynb | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/tutorials/029 - Lake Formation Governed Tables.ipynb b/tutorials/029 - Lake Formation Governed Tables.ipynb index 75d0ab6f3..6982a10c5 100644 --- a/tutorials/029 - Lake Formation Governed Tables.ipynb +++ b/tutorials/029 - Lake Formation Governed Tables.ipynb @@ -99,7 +99,7 @@ "df = wr.lakeformation.read_sql_query(\n", " sql=f\"SELECT * FROM {table};\",\n", " database=database,\n", - " catalog_id=\"111111111111\"\n", + " catalog_id=catalog_id\n", ")" ] }, @@ -164,8 +164,7 @@ "source": [ "df = wr.lakeformation.read_sql_table(\n", " table=table,\n", - " database=database,\n", - " use_threads=True\n", + " database=database\n", ")" ] }, @@ -266,8 +265,7 @@ " table_type=\"GOVERNED\",\n", " description=\"c0\",\n", " parameters={\"num_cols\": str(len(df.columns)), \"num_rows\": str(len(df.index))},\n", - " columns_comments={\"c0\": \"0\"},\n", - " use_threads=True\n", + " columns_comments={\"c0\": \"0\"}\n", ")" ] }, @@ -301,8 +299,7 @@ " table_type=\"GOVERNED\",\n", " description=\"c1\",\n", " parameters={\"num_cols\": str(len(df.columns)), \"num_rows\": str(len(df.index))},\n", - " columns_comments={\"c1\": \"1\"},\n", - " use_threads=True\n", + " columns_comments={\"c1\": \"1\"}\n", ")" ] }, @@ -329,8 +326,7 @@ " table_type=\"GOVERNED\",\n", " description=\"c1\",\n", " parameters={\"num_cols\": str(len(df.columns)), \"num_rows\": str(len(df.index) * 2)},\n", - " columns_comments={\"c1\": \"1\"},\n", - " use_threads=True\n", + " columns_comments={\"c1\": \"1\"}\n", ")" ] }, @@ -360,8 +356,7 @@ " partition_cols=[\"c1\"],\n", " description=\"c0+c1\",\n", " parameters={\"num_cols\": \"2\", \"num_rows\": \"2\"},\n", - " columns_comments={\"c0\": \"zero\", \"c1\": \"one\"},\n", - " use_threads=True\n", + " columns_comments={\"c0\": \"zero\", \"c1\": \"one\"}\n", ")" ] }, From 9bbd007772d64460e1f2912caf4f5876fc0f8436 Mon Sep 17 00:00:00 2001 From: Abdel Jaidi Date: Thu, 25 Feb 2021 10:53:58 +0000 Subject: [PATCH 22/25] Minor - Removing the need to specify table_type when table is already created --- awswrangler/s3/_write.py | 6 ------ awswrangler/s3/_write_parquet.py | 10 ++++++---- awswrangler/s3/_write_text.py | 8 +++++--- tests/test__routines.py | 6 ------ .../029 - Lake Formation Governed Tables.ipynb | 14 ++++++-------- 5 files changed, 17 insertions(+), 27 deletions(-) diff --git a/awswrangler/s3/_write.py b/awswrangler/s3/_write.py index 0ed48535e..666035bb6 100644 --- a/awswrangler/s3/_write.py +++ b/awswrangler/s3/_write.py @@ -48,8 +48,6 @@ def _validate_args( database: Optional[str], dataset: bool, path: Optional[str], - table_type: Optional[str], - transaction_id: Optional[str], partition_cols: Optional[List[str]], bucketing_info: Optional[Tuple[List[str], int]], mode: Optional[str], @@ -83,10 +81,6 @@ def _validate_args( "Arguments database and table must be passed together. If you want to store your dataset metadata in " "the Glue Catalog, please ensure you are passing both." ) - elif (table_type != "GOVERNED") and (transaction_id is not None): - raise exceptions.InvalidArgumentCombination( - "When passing a `transaction_id` as an argument, `table_type` must be set to 'GOVERNED'" - ) elif bucketing_info and bucketing_info[1] <= 0: raise exceptions.InvalidArgumentValue( "Please pass a value greater than 1 for the number of buckets for bucketing." diff --git a/awswrangler/s3/_write_parquet.py b/awswrangler/s3/_write_parquet.py index a5fecd1e3..5ea2ff9c0 100644 --- a/awswrangler/s3/_write_parquet.py +++ b/awswrangler/s3/_write_parquet.py @@ -196,7 +196,7 @@ def _to_parquet( @apply_configs -def to_parquet( # pylint: disable=too-many-arguments,too-many-locals +def to_parquet( # pylint: disable=too-many-arguments,too-many-locals,too-many-branches,too-many-statements df: pd.DataFrame, path: Optional[str] = None, index: bool = False, @@ -507,8 +507,6 @@ def to_parquet( # pylint: disable=too-many-arguments,too-many-locals database=database, dataset=dataset, path=path, - table_type=table_type, - transaction_id=transaction_id, partition_cols=partition_cols, bucketing_info=bucketing_info, mode=mode, @@ -527,6 +525,8 @@ def to_parquet( # pylint: disable=too-many-arguments,too-many-locals dtype = dtype if dtype else {} partitions_values: Dict[str, List[str]] = {} mode = "append" if mode is None else mode + if transaction_id: + table_type = "GOVERNED" cpus: int = _utils.ensure_cpu_count(use_threads=use_threads) session: boto3.Session = _utils.ensure_session(session=boto3_session) @@ -540,8 +540,10 @@ def to_parquet( # pylint: disable=too-many-arguments,too-many-locals catalog_table_input = catalog._get_table_input( # pylint: disable=protected-access database=database, table=table, boto3_session=session, catalog_id=catalog_id ) + if catalog_table_input: + table_type = catalog_table_input["TableType"] if path is None: - if catalog_table_input is not None: + if catalog_table_input: path = catalog_table_input["StorageDescriptor"]["Location"] else: raise exceptions.InvalidArgumentValue( diff --git a/awswrangler/s3/_write_text.py b/awswrangler/s3/_write_text.py index 516596b07..75d9324e4 100644 --- a/awswrangler/s3/_write_text.py +++ b/awswrangler/s3/_write_text.py @@ -416,8 +416,6 @@ def to_csv( # pylint: disable=too-many-arguments,too-many-locals,too-many-state database=database, dataset=dataset, path=path, - table_type=table_type, - transaction_id=transaction_id, partition_cols=partition_cols, bucketing_info=bucketing_info, mode=mode, @@ -431,6 +429,8 @@ def to_csv( # pylint: disable=too-many-arguments,too-many-locals,too-many-state dtype = dtype if dtype else {} partitions_values: Dict[str, List[str]] = {} mode = "append" if mode is None else mode + if transaction_id: + table_type = "GOVERNED" session: boto3.Session = _utils.ensure_session(session=boto3_session) # Sanitize table to respect Athena's standards @@ -443,8 +443,10 @@ def to_csv( # pylint: disable=too-many-arguments,too-many-locals,too-many-state catalog_table_input = catalog._get_table_input( # pylint: disable=protected-access database=database, table=table, boto3_session=session, catalog_id=catalog_id ) + if catalog_table_input: + table_type = catalog_table_input["TableType"] if path is None: - if catalog_table_input is not None: + if catalog_table_input: path = catalog_table_input["StorageDescriptor"]["Location"] else: raise exceptions.InvalidArgumentValue( diff --git a/tests/test__routines.py b/tests/test__routines.py index 70b8a5428..96f430059 100644 --- a/tests/test__routines.py +++ b/tests/test__routines.py @@ -58,7 +58,6 @@ def test_routine_0( mode="overwrite", database=database, table=table, - table_type=table_type, description="c1", parameters={"num_cols": str(len(df.columns)), "num_rows": str(len(df.index))}, columns_comments={"c1": "1"}, @@ -90,7 +89,6 @@ def test_routine_0( mode="append", database=database, table=table, - table_type=table_type, description="c1", parameters={"num_cols": str(len(df.columns)), "num_rows": str(len(df.index) * 2)}, columns_comments={"c1": "1"}, @@ -122,7 +120,6 @@ def test_routine_0( mode="append", database=database, table=table, - table_type=table_type, description="c1+c2", parameters={"num_cols": "2", "num_rows": "9"}, columns_comments={"c1": "1", "c2": "2"}, @@ -156,7 +153,6 @@ def test_routine_0( mode="append", database=database, table=table, - table_type=table_type, description="c1+c2+c3", parameters={"num_cols": "3", "num_rows": "10"}, columns_comments={"c1": "1!", "c2": "2!", "c3": "3"}, @@ -237,7 +233,6 @@ def test_routine_1( mode="overwrite_partitions", database=database, table=table, - table_type=table_type, partition_cols=["c1"], description="c0+c1", parameters={"num_cols": "2", "num_rows": "3"}, @@ -272,7 +267,6 @@ def test_routine_1( mode="overwrite_partitions", database=database, table=table, - table_type=table_type, partition_cols=["c1"], description="c0+c1+c2", parameters={"num_cols": "3", "num_rows": "4"}, diff --git a/tutorials/029 - Lake Formation Governed Tables.ipynb b/tutorials/029 - Lake Formation Governed Tables.ipynb index 6982a10c5..56133b281 100644 --- a/tutorials/029 - Lake Formation Governed Tables.ipynb +++ b/tutorials/029 - Lake Formation Governed Tables.ipynb @@ -230,9 +230,10 @@ " \"col3\": [None, \"test\", None]\n", "})\n", "# Note 1: If a transaction_id is not specified, a new transaction is started\n", + "# Note 2: When creating a new Governed table, `table_type=\"GOVERNED\"` must be specified. Otherwise the default is to create an EXTERNAL_TABLE\n", "wr.s3.to_csv(\n", " df=df,\n", - " path=f\"s3://{bucket}/{table}/\", # S3 path\n", + " path=f\"s3://{bucket}/{database}/{table}/\", # S3 path\n", " dataset=True,\n", " database=database,\n", " table=table,\n", @@ -258,7 +259,7 @@ "df = pd.DataFrame({\"c0\": [0, None]}, dtype=\"Int64\")\n", "wr.s3.to_parquet(\n", " df=df,\n", - " path=f\"s3://{bucket}/{table}/\",\n", + " path=f\"s3://{bucket}/{database}/{table}/\",\n", " dataset=True,\n", " database=database,\n", " table=table,\n", @@ -296,7 +297,6 @@ " mode=\"overwrite\",\n", " database=database,\n", " table=table,\n", - " table_type=\"GOVERNED\",\n", " description=\"c1\",\n", " parameters={\"num_cols\": str(len(df.columns)), \"num_rows\": str(len(df.index))},\n", " columns_comments={\"c1\": \"1\"}\n", @@ -323,7 +323,6 @@ " mode=\"append\",\n", " database=database,\n", " table=table,\n", - " table_type=\"GOVERNED\",\n", " description=\"c1\",\n", " parameters={\"num_cols\": str(len(df.columns)), \"num_rows\": str(len(df.index) * 2)},\n", " columns_comments={\"c1\": \"1\"}\n", @@ -348,7 +347,7 @@ "df = pd.DataFrame({\"c0\": [\"foo\", None], \"c1\": [0, 1]})\n", "wr.s3.to_parquet(\n", " df=df,\n", - " path=f\"s3://{bucket}/{table}/\",\n", + " path=f\"s3://{bucket}/{database}/{table}/\",\n", " dataset=True,\n", " database=database,\n", " table=table,\n", @@ -380,7 +379,6 @@ " mode=\"overwrite_partitions\",\n", " database=database,\n", " table=table,\n", - " table_type=\"GOVERNED\",\n", " partition_cols=[\"c1\"],\n", " description=\"c0+c1\",\n", " parameters={\"num_cols\": \"2\", \"num_rows\": \"3\"},\n", @@ -409,7 +407,7 @@ "df = pd.DataFrame({\"c0\": [0, None]}, dtype=\"Int64\")\n", "wr.s3.to_parquet(\n", " df=df,\n", - " path=f\"s3://{bucket}/{table}/{write_table}_1\",\n", + " path=f\"s3://{bucket}/{database}/{write_table}_1\",\n", " dataset=True,\n", " database=database,\n", " table=f\"{write_table}_1\",\n", @@ -427,7 +425,7 @@ "df3 = pd.DataFrame({\"c1\": [None, 1, None]}, dtype=\"Int16\")\n", "wr.s3.to_parquet(\n", " df=df2,\n", - " path=f\"s3://{bucket}/{table}/{write_table}_2\",\n", + " path=f\"s3://{bucket}/{database}/{write_table}_2\",\n", " dataset=True,\n", " mode=\"append\",\n", " database=database,\n", From 16788e3e9fd468ceca6140f18b8fffcb0c22fbe8 Mon Sep 17 00:00:00 2001 From: Abdel Jaidi Date: Thu, 25 Feb 2021 11:22:58 +0000 Subject: [PATCH 23/25] Minor - Fixing _catalog_id call --- awswrangler/lakeformation/_utils.py | 15 ++------------- 1 file changed, 2 insertions(+), 13 deletions(-) diff --git a/awswrangler/lakeformation/_utils.py b/awswrangler/lakeformation/_utils.py index 5088096b1..ea94101bd 100644 --- a/awswrangler/lakeformation/_utils.py +++ b/awswrangler/lakeformation/_utils.py @@ -70,13 +70,7 @@ def _get_table_objects( client_lakeformation: boto3.client = _utils.client(service_name="lakeformation", session=session) scan_kwargs: Dict[str, Union[str, int]] = _catalog_id( - catalog_id=catalog_id, - **{ - "TransactionId": transaction_id, - "DatabaseName": database, - "TableName": table, - "MaxResults": 100, - }, + catalog_id=catalog_id, TransactionId=transaction_id, DatabaseName=database, TableName=table, MaxResults=100 ) if partition_cols and partitions_types and partitions_values: scan_kwargs["PartitionPredicate"] = _build_partition_predicate( @@ -111,12 +105,7 @@ def _update_table_objects( client_lakeformation: boto3.client = _utils.client(service_name="lakeformation", session=session) update_kwargs: Dict[str, Union[str, int, List[Dict[str, Dict[str, Any]]]]] = _catalog_id( - catalog_id=catalog_id, - **{ - "TransactionId": transaction_id, - "DatabaseName": database, - "TableName": table, - }, + catalog_id=catalog_id, TransactionId=transaction_id, DatabaseName=database, TableName=table ) write_operations: List[Dict[str, Dict[str, Any]]] = [] From 6e48f61f9ca7d7341d710253be3cb1b0755452f5 Mon Sep 17 00:00:00 2001 From: Abdel Jaidi Date: Thu, 25 Feb 2021 14:48:47 +0000 Subject: [PATCH 24/25] Minor - Clarifying SQL filter operation --- .github/workflows/static-checking.yml | 1 - awswrangler/athena/_read.py | 4 ++-- awswrangler/lakeformation/_read.py | 6 +++--- awswrangler/sqlserver.py | 2 +- tutorials/029 - Lake Formation Governed Tables.ipynb | 4 ++-- 5 files changed, 8 insertions(+), 9 deletions(-) diff --git a/.github/workflows/static-checking.yml b/.github/workflows/static-checking.yml index f4622631a..fafca0e93 100644 --- a/.github/workflows/static-checking.yml +++ b/.github/workflows/static-checking.yml @@ -9,7 +9,6 @@ on: branches: - main - main-governed-tables - - feature/lf-transactions jobs: Check: diff --git a/awswrangler/athena/_read.py b/awswrangler/athena/_read.py index bca15e15f..dba3888fa 100644 --- a/awswrangler/athena/_read.py +++ b/awswrangler/athena/_read.py @@ -761,8 +761,8 @@ def read_sql_query( >>> import awswrangler as wr >>> df = wr.athena.read_sql_query( - ... sql="SELECT * FROM my_table WHERE name=:name;", - ... params={"name": "filtered_name"} + ... sql="SELECT * FROM my_table WHERE name=:name; AND city=:city;", + ... params={"name": "'filtered_name'", "city": "'filtered_city'"} ... ) """ diff --git a/awswrangler/lakeformation/_read.py b/awswrangler/lakeformation/_read.py index 007976bfd..d08c7a5d9 100644 --- a/awswrangler/lakeformation/_read.py +++ b/awswrangler/lakeformation/_read.py @@ -204,10 +204,10 @@ def read_sql_query( >>> import awswrangler as wr >>> df = wr.lakeformation.read_sql_query( - ... sql="SELECT * FROM my_table WHERE name=:name;", + ... sql="SELECT * FROM my_table WHERE name=:name; AND city=:city;", ... database="my_db", ... query_as_of_time="1611142914", - ... params={"name": "'filtered_name'"} + ... params={"name": "'filtered_name'", "city": "'filtered_city'"} ... ) """ @@ -220,7 +220,7 @@ def read_sql_query( if params is None: params = {} for key, value in params.items(): - sql = sql.replace(f":{key}", str(value)) + sql = sql.replace(f":{key};", str(value)) args: Dict[str, Optional[str]] = _catalog_id(catalog_id=catalog_id, **{"DatabaseName": database, "Statement": sql}) if query_as_of_time: diff --git a/awswrangler/sqlserver.py b/awswrangler/sqlserver.py index a721a1430..b5f0ce483 100644 --- a/awswrangler/sqlserver.py +++ b/awswrangler/sqlserver.py @@ -15,7 +15,7 @@ __all__ = ["connect", "read_sql_query", "read_sql_table", "to_sql"] -_pyodbc_found = importlib.util.find_spec("pyodbc") +_pyodbc_found = importlib.util.find_spec("pyodbc") # type: ignore if _pyodbc_found: import pyodbc # pylint: disable=import-error diff --git a/tutorials/029 - Lake Formation Governed Tables.ipynb b/tutorials/029 - Lake Formation Governed Tables.ipynb index 56133b281..571b78a89 100644 --- a/tutorials/029 - Lake Formation Governed Tables.ipynb +++ b/tutorials/029 - Lake Formation Governed Tables.ipynb @@ -142,10 +142,10 @@ "\n", "query_as_of_time = query_as_of_time = calendar.timegm(time.gmtime())\n", "df = wr.lakeformation.read_sql_query(\n", - " sql=f\"SELECT * FROM {table} WHERE id = :id;\",\n", + " sql=f\"SELECT * FROM {table} WHERE id=:id; AND name=:name;\",\n", " database=database,\n", " query_as_of_time=query_as_of_time,\n", - " params={\"id\": 1}\n", + " params={\"id\": 1, \"name\": \"Ayoub\"}\n", ")" ] }, From 11081788a235dad28b663429a9e1f7cc167effb9 Mon Sep 17 00:00:00 2001 From: Abdel Jaidi Date: Thu, 25 Feb 2021 14:53:52 +0000 Subject: [PATCH 25/25] Minor - Removing type ignore --- awswrangler/sqlserver.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/awswrangler/sqlserver.py b/awswrangler/sqlserver.py index b5f0ce483..a721a1430 100644 --- a/awswrangler/sqlserver.py +++ b/awswrangler/sqlserver.py @@ -15,7 +15,7 @@ __all__ = ["connect", "read_sql_query", "read_sql_table", "to_sql"] -_pyodbc_found = importlib.util.find_spec("pyodbc") # type: ignore +_pyodbc_found = importlib.util.find_spec("pyodbc") if _pyodbc_found: import pyodbc # pylint: disable=import-error