diff --git a/src/sagemaker/feature_store/dataset_builder.py b/src/sagemaker/feature_store/dataset_builder.py index 1e5eb6d39e..12d2de5b9d 100644 --- a/src/sagemaker/feature_store/dataset_builder.py +++ b/src/sagemaker/feature_store/dataset_builder.py @@ -21,12 +21,15 @@ import os from typing import Any, Dict, List, Tuple, Union +import logging + import attr import pandas as pd from sagemaker import Session, s3, utils from sagemaker.feature_store.feature_group import FeatureDefinition, FeatureGroup, FeatureTypeEnum +logger = logging.getLogger(__name__) _DEFAULT_CATALOG = "AwsDataCatalog" _DEFAULT_DATABASE = "sagemaker_featurestore" @@ -185,6 +188,9 @@ class DatasetBuilder: when calling "to_dataframe" should include duplicated records (default: False). _include_deleted_records (bool): A boolean representing whether the resulting dataframe when calling "to_dataframe" should include deleted records (default: False). + _cleanup_temporary_tables (bool): A boolean representing whether temporary tables are + cleaned up after calling to_dataframe when a dataframe is the base. This requires + Glue:DeleteTable permissions in your execution role. (default: False) _number_of_recent_records (int): An integer representing how many records will be returned for each record identifier (default: 1). _number_of_records (int): An integer representing the number of records that should be @@ -215,6 +221,7 @@ class DatasetBuilder: _point_in_time_accurate_join: bool = attr.ib(init=False, default=False) _include_duplicated_records: bool = attr.ib(init=False, default=False) _include_deleted_records: bool = attr.ib(init=False, default=False) + _cleanup_temporary_tables: bool = attr.ib(init=False, default=False) _number_of_recent_records: int = attr.ib(init=False, default=None) _number_of_records: int = attr.ib(init=False, default=None) _write_time_ending_timestamp: datetime.datetime = attr.ib(init=False, default=None) @@ -282,6 +289,16 @@ def include_deleted_records(self): self._include_deleted_records = True return self + def cleanup_temporary_tables(self): + """Cleans up temporary tables when calling to_dataframe with a dataframe as the base + feature group. + + Returns: + This DatasetBuilder object. + """ + self._cleanup_temporary_tables = True + return self + def with_number_of_recent_records_by_record_identifier(self, number_of_recent_records: int): """Set number_of_recent_records field with provided input. @@ -384,10 +401,15 @@ def to_csv_file(self) -> Tuple[str, str]: ) ) query_result = self._run_query(query_string, _DEFAULT_CATALOG, _DEFAULT_DATABASE) - # TODO: cleanup temp table, need more clarification, keep it for now - return query_result.get("QueryExecution", {}).get("ResultConfiguration", {}).get( + + res = query_result.get("QueryExecution", {}).get("ResultConfiguration", {}).get( "OutputLocation", None ), query_result.get("QueryExecution", {}).get("Query", None) + + if self._cleanup_temporary_tables is True: + self._drop_temp_table(temp_table_name) + + return res if isinstance(self._base, FeatureGroup): base_feature_group = construct_feature_group_to_be_merged( self._base, self._included_feature_names @@ -953,6 +975,22 @@ def _create_temp_table(self, temp_table_name: str, desired_s3_folder: str): ) self._run_query(query_string, _DEFAULT_CATALOG, _DEFAULT_DATABASE) + def _drop_temp_table(self, temp_table_name: str): + """Internal method to drop a temp Athena table for the base pandas.Dataframe. + + Args: + temp_table_name (str): The Athena table name of base pandas.DataFrame. + database (str): The database to run the query against + """ + query_string = ( + f"DROP TABLE `{_DEFAULT_DATABASE}.{temp_table_name}`" + ) + + try: + self._run_query(query_string, _DEFAULT_CATALOG, _DEFAULT_DATABASE) + except Exception as e: # pylint: disable=broad-except + logger.debug("The temporary table was unsuccessfully cleaned up. %s", e) + def _construct_athena_table_column_string(self, column: str) -> str: """Internal method for constructing string of Athena column. diff --git a/tests/unit/sagemaker/feature_store/test_dataset_builder.py b/tests/unit/sagemaker/feature_store/test_dataset_builder.py index 0e55b86bd0..936438c380 100644 --- a/tests/unit/sagemaker/feature_store/test_dataset_builder.py +++ b/tests/unit/sagemaker/feature_store/test_dataset_builder.py @@ -553,6 +553,18 @@ def test_create_temp_table(sagemaker_session_mock): kms_key=None, ) +def test_cleanup_temporary_tables(sagemaker_session_mock): + dataframe = pd.DataFrame({"feature-1": [420, 380, 390], "feature-2": [50, 40, 45]}) + + dataset_builder = DatasetBuilder( + sagemaker_session=sagemaker_session_mock, + base=dataframe, + output_path="file/to/path", + ) + + dataset_builder.cleanup_temporary_tables() + + assert dataset_builder._cleanup_temporary_tables is True @pytest.mark.parametrize( "column, expected",