From a0d07772277cfe10f8d76a2c0664aabe7b061b36 Mon Sep 17 00:00:00 2001 From: igorborgest Date: Mon, 18 May 2020 17:21:47 -0300 Subject: [PATCH 1/2] First schema evolution propose for parquet datasets. #232 --- awswrangler/_data_types.py | 35 ++- awswrangler/catalog.py | 183 ++++++++++-- awswrangler/s3.py | 37 ++- docs/source/api.rst | 2 + requirements-dev.txt | 3 +- testing/test_awswrangler/test_data_lake.py | 331 +++++++++++++++++---- tox.ini | 5 +- 7 files changed, 491 insertions(+), 105 deletions(-) diff --git a/awswrangler/_data_types.py b/awswrangler/_data_types.py index aa08860e8..8cd84e7b8 100644 --- a/awswrangler/_data_types.py +++ b/awswrangler/_data_types.py @@ -175,15 +175,15 @@ def pyarrow2sqlalchemy( # pylint: disable=too-many-branches,too-many-return-sta ) -> Optional[VisitableType]: """Pyarrow to Athena data types conversion.""" if pa.types.is_int8(dtype): - return sqlalchemy.types.SMALLINT + return sqlalchemy.types.SmallInteger if pa.types.is_int16(dtype): - return sqlalchemy.types.SMALLINT + return sqlalchemy.types.SmallInteger if pa.types.is_int32(dtype): - return sqlalchemy.types.INTEGER + return sqlalchemy.types.Integer if pa.types.is_int64(dtype): - return sqlalchemy.types.BIGINT + return sqlalchemy.types.BigInteger if pa.types.is_float32(dtype): - return sqlalchemy.types.FLOAT + return sqlalchemy.types.Float if pa.types.is_float64(dtype): if db_type == "mysql": return sqlalchemy.dialects.mysql.DOUBLE @@ -195,25 +195,25 @@ def pyarrow2sqlalchemy( # pylint: disable=too-many-branches,too-many-return-sta f"{db_type} is a invalid database type, please choose between postgresql, mysql and redshift." ) # pragma: no cover if pa.types.is_boolean(dtype): - return sqlalchemy.types.BOOLEAN + return sqlalchemy.types.Boolean if pa.types.is_string(dtype): if db_type == "mysql": - return sqlalchemy.types.TEXT + return sqlalchemy.types.Text if db_type == "postgresql": - return sqlalchemy.types.TEXT + return sqlalchemy.types.Text if db_type == "redshift": return sqlalchemy.types.VARCHAR(length=256) raise exceptions.InvalidDatabaseType( f"{db_type} is a invalid database type. " f"Please choose between postgresql, mysql and redshift." ) # pragma: no cover if pa.types.is_timestamp(dtype): - return sqlalchemy.types.DATETIME + return sqlalchemy.types.DateTime if pa.types.is_date(dtype): - return sqlalchemy.types.DATE + return sqlalchemy.types.Date if pa.types.is_binary(dtype): if db_type == "redshift": raise exceptions.UnsupportedType("Binary columns are not supported for Redshift.") # pragma: no cover - return sqlalchemy.types.BINARY + return sqlalchemy.types.Binary if pa.types.is_decimal(dtype): return sqlalchemy.types.Numeric(precision=dtype.precision, scale=dtype.scale) if pa.types.is_dictionary(dtype): @@ -396,7 +396,7 @@ def cast_pandas_with_athena_types(df: pd.DataFrame, dtype: Dict[str, str]) -> pd df[col] = ( df[col] .astype("string") - .apply(lambda x: Decimal(str(x)) if str(x) not in ("", "none", " ", "") else None) + .apply(lambda x: Decimal(str(x)) if str(x) not in ("", "none", "None", " ", "") else None) ) elif pandas_type == "string": curr_type: str = str(df[col].dtypes) @@ -405,7 +405,16 @@ def cast_pandas_with_athena_types(df: pd.DataFrame, dtype: Dict[str, str]) -> pd else: df[col] = df[col].astype("string") else: - df[col] = df[col].astype(pandas_type) + try: + df[col] = df[col].astype(pandas_type) + except TypeError as ex: + if "object cannot be converted to an IntegerDtype" not in str(ex): + raise ex # pragma: no cover + df[col] = ( + df[col] + .apply(lambda x: int(x) if str(x) not in ("", "none", "None", " ", "") else None) + .astype(pandas_type) + ) return df diff --git a/awswrangler/catalog.py b/awswrangler/catalog.py index 49d748ed3..4a8d6b2d3 100644 --- a/awswrangler/catalog.py +++ b/awswrangler/catalog.py @@ -150,9 +150,33 @@ def create_parquet_table( """ table = sanitize_table_name(table=table) partitions_types = {} if partitions_types is None else partitions_types - table_input: Dict[str, Any] = _parquet_table_definition( - table=table, path=path, columns_types=columns_types, partitions_types=partitions_types, compression=compression - ) + session: boto3.Session = _utils.ensure_session(session=boto3_session) + cat_table_input: Optional[Dict[str, Any]] = _get_table_input(database=database, table=table, boto3_session=session) + table_input: Dict[str, Any] + if (cat_table_input is not None) and (mode in ("append", "overwrite_partitions")): + table_input = cat_table_input + updated: bool = False + cat_cols: Dict[str, str] = {x["Name"]: x["Type"] for x in table_input["StorageDescriptor"]["Columns"]} + for c, t in columns_types.items(): + if c not in cat_cols: + _logger.debug("New column %s with type %s.", c, t) + table_input["StorageDescriptor"]["Columns"].append({"Name": c, "Type": t}) + updated = True + elif t != cat_cols[c]: # Data type change detected! + raise exceptions.InvalidArgumentValue( + f"Data type change detected on column {c}. Old type: {cat_cols[c]}. New type {t}." + ) + if updated is True: + mode = "update" + else: + table_input = _parquet_table_definition( + table=table, + path=path, + columns_types=columns_types, + partitions_types=partitions_types, + compression=compression, + ) + table_exist: bool = cat_table_input is not None _create_table( database=database, table=table, @@ -161,8 +185,9 @@ def create_parquet_table( columns_comments=columns_comments, mode=mode, catalog_versioning=catalog_versioning, - boto3_session=boto3_session, + boto3_session=session, table_input=table_input, + table_exist=table_exist, ) @@ -266,7 +291,9 @@ def _parquet_partition_definition(location: str, values: List[str], compression: } -def get_table_types(database: str, table: str, boto3_session: Optional[boto3.Session] = None) -> Dict[str, str]: +def get_table_types( + database: str, table: str, boto3_session: Optional[boto3.Session] = None +) -> Optional[Dict[str, str]]: """Get all columns and types from a table. Parameters @@ -280,8 +307,8 @@ def get_table_types(database: str, table: str, boto3_session: Optional[boto3.Ses Returns ------- - Dict[str, str] - A dictionary as {'col name': 'col data type'}. + Optional[Dict[str, str]] + If table exists, a dictionary like {'col name': 'col data type'}. Otherwise None. Examples -------- @@ -291,7 +318,10 @@ def get_table_types(database: str, table: str, boto3_session: Optional[boto3.Ses """ client_glue: boto3.client = _utils.client(service_name="glue", session=boto3_session) - response: Dict[str, Any] = client_glue.get_table(DatabaseName=database, Name=table) + try: + response: Dict[str, Any] = client_glue.get_table(DatabaseName=database, Name=table) + except client_glue.exceptions.EntityNotFoundException: + return None dtypes: Dict[str, str] = {} for col in response["Table"]["StorageDescriptor"]["Columns"]: dtypes[col["Name"]] = col["Type"] @@ -938,6 +968,7 @@ def create_csv_table( compression=compression, sep=sep, ) + session: boto3.Session = _utils.ensure_session(session=boto3_session) _create_table( database=database, table=table, @@ -946,8 +977,9 @@ def create_csv_table( columns_comments=columns_comments, mode=mode, catalog_versioning=catalog_versioning, - boto3_session=boto3_session, + boto3_session=session, table_input=table_input, + table_exist=does_table_exist(database=database, table=table, boto3_session=session), ) @@ -961,6 +993,7 @@ def _create_table( catalog_versioning: bool, boto3_session: Optional[boto3.Session], table_input: Dict[str, Any], + table_exist: bool, ): if description is not None: table_input["Description"] = description @@ -978,13 +1011,12 @@ def _create_table( par["Comment"] = columns_comments[name] session: boto3.Session = _utils.ensure_session(session=boto3_session) client_glue: boto3.client = _utils.client(service_name="glue", session=session) - exist: bool = does_table_exist(database=database, table=table, boto3_session=session) - if mode not in ("overwrite", "append", "overwrite_partitions"): # pragma: no cover + skip_archive: bool = not catalog_versioning + if mode not in ("overwrite", "append", "overwrite_partitions", "update"): # pragma: no cover raise exceptions.InvalidArgument( f"{mode} is not a valid mode. It must be 'overwrite', 'append' or 'overwrite_partitions'." ) - if (exist is True) and (mode == "overwrite"): - skip_archive: bool = not catalog_versioning + if (table_exist is True) and (mode == "overwrite"): partitions_values: List[List[str]] = list( _get_partitions(database=database, table=table, boto3_session=session).values() ) @@ -992,9 +1024,12 @@ def _create_table( DatabaseName=database, TableName=table, PartitionsToDelete=[{"Values": v} for v in partitions_values] ) client_glue.update_table(DatabaseName=database, TableInput=table_input, SkipArchive=skip_archive) - elif (exist is True) and (mode in ("append", "overwrite_partitions")) and (parameters is not None): - upsert_table_parameters(parameters=parameters, database=database, table=table, boto3_session=session) - elif exist is False: + elif (table_exist is True) and (mode in ("append", "overwrite_partitions", "update")): + if parameters is not None: + upsert_table_parameters(parameters=parameters, database=database, table=table, boto3_session=session) + if mode == "update": + client_glue.update_table(DatabaseName=database, TableInput=table_input, SkipArchive=skip_archive) + elif table_exist is False: client_glue.create_table(DatabaseName=database, TableInput=table_input) @@ -1379,6 +1414,88 @@ def get_table_parameters( return parameters +def get_table_description( + database: str, table: str, catalog_id: Optional[str] = None, boto3_session: Optional[boto3.Session] = None +) -> str: + """Get table description. + + Parameters + ---------- + database : str + Database name. + table : str + Table name. + catalog_id : str, optional + The ID of the Data Catalog from which to retrieve Databases. + If none is provided, the AWS account ID is used by default. + boto3_session : boto3.Session(), optional + Boto3 Session. The default boto3 session will be used if boto3_session receive None. + + Returns + ------- + str + Description. + + Examples + -------- + >>> import awswrangler as wr + >>> desc = wr.catalog.get_table_description(database="...", table="...") + + """ + client_glue: boto3.client = _utils.client(service_name="glue", session=boto3_session) + args: Dict[str, str] = {} + if catalog_id is not None: + args["CatalogId"] = catalog_id # pragma: no cover + args["DatabaseName"] = database + args["Name"] = table + response: Dict[str, Any] = client_glue.get_table(**args) + desc: str = response["Table"]["Description"] + return desc + + +def get_columns_comments( + database: str, table: str, catalog_id: Optional[str] = None, boto3_session: Optional[boto3.Session] = None +) -> Dict[str, str]: + """Get all columns comments. + + Parameters + ---------- + database : str + Database name. + table : str + Table name. + catalog_id : str, optional + The ID of the Data Catalog from which to retrieve Databases. + If none is provided, the AWS account ID is used by default. + boto3_session : boto3.Session(), optional + Boto3 Session. The default boto3 session will be used if boto3_session receive None. + + Returns + ------- + Dict[str, str] + Columns comments. e.g. {"col1": "foo boo bar"}. + + Examples + -------- + >>> import awswrangler as wr + >>> pars = wr.catalog.get_table_parameters(database="...", table="...") + + """ + client_glue: boto3.client = _utils.client(service_name="glue", session=boto3_session) + args: Dict[str, str] = {} + if catalog_id is not None: + args["CatalogId"] = catalog_id # pragma: no cover + args["DatabaseName"] = database + args["Name"] = table + response: Dict[str, Any] = client_glue.get_table(**args) + comments: Dict[str, str] = {} + for c in response["Table"]["StorageDescriptor"]["Columns"]: + comments[c["Name"]] = c["Comment"] + for p in response["Table"]["PartitionKeys"]: + comments[p["Name"]] = p["Comment"] + return comments + + def upsert_table_parameters( parameters: Dict[str, str], database: str, @@ -1465,14 +1582,36 @@ def overwrite_table_parameters( ... table="...") """ + session: boto3.Session = _utils.ensure_session(session=boto3_session) + table_input: Optional[Dict[str, Any]] = _get_table_input( + database=database, table=table, catalog_id=catalog_id, boto3_session=session + ) + if table_input is None: + raise exceptions.InvalidTable(f"Table {table} does not exist on database {database}.") + table_input["Parameters"] = parameters + args2: Dict[str, Union[str, Dict[str, Any]]] = {} + if catalog_id is not None: + args2["CatalogId"] = catalog_id # pragma: no cover + args2["DatabaseName"] = database + args2["TableInput"] = table_input + client_glue: boto3.client = _utils.client(service_name="glue", session=session) + client_glue.update_table(**args2) + return parameters + + +def _get_table_input( + database: str, table: str, boto3_session: Optional[boto3.Session], catalog_id: Optional[str] = None +) -> Optional[Dict[str, str]]: client_glue: boto3.client = _utils.client(service_name="glue", session=boto3_session) args: Dict[str, str] = {} if catalog_id is not None: args["CatalogId"] = catalog_id # pragma: no cover args["DatabaseName"] = database args["Name"] = table - response: Dict[str, Any] = client_glue.get_table(**args) - response["Table"]["Parameters"] = parameters + try: + response: Dict[str, Any] = client_glue.get_table(**args) + except client_glue.exceptions.EntityNotFoundException: + return None if "DatabaseName" in response["Table"]: del response["Table"]["DatabaseName"] if "CreateTime" in response["Table"]: @@ -1483,10 +1622,4 @@ def overwrite_table_parameters( del response["Table"]["CreatedBy"] if "IsRegisteredWithLakeFormation" in response["Table"]: del response["Table"]["IsRegisteredWithLakeFormation"] - args2: Dict[str, Union[str, Dict[str, Any]]] = {} - if catalog_id is not None: - args2["CatalogId"] = catalog_id # pragma: no cover - args2["DatabaseName"] = database - args2["TableInput"] = response["Table"] - client_glue.update_table(**args2) - return parameters + return response["Table"] diff --git a/awswrangler/s3.py b/awswrangler/s3.py index 1a5590579..b68a1e433 100644 --- a/awswrangler/s3.py +++ b/awswrangler/s3.py @@ -627,11 +627,18 @@ def to_csv( # pylint: disable=too-many-arguments ) if df.empty is True: raise exceptions.EmptyDataFrame() - session: boto3.Session = _utils.ensure_session(session=boto3_session) + + # Sanitize table to respect Athena's standards partition_cols = partition_cols if partition_cols else [] dtype = dtype if dtype else {} columns_comments = columns_comments if columns_comments else {} partitions_values: Dict[str, List[str]] = {} + df = catalog.sanitize_dataframe_columns_names(df=df) + partition_cols = [catalog.sanitize_column_name(p) for p in partition_cols] + dtype = {catalog.sanitize_column_name(k): v.lower() for k, v in dtype.items()} + columns_comments = {catalog.sanitize_column_name(k): v for k, v in columns_comments.items()} + + session: boto3.Session = _utils.ensure_session(session=boto3_session) fs: s3fs.S3FileSystem = _utils.get_fs(session=session, s3_additional_kwargs=s3_additional_kwargs) if dataset is False: if partition_cols: @@ -653,14 +660,14 @@ def to_csv( # pylint: disable=too-many-arguments mode = "append" if mode is None else mode if columns: df = df[columns] - if (database is not None) and (table is not None): # Normalize table to respect Athena's standards - df = catalog.sanitize_dataframe_columns_names(df=df) - partition_cols = [catalog.sanitize_column_name(p) for p in partition_cols] - dtype = {catalog.sanitize_column_name(k): v.lower() for k, v in dtype.items()} - columns_comments = {catalog.sanitize_column_name(k): v for k, v in columns_comments.items()} - exist: bool = catalog.does_table_exist(database=database, table=table, boto3_session=session) - if (exist is True) and (mode in ("append", "overwrite_partitions")): - for k, v in catalog.get_table_types(database=database, table=table, boto3_session=session).items(): + if ( + (mode in ("append", "overwrite_partitions")) and (database is not None) and (table is not None) + ): # Fetching Catalog Types + catalog_types: Optional[Dict[str, str]] = catalog.get_table_types( + database=database, table=table, boto3_session=session + ) + if catalog_types is not None: + for k, v in catalog_types.items(): dtype[k] = v df = catalog.drop_duplicated_columns(df=df) paths, partitions_values = _to_csv_dataset( @@ -1083,10 +1090,14 @@ def to_parquet( # pylint: disable=too-many-arguments ] else: mode = "append" if mode is None else mode - if (database is not None) and (table is not None): - exist: bool = catalog.does_table_exist(database=database, table=table, boto3_session=session) - if (exist is True) and (mode in ("append", "overwrite_partitions")): - for k, v in catalog.get_table_types(database=database, table=table, boto3_session=session).items(): + if ( + (mode in ("append", "overwrite_partitions")) and (database is not None) and (table is not None) + ): # Fetching Catalog Types + catalog_types: Optional[Dict[str, str]] = catalog.get_table_types( + database=database, table=table, boto3_session=session + ) + if catalog_types is not None: + for k, v in catalog_types.items(): dtype[k] = v paths, partitions_values = _to_parquet_dataset( df=df, diff --git a/docs/source/api.rst b/docs/source/api.rst index 7f36e5aa5..9d1619f38 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -64,6 +64,8 @@ AWS Glue Catalog get_engine extract_athena_types get_table_parameters + get_columns_comments + get_table_description upsert_table_parameters overwrite_table_parameters diff --git a/requirements-dev.txt b/requirements-dev.txt index 76554ff30..30cb6182b 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -10,7 +10,8 @@ tox~=3.15.0 pytest~=5.4.2 pytest-cov~=2.8.1 pytest-xdist~=1.32.0 -scikit-learn~=0.22.1 +pytest-timeout~=1.3.4 +scikit-learn~=0.23.0 cfn-lint~=0.31.1 cfn-flip~=1.2.3 twine~=3.1.1 diff --git a/testing/test_awswrangler/test_data_lake.py b/testing/test_awswrangler/test_data_lake.py index 267b29011..34e1adf29 100644 --- a/testing/test_awswrangler/test_data_lake.py +++ b/testing/test_awswrangler/test_data_lake.py @@ -178,7 +178,7 @@ def path(bucket): @pytest.fixture(scope="function") def table(database): - name = get_time_str_with_random_suffix() + name = f"tbl_{get_time_str_with_random_suffix()}" print(f"Table name: {name}") wr.catalog.delete_table_if_exists(database=database, table=name) yield name @@ -202,7 +202,7 @@ def path2(bucket): @pytest.fixture(scope="function") def table2(database): - name = get_time_str_with_random_suffix() + name = f"tbl_{get_time_str_with_random_suffix()}" print(f"Table name: {name}") wr.catalog.delete_table_if_exists(database=database, table=name) yield name @@ -573,28 +573,27 @@ def test_parquet_catalog_casting(bucket, database): assert wr.catalog.delete_table_if_exists(database=database, table="__test_parquet_catalog_casting") is True -def test_catalog(bucket, database): +def test_catalog(path, database, table): account_id = boto3.client("sts").get_caller_identity().get("Account") - path = f"s3://{bucket}/test_catalog/" - wr.catalog.delete_table_if_exists(database=database, table="test_catalog") - assert wr.catalog.does_table_exist(database=database, table="test_catalog") is False + assert wr.catalog.does_table_exist(database=database, table=table) is False wr.catalog.create_parquet_table( database=database, - table="test_catalog", + table=table, path=path, columns_types={"col0": "int", "col1": "double"}, partitions_types={"y": "int", "m": "int"}, compression="snappy", ) - wr.catalog.create_parquet_table( - database=database, table="test_catalog", path=path, columns_types={"col0": "string"}, mode="append" - ) - assert wr.catalog.does_table_exist(database=database, table="test_catalog") is True - assert wr.catalog.delete_table_if_exists(database=database, table="test_catalog") is True - assert wr.catalog.delete_table_if_exists(database=database, table="test_catalog") is False + with pytest.raises(wr.exceptions.InvalidArgumentValue): + wr.catalog.create_parquet_table( + database=database, table=table, path=path, columns_types={"col0": "string"}, mode="append" + ) + assert wr.catalog.does_table_exist(database=database, table=table) is True + assert wr.catalog.delete_table_if_exists(database=database, table=table) is True + assert wr.catalog.delete_table_if_exists(database=database, table=table) is False wr.catalog.create_parquet_table( database=database, - table="test_catalog", + table=table, path=path, columns_types={"col0": "int", "col1": "double"}, partitions_types={"y": "int", "m": "int"}, @@ -602,22 +601,23 @@ def test_catalog(bucket, database): description="Foo boo bar", parameters={"tag": "test"}, columns_comments={"col0": "my int", "y": "year"}, + mode="overwrite", ) wr.catalog.add_parquet_partitions( database=database, - table="test_catalog", + table=table, partitions_values={f"{path}y=2020/m=1/": ["2020", "1"], f"{path}y=2021/m=2/": ["2021", "2"]}, compression="snappy", ) - assert wr.catalog.get_table_location(database=database, table="test_catalog") == path - partitions_values = wr.catalog.get_parquet_partitions(database=database, table="test_catalog") + assert wr.catalog.get_table_location(database=database, table=table) == path + partitions_values = wr.catalog.get_parquet_partitions(database=database, table=table) assert len(partitions_values) == 2 partitions_values = wr.catalog.get_parquet_partitions( - database=database, table="test_catalog", catalog_id=account_id, expression="y = 2021 AND m = 2" + database=database, table=table, catalog_id=account_id, expression="y = 2021 AND m = 2" ) assert len(partitions_values) == 1 assert len(set(partitions_values[f"{path}y=2021/m=2/"]) & {"2021", "2"}) == 2 - dtypes = wr.catalog.get_table_types(database=database, table="test_catalog") + dtypes = wr.catalog.get_table_types(database=database, table=table) assert dtypes["col0"] == "int" assert dtypes["col1"] == "double" assert dtypes["y"] == "int" @@ -628,7 +628,7 @@ def test_catalog(bucket, database): tables = list(wr.catalog.get_tables()) assert len(tables) > 0 for tbl in tables: - if tbl["Name"] == "test_catalog": + if tbl["Name"] == table: assert tbl["TableType"] == "EXTERNAL_TABLE" tables = list(wr.catalog.get_tables(database=database)) assert len(tables) > 0 @@ -638,37 +638,41 @@ def test_catalog(bucket, database): tables = list(wr.catalog.search_tables(text="parquet", catalog_id=account_id)) assert len(tables) > 0 for tbl in tables: - if tbl["Name"] == "test_catalog": + if tbl["Name"] == table: assert tbl["TableType"] == "EXTERNAL_TABLE" # prefix - tables = list(wr.catalog.get_tables(name_prefix="test_cat", catalog_id=account_id)) + tables = list(wr.catalog.get_tables(name_prefix=table[:4], catalog_id=account_id)) assert len(tables) > 0 for tbl in tables: - if tbl["Name"] == "test_catalog": + if tbl["Name"] == table: assert tbl["TableType"] == "EXTERNAL_TABLE" # suffix - tables = list(wr.catalog.get_tables(name_suffix="_catalog", catalog_id=account_id)) + tables = list(wr.catalog.get_tables(name_suffix=table[-4:], catalog_id=account_id)) assert len(tables) > 0 for tbl in tables: - if tbl["Name"] == "test_catalog": + if tbl["Name"] == table: assert tbl["TableType"] == "EXTERNAL_TABLE" # name_contains - tables = list(wr.catalog.get_tables(name_contains="cat", catalog_id=account_id)) + tables = list(wr.catalog.get_tables(name_contains=table[4:-4], catalog_id=account_id)) assert len(tables) > 0 for tbl in tables: - if tbl["Name"] == "test_catalog": + if tbl["Name"] == table: assert tbl["TableType"] == "EXTERNAL_TABLE" # prefix & suffix & name_contains - tables = list(wr.catalog.get_tables(name_prefix="t", name_contains="_", name_suffix="g", catalog_id=account_id)) + tables = list( + wr.catalog.get_tables( + name_prefix=table[0], name_contains=table[3], name_suffix=table[-1], catalog_id=account_id + ) + ) assert len(tables) > 0 for tbl in tables: - if tbl["Name"] == "test_catalog": + if tbl["Name"] == table: assert tbl["TableType"] == "EXTERNAL_TABLE" # prefix & suffix - tables = list(wr.catalog.get_tables(name_prefix="t", name_suffix="g", catalog_id=account_id)) + tables = list(wr.catalog.get_tables(name_prefix=table[0], name_suffix=table[-1], catalog_id=account_id)) assert len(tables) > 0 for tbl in tables: - if tbl["Name"] == "test_catalog": + if tbl["Name"] == table: assert tbl["TableType"] == "EXTERNAL_TABLE" # DataFrames assert len(wr.catalog.databases().index) > 0 @@ -678,17 +682,18 @@ def test_catalog(bucket, database): wr.catalog.tables( database=database, search_text="parquet", - name_prefix="t", - name_contains="_", - name_suffix="g", + name_prefix=table[0], + name_contains=table[3], + name_suffix=table[-1], catalog_id=account_id, ).index ) > 0 ) - assert len(wr.catalog.table(database=database, table="test_catalog").index) > 0 - assert len(wr.catalog.table(database=database, table="test_catalog", catalog_id=account_id).index) > 0 - assert wr.catalog.delete_table_if_exists(database=database, table="test_catalog") is True + assert len(wr.catalog.table(database=database, table=table).index) > 0 + assert len(wr.catalog.table(database=database, table=table, catalog_id=account_id).index) > 0 + with pytest.raises(wr.exceptions.InvalidTable): + wr.catalog.overwrite_table_parameters({"foo": "boo"}, database, "fake_table") def test_s3_get_bucket_region(bucket, region): @@ -1118,10 +1123,7 @@ def test_csv_compress(bucket, compression): wr.s3.delete_objects(path=path) -def test_parquet_char_length(bucket, database, external_schema): - path = f"s3://{bucket}/test_parquet_char_length/" - table = "test_parquet_char_length" - +def test_parquet_char_length(path, database, table, external_schema): df = pd.DataFrame( {"id": [1, 2], "cchar": ["foo", "boo"], "date": [datetime.date(2020, 1, 1), datetime.date(2020, 1, 2)]} ) @@ -1152,9 +1154,6 @@ def test_parquet_char_length(bucket, database, external_schema): assert len(df2.columns) == 3 assert df2.id.sum() == 3 - wr.s3.delete_objects(path=path) - assert wr.catalog.delete_table_if_exists(database=database, table=table) is True - def test_merge(bucket): path = f"s3://{bucket}/test_merge/" @@ -1475,10 +1474,7 @@ def test_parquet_uint64(bucket): wr.s3.delete_objects(path=path) -def test_parquet_overwrite_partition_cols(bucket, database, external_schema): - table = "test_parquet_overwrite_partition_cols" - path = f"s3://{bucket}/{table}/" - wr.s3.delete_objects(path=path) +def test_parquet_overwrite_partition_cols(path, database, table, external_schema): df = pd.DataFrame({"c0": [1, 2, 1, 2], "c1": [1, 2, 1, 2], "c2": [2, 1, 2, 1]}) paths = wr.s3.to_parquet( @@ -1511,9 +1507,6 @@ def test_parquet_overwrite_partition_cols(bucket, database, external_schema): assert df.c1.sum() == 6 assert df.c2.sum() == 6 - wr.s3.delete_objects(path=path) - wr.catalog.delete_table_if_exists(database=database, table=table) - def test_catalog_parameters(bucket, database): table = "test_catalog_parameters" @@ -1695,3 +1688,239 @@ def test_to_parquet_file_sanitize(path): assert df2.c0.sum() == 1 assert df2.camel_case.sum() == 5 assert df2.c_2.sum() == 9 + + +def test_to_parquet_modes(database, table, path, external_schema): + + # Round 1 - Warm up + df = pd.DataFrame({"c0": [0, None]}, dtype="Int64") + paths = wr.s3.to_parquet( + df=df, + path=path, + dataset=True, + mode="overwrite", + database=database, + table=table, + description="c0", + parameters={"num_cols": str(len(df.columns)), "num_rows": str(len(df.index))}, + columns_comments={"c0": "0"}, + )["paths"] + wr.s3.wait_objects_exist(paths=paths) + df2 = wr.athena.read_sql_table(table, database) + assert df.shape == df2.shape + assert df.c0.sum() == df2.c0.sum() + parameters = wr.catalog.get_table_parameters(database, table) + assert len(parameters) == 5 + assert parameters["num_cols"] == str(len(df2.columns)) + assert parameters["num_rows"] == str(len(df2.index)) + assert wr.catalog.get_table_description(database, table) == "c0" + comments = wr.catalog.get_columns_comments(database, table) + assert len(comments) == len(df.columns) + assert comments["c0"] == "0" + + # Round 2 - Overwrite + df = pd.DataFrame({"c1": [None, 1, None]}, dtype="Int16") + paths = wr.s3.to_parquet( + df=df, + path=path, + dataset=True, + mode="overwrite", + database=database, + table=table, + description="c1", + parameters={"num_cols": str(len(df.columns)), "num_rows": str(len(df.index))}, + columns_comments={"c1": "1"}, + )["paths"] + wr.s3.wait_objects_exist(paths=paths) + df2 = wr.athena.read_sql_table(table, database) + assert df.shape == df2.shape + assert df.c1.sum() == df2.c1.sum() + parameters = wr.catalog.get_table_parameters(database, table) + assert len(parameters) == 5 + assert parameters["num_cols"] == str(len(df2.columns)) + assert parameters["num_rows"] == str(len(df2.index)) + assert wr.catalog.get_table_description(database, table) == "c1" + comments = wr.catalog.get_columns_comments(database, table) + assert len(comments) == len(df.columns) + assert comments["c1"] == "1" + + # Round 3 - Append + df = pd.DataFrame({"c1": [None, 2, None]}, dtype="Int8") + paths = wr.s3.to_parquet( + df=df, + path=path, + dataset=True, + mode="append", + database=database, + table=table, + description="c1", + parameters={"num_cols": str(len(df.columns)), "num_rows": str(len(df.index) * 2)}, + columns_comments={"c1": "1"}, + )["paths"] + wr.s3.wait_objects_exist(paths=paths) + df2 = wr.athena.read_sql_table(table, database) + assert len(df.columns) == len(df2.columns) + assert len(df.index) * 2 == len(df2.index) + assert df.c1.sum() + 1 == df2.c1.sum() + parameters = wr.catalog.get_table_parameters(database, table) + assert len(parameters) == 5 + assert parameters["num_cols"] == str(len(df2.columns)) + assert parameters["num_rows"] == str(len(df2.index)) + assert wr.catalog.get_table_description(database, table) == "c1" + comments = wr.catalog.get_columns_comments(database, table) + assert len(comments) == len(df.columns) + assert comments["c1"] == "1" + + # Round 4 - Append + New Column + df = pd.DataFrame({"c2": ["a", None, "b"], "c1": [None, None, None]}) + paths = wr.s3.to_parquet( + df=df, + path=path, + dataset=True, + mode="append", + database=database, + table=table, + description="c1+c2", + parameters={"num_cols": "2", "num_rows": "9"}, + columns_comments={"c1": "1", "c2": "2"}, + )["paths"] + wr.s3.wait_objects_exist(paths=paths) + df2 = wr.athena.read_sql_table(table, database) + assert len(df2.columns) == 2 + assert len(df2.index) == 9 + assert df2.c1.sum() == 3 + parameters = wr.catalog.get_table_parameters(database, table) + assert len(parameters) == 5 + assert parameters["num_cols"] == "2" + assert parameters["num_rows"] == "9" + assert wr.catalog.get_table_description(database, table) == "c1+c2" + comments = wr.catalog.get_columns_comments(database, table) + assert len(comments) == len(df.columns) + assert comments["c1"] == "1" + assert comments["c2"] == "2" + + # Round 5 - Append + New Column + Wrong Types + df = pd.DataFrame({"c2": [1], "c3": [True], "c1": ["1"]}) + paths = wr.s3.to_parquet( + df=df, + path=path, + dataset=True, + mode="append", + database=database, + table=table, + description="c1+c2+c3", + parameters={"num_cols": "3", "num_rows": "10"}, + columns_comments={"c1": "1!", "c2": "2!", "c3": "3"}, + )["paths"] + wr.s3.wait_objects_exist(paths=paths) + df2 = wr.athena.read_sql_table(table, database) + assert len(df2.columns) == 3 + assert len(df2.index) == 10 + assert df2.c1.sum() == 4 + parameters = wr.catalog.get_table_parameters(database, table) + assert len(parameters) == 5 + assert parameters["num_cols"] == "3" + assert parameters["num_rows"] == "10" + assert wr.catalog.get_table_description(database, table) == "c1+c2+c3" + comments = wr.catalog.get_columns_comments(database, table) + assert len(comments) == len(df.columns) + assert comments["c1"] == "1!" + assert comments["c2"] == "2!" + assert comments["c3"] == "3" + engine = wr.catalog.get_engine("aws-data-wrangler-redshift") + df3 = wr.db.read_sql_table(con=engine, table=table, schema=external_schema) + assert len(df3.columns) == 3 + assert len(df3.index) == 10 + assert df3.c1.sum() == 4 + + # Round 6 - Overwrite Partitioned + df = pd.DataFrame({"c0": ["foo", None], "c1": [0, 1]}) + paths = wr.s3.to_parquet( + df=df, + path=path, + dataset=True, + mode="overwrite", + database=database, + table=table, + partition_cols=["c1"], + description="c0+c1", + parameters={"num_cols": "2", "num_rows": "2"}, + columns_comments={"c0": "zero", "c1": "one"}, + )["paths"] + wr.s3.wait_objects_exist(paths=paths) + df2 = wr.athena.read_sql_table(table, database) + assert df.shape == df2.shape + assert df.c1.sum() == df2.c1.sum() + parameters = wr.catalog.get_table_parameters(database, table) + assert len(parameters) == 5 + assert parameters["num_cols"] == "2" + assert parameters["num_rows"] == "2" + assert wr.catalog.get_table_description(database, table) == "c0+c1" + comments = wr.catalog.get_columns_comments(database, table) + assert len(comments) == len(df.columns) + assert comments["c0"] == "zero" + assert comments["c1"] == "one" + + # Round 7 - Overwrite Partitions + df = pd.DataFrame({"c0": [None, None], "c1": [0, 2]}) + paths = wr.s3.to_parquet( + df=df, + path=path, + dataset=True, + mode="overwrite_partitions", + database=database, + table=table, + partition_cols=["c1"], + description="c0+c1", + parameters={"num_cols": "2", "num_rows": "3"}, + columns_comments={"c0": "zero", "c1": "one"}, + )["paths"] + wr.s3.wait_objects_exist(paths=paths) + df2 = wr.athena.read_sql_table(table, database) + assert len(df2.columns) == 2 + assert len(df2.index) == 3 + assert df2.c1.sum() == 3 + parameters = wr.catalog.get_table_parameters(database, table) + assert len(parameters) == 5 + assert parameters["num_cols"] == "2" + assert parameters["num_rows"] == "3" + assert wr.catalog.get_table_description(database, table) == "c0+c1" + comments = wr.catalog.get_columns_comments(database, table) + assert len(comments) == len(df.columns) + assert comments["c0"] == "zero" + assert comments["c1"] == "one" + + # Round 8 - Overwrite Partitions + New Column + Wrong Type + df = pd.DataFrame({"c0": [1, 2], "c1": ["1", "3"], "c2": [True, False]}) + paths = wr.s3.to_parquet( + df=df, + path=path, + dataset=True, + mode="overwrite_partitions", + database=database, + table=table, + partition_cols=["c1"], + description="c0+c1+c2", + parameters={"num_cols": "3", "num_rows": "4"}, + columns_comments={"c0": "zero", "c1": "one", "c2": "two"}, + )["paths"] + wr.s3.wait_objects_exist(paths=paths) + df2 = wr.athena.read_sql_table(table, database) + assert len(df2.columns) == 3 + assert len(df2.index) == 4 + assert df2.c1.sum() == 6 + parameters = wr.catalog.get_table_parameters(database, table) + assert len(parameters) == 5 + assert parameters["num_cols"] == "3" + assert parameters["num_rows"] == "4" + assert wr.catalog.get_table_description(database, table) == "c0+c1+c2" + comments = wr.catalog.get_columns_comments(database, table) + assert len(comments) == len(df.columns) + assert comments["c0"] == "zero" + assert comments["c1"] == "one" + assert comments["c2"] == "two" + engine = wr.catalog.get_engine("aws-data-wrangler-redshift") + df3 = wr.db.read_sql_table(con=engine, table=table, schema=external_schema) + assert len(df3.columns) == 3 + assert len(df3.index) == 4 + assert df3.c1.sum() == 6 diff --git a/tox.ini b/tox.ini index 018d4cd87..05764a50f 100644 --- a/tox.ini +++ b/tox.ini @@ -5,13 +5,14 @@ envlist = py{37,38,36} deps = pytest pytest-xdist + pytest-timeout moto commands = - pytest -n 8 testing/test_awswrangler + pytest --timeout=600 -n 8 testing/test_awswrangler [testenv:py36] deps = {[testenv]deps} pytest-cov commands = - pytest --cov=awswrangler -n 8 testing/test_awswrangler + pytest --timeout=600 --cov=awswrangler -n 8 testing/test_awswrangler From 6604c066b487f2ff7cff6343a9a9d49736d50724 Mon Sep 17 00:00:00 2001 From: igorborgest Date: Tue, 19 May 2020 00:23:25 -0300 Subject: [PATCH 2/2] Add test_store_parquet_metadata_modes() --- testing/test_awswrangler/test_data_lake.py | 204 +++++++++++++++++++++ 1 file changed, 204 insertions(+) diff --git a/testing/test_awswrangler/test_data_lake.py b/testing/test_awswrangler/test_data_lake.py index 34e1adf29..e4f733876 100644 --- a/testing/test_awswrangler/test_data_lake.py +++ b/testing/test_awswrangler/test_data_lake.py @@ -1924,3 +1924,207 @@ def test_to_parquet_modes(database, table, path, external_schema): assert len(df3.columns) == 3 assert len(df3.index) == 4 assert df3.c1.sum() == 6 + + +def test_store_parquet_metadata_modes(database, table, path, external_schema): + + # Round 1 - Warm up + df = pd.DataFrame({"c0": [0, None]}, dtype="Int64") + paths = wr.s3.to_parquet(df=df, path=path, dataset=True, mode="overwrite")["paths"] + wr.s3.wait_objects_exist(paths=paths) + wr.s3.store_parquet_metadata( + path=path, + dataset=True, + mode="overwrite", + database=database, + table=table, + description="c0", + parameters={"num_cols": str(len(df.columns)), "num_rows": str(len(df.index))}, + columns_comments={"c0": "0"}, + ) + df2 = wr.athena.read_sql_table(table, database) + assert df.shape == df2.shape + assert df.c0.sum() == df2.c0.sum() + parameters = wr.catalog.get_table_parameters(database, table) + assert len(parameters) == 5 + assert parameters["num_cols"] == str(len(df2.columns)) + assert parameters["num_rows"] == str(len(df2.index)) + assert wr.catalog.get_table_description(database, table) == "c0" + comments = wr.catalog.get_columns_comments(database, table) + assert len(comments) == len(df.columns) + assert comments["c0"] == "0" + + # Round 2 - Overwrite + df = pd.DataFrame({"c1": [None, 1, None]}, dtype="Int16") + paths = wr.s3.to_parquet(df=df, path=path, dataset=True, mode="overwrite")["paths"] + wr.s3.wait_objects_exist(paths=paths) + wr.s3.store_parquet_metadata( + path=path, + dataset=True, + mode="overwrite", + database=database, + table=table, + description="c1", + parameters={"num_cols": str(len(df.columns)), "num_rows": str(len(df.index))}, + columns_comments={"c1": "1"}, + ) + df2 = wr.athena.read_sql_table(table, database) + assert df.shape == df2.shape + assert df.c1.sum() == df2.c1.sum() + parameters = wr.catalog.get_table_parameters(database, table) + assert len(parameters) == 5 + assert parameters["num_cols"] == str(len(df2.columns)) + assert parameters["num_rows"] == str(len(df2.index)) + assert wr.catalog.get_table_description(database, table) == "c1" + comments = wr.catalog.get_columns_comments(database, table) + assert len(comments) == len(df.columns) + assert comments["c1"] == "1" + + # Round 3 - Append + df = pd.DataFrame({"c1": [None, 2, None]}, dtype="Int16") + paths = wr.s3.to_parquet(df=df, path=path, dataset=True, mode="append")["paths"] + wr.s3.wait_objects_exist(paths=paths) + wr.s3.store_parquet_metadata( + path=path, + dataset=True, + mode="append", + database=database, + table=table, + description="c1", + parameters={"num_cols": str(len(df.columns)), "num_rows": str(len(df.index) * 2)}, + columns_comments={"c1": "1"}, + ) + df2 = wr.athena.read_sql_table(table, database) + assert len(df.columns) == len(df2.columns) + assert len(df.index) * 2 == len(df2.index) + assert df.c1.sum() + 1 == df2.c1.sum() + parameters = wr.catalog.get_table_parameters(database, table) + assert len(parameters) == 5 + assert parameters["num_cols"] == str(len(df2.columns)) + assert parameters["num_rows"] == str(len(df2.index)) + assert wr.catalog.get_table_description(database, table) == "c1" + comments = wr.catalog.get_columns_comments(database, table) + assert len(comments) == len(df.columns) + assert comments["c1"] == "1" + + # Round 4 - Append + New Column + df = pd.DataFrame({"c2": ["a", None, "b"], "c1": [None, 1, None]}) + df["c1"] = df["c1"].astype("Int16") + paths = wr.s3.to_parquet(df=df, path=path, dataset=True, mode="append")["paths"] + wr.s3.wait_objects_exist(paths=paths) + wr.s3.store_parquet_metadata( + path=path, + dataset=True, + mode="append", + database=database, + table=table, + description="c1+c2", + parameters={"num_cols": "2", "num_rows": "9"}, + columns_comments={"c1": "1", "c2": "2"}, + ) + df2 = wr.athena.read_sql_table(table, database) + assert len(df2.columns) == 2 + assert len(df2.index) == 9 + assert df2.c1.sum() == 4 + parameters = wr.catalog.get_table_parameters(database, table) + assert len(parameters) == 5 + assert parameters["num_cols"] == "2" + assert parameters["num_rows"] == "9" + assert wr.catalog.get_table_description(database, table) == "c1+c2" + comments = wr.catalog.get_columns_comments(database, table) + assert len(comments) == len(df.columns) + assert comments["c1"] == "1" + assert comments["c2"] == "2" + + # Round 5 - Overwrite Partitioned + df = pd.DataFrame({"c0": ["foo", None], "c1": [0, 1]}) + paths = wr.s3.to_parquet(df=df, path=path, dataset=True, mode="overwrite", partition_cols=["c1"])["paths"] + wr.s3.wait_objects_exist(paths=paths) + wr.s3.store_parquet_metadata( + path=path, + dataset=True, + mode="overwrite", + database=database, + table=table, + description="c0+c1", + parameters={"num_cols": "2", "num_rows": "2"}, + columns_comments={"c0": "zero", "c1": "one"}, + ) + df2 = wr.athena.read_sql_table(table, database) + assert df.shape == df2.shape + assert df.c1.sum() == df2.c1.astype(int).sum() + parameters = wr.catalog.get_table_parameters(database, table) + assert len(parameters) == 5 + assert parameters["num_cols"] == "2" + assert parameters["num_rows"] == "2" + assert wr.catalog.get_table_description(database, table) == "c0+c1" + comments = wr.catalog.get_columns_comments(database, table) + assert len(comments) == len(df.columns) + assert comments["c0"] == "zero" + assert comments["c1"] == "one" + + # Round 6 - Overwrite Partitions + df = pd.DataFrame({"c0": [None, "boo"], "c1": [0, 2]}) + paths = wr.s3.to_parquet(df=df, path=path, dataset=True, mode="overwrite_partitions", partition_cols=["c1"])[ + "paths" + ] + wr.s3.wait_objects_exist(paths=paths) + wr.s3.store_parquet_metadata( + path=path, + dataset=True, + mode="append", + database=database, + table=table, + description="c0+c1", + parameters={"num_cols": "2", "num_rows": "3"}, + columns_comments={"c0": "zero", "c1": "one"}, + ) + df2 = wr.athena.read_sql_table(table, database) + assert len(df2.columns) == 2 + assert len(df2.index) == 3 + assert df2.c1.astype(int).sum() == 3 + parameters = wr.catalog.get_table_parameters(database, table) + assert len(parameters) == 5 + assert parameters["num_cols"] == "2" + assert parameters["num_rows"] == "3" + assert wr.catalog.get_table_description(database, table) == "c0+c1" + comments = wr.catalog.get_columns_comments(database, table) + assert len(comments) == len(df.columns) + assert comments["c0"] == "zero" + assert comments["c1"] == "one" + + # Round 7 - Overwrite Partitions + New Column + df = pd.DataFrame({"c0": ["bar", None], "c1": [1, 3], "c2": [True, False]}) + paths = wr.s3.to_parquet(df=df, path=path, dataset=True, mode="overwrite_partitions", partition_cols=["c1"])[ + "paths" + ] + wr.s3.wait_objects_exist(paths=paths) + wr.s3.store_parquet_metadata( + path=path, + dataset=True, + mode="append", + database=database, + table=table, + description="c0+c1+c2", + parameters={"num_cols": "3", "num_rows": "4"}, + columns_comments={"c0": "zero", "c1": "one", "c2": "two"}, + ) + df2 = wr.athena.read_sql_table(table, database) + assert len(df2.columns) == 3 + assert len(df2.index) == 4 + assert df2.c1.astype(int).sum() == 6 + parameters = wr.catalog.get_table_parameters(database, table) + assert len(parameters) == 5 + assert parameters["num_cols"] == "3" + assert parameters["num_rows"] == "4" + assert wr.catalog.get_table_description(database, table) == "c0+c1+c2" + comments = wr.catalog.get_columns_comments(database, table) + assert len(comments) == len(df.columns) + assert comments["c0"] == "zero" + assert comments["c1"] == "one" + assert comments["c2"] == "two" + engine = wr.catalog.get_engine("aws-data-wrangler-redshift") + df3 = wr.db.read_sql_table(con=engine, table=table, schema=external_schema) + assert len(df3.columns) == 3 + assert len(df3.index) == 4 + assert df3.c1.astype(int).sum() == 6