Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 40 additions & 2 deletions src/sagemaker/feature_store/dataset_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,15 @@
import os
from typing import Any, Dict, List, Tuple, Union

import logging

import attr
import pandas as pd

from sagemaker import Session, s3, utils
from sagemaker.feature_store.feature_group import FeatureDefinition, FeatureGroup, FeatureTypeEnum

logger = logging.getLogger(__name__)

_DEFAULT_CATALOG = "AwsDataCatalog"
_DEFAULT_DATABASE = "sagemaker_featurestore"
Expand Down Expand Up @@ -185,6 +188,9 @@ class DatasetBuilder:
when calling "to_dataframe" should include duplicated records (default: False).
_include_deleted_records (bool): A boolean representing whether the resulting
dataframe when calling "to_dataframe" should include deleted records (default: False).
_cleanup_temporary_tables (bool): A boolean representing whether temporary tables are
cleaned up after calling to_dataframe when a dataframe is the base. This requires
Glue:DeleteTable permissions in your execution role. (default: False)
_number_of_recent_records (int): An integer representing how many records will be
returned for each record identifier (default: 1).
_number_of_records (int): An integer representing the number of records that should be
Expand Down Expand Up @@ -215,6 +221,7 @@ class DatasetBuilder:
_point_in_time_accurate_join: bool = attr.ib(init=False, default=False)
_include_duplicated_records: bool = attr.ib(init=False, default=False)
_include_deleted_records: bool = attr.ib(init=False, default=False)
_cleanup_temporary_tables: bool = attr.ib(init=False, default=False)
_number_of_recent_records: int = attr.ib(init=False, default=None)
_number_of_records: int = attr.ib(init=False, default=None)
_write_time_ending_timestamp: datetime.datetime = attr.ib(init=False, default=None)
Expand Down Expand Up @@ -282,6 +289,16 @@ def include_deleted_records(self):
self._include_deleted_records = True
return self

def cleanup_temporary_tables(self):
"""Cleans up temporary tables when calling to_dataframe with a dataframe as the base
feature group.

Returns:
This DatasetBuilder object.
"""
self._cleanup_temporary_tables = True
return self

def with_number_of_recent_records_by_record_identifier(self, number_of_recent_records: int):
"""Set number_of_recent_records field with provided input.

Expand Down Expand Up @@ -384,10 +401,15 @@ def to_csv_file(self) -> Tuple[str, str]:
)
)
query_result = self._run_query(query_string, _DEFAULT_CATALOG, _DEFAULT_DATABASE)
# TODO: cleanup temp table, need more clarification, keep it for now
return query_result.get("QueryExecution", {}).get("ResultConfiguration", {}).get(

res = query_result.get("QueryExecution", {}).get("ResultConfiguration", {}).get(
"OutputLocation", None
), query_result.get("QueryExecution", {}).get("Query", None)

if self._cleanup_temporary_tables is True:
self._drop_temp_table(temp_table_name)

return res
if isinstance(self._base, FeatureGroup):
base_feature_group = construct_feature_group_to_be_merged(
self._base, self._included_feature_names
Expand Down Expand Up @@ -953,6 +975,22 @@ def _create_temp_table(self, temp_table_name: str, desired_s3_folder: str):
)
self._run_query(query_string, _DEFAULT_CATALOG, _DEFAULT_DATABASE)

def _drop_temp_table(self, temp_table_name: str):
"""Internal method to drop a temp Athena table for the base pandas.Dataframe.

Args:
temp_table_name (str): The Athena table name of base pandas.DataFrame.
database (str): The database to run the query against
"""
query_string = (
f"DROP TABLE `{_DEFAULT_DATABASE}.{temp_table_name}`"
)

try:
self._run_query(query_string, _DEFAULT_CATALOG, _DEFAULT_DATABASE)
except Exception as e: # pylint: disable=broad-except
logger.debug("The temporary table was unsuccessfully cleaned up. %s", e)

def _construct_athena_table_column_string(self, column: str) -> str:
"""Internal method for constructing string of Athena column.

Expand Down
12 changes: 12 additions & 0 deletions tests/unit/sagemaker/feature_store/test_dataset_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,6 +553,18 @@ def test_create_temp_table(sagemaker_session_mock):
kms_key=None,
)

def test_cleanup_temporary_tables(sagemaker_session_mock):
dataframe = pd.DataFrame({"feature-1": [420, 380, 390], "feature-2": [50, 40, 45]})

dataset_builder = DatasetBuilder(
sagemaker_session=sagemaker_session_mock,
base=dataframe,
output_path="file/to/path",
)

dataset_builder.cleanup_temporary_tables()

assert dataset_builder._cleanup_temporary_tables is True

@pytest.mark.parametrize(
"column, expected",
Expand Down