Skip to content
61 changes: 61 additions & 0 deletions src/sagemaker/feature_store/feature_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
OfflineStoreConfig,
DataCatalogConfig,
FeatureValue,
FeatureParameter,
)

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -537,6 +538,66 @@ def describe(self, next_token: str = None) -> Dict[str, Any]:
feature_group_name=self.name, next_token=next_token
)

def update(self, feature_additions: Sequence[FeatureDefinition]) -> Dict[str, Any]:
"""Update a FeatureGroup and add new features from the given feature definitions.

Args:
feature_additions (Sequence[Dict[str, str]): list of feature definitions to be updated.

Returns:
Response dict from service.
"""

return self.sagemaker_session.update_feature_group(
feature_group_name=self.name,
feature_additions=[
feature_addition.to_dict() for feature_addition in feature_additions
],
)

def update_feature_metadata(
self,
feature_name: str,
description: str = None,
parameter_additions: Sequence[FeatureParameter] = None,
parameter_removals: Sequence[str] = None,
) -> Dict[str, Any]:
"""Update a feature metadata and add/remove metadata.

Args:
feature_name (str): name of the feature to update.
description (str): description of the feature to update.
parameter_additions (Sequence[Dict[str, str]): list of feature parameter to be added.
parameter_removals (Sequence[str]): list of feature parameter key to be removed.

Returns:
Response dict from service.
"""
return self.sagemaker_session.update_feature_metadata(
feature_group_name=self.name,
feature_name=feature_name,
description=description,
parameter_additions=[
parameter_addition.to_dict() for parameter_addition in parameter_additions
]
if parameter_additions is not None
else [],
parameter_removals=parameter_removals if parameter_removals is not None else [],
)

def describe_feature_metadata(self, feature_name: str) -> Dict[str, Any]:
"""Describe feature metadata by feature name.

Args:
feature_name (str): name of the feature.
Returns:
Response dict from service.
"""

return self.sagemaker_session.describe_feature_metadata(
feature_group_name=self.name, feature_name=feature_name
)

def load_feature_definitions(
self,
data_frame: DataFrame,
Expand Down
24 changes: 24 additions & 0 deletions src/sagemaker/feature_store/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,3 +207,27 @@ def to_dict(self) -> Dict[str, Any]:
FeatureName=self.feature_name,
ValueAsString=self.value_as_string,
)


@attr.s
class FeatureParameter(Config):
"""FeatureParameter for FeatureStore.

Attributes:
key (str): key of the parameter.
value (str): value of the parameter.
"""

key: str = attr.ib(default=None)
value: str = attr.ib(default=None)

def to_dict(self) -> Dict[str, Any]:
"""Construct a dictionary based on the attributes provided.

Returns:
dict represents the attributes.
"""
return Config.construct_dict(
Key=self.key,
Value=self.value,
)
68 changes: 67 additions & 1 deletion src/sagemaker/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -4076,7 +4076,7 @@ def describe_feature_group(
"""Describe a FeatureGroup by name in FeatureStore service.

Args:
feature_group_name (str): name of the FeatureGroup to descibe.
feature_group_name (str): name of the FeatureGroup to describe.
next_token (str): next_token to get next page of features.
Returns:
Response dict from service.
Expand All @@ -4086,6 +4086,72 @@ def describe_feature_group(
update_args(kwargs, NextToken=next_token)
return self.sagemaker_client.describe_feature_group(**kwargs)

def update_feature_group(
self, feature_group_name: str, feature_additions: Sequence[Dict[str, str]]
) -> Dict[str, Any]:
"""Update a FeatureGroup and add new features from the given feature definitions.

Args:
feature_group_name (str): name of the FeatureGroup to update.
feature_additions (Sequence[Dict[str, str]): list of feature definitions to be updated.
Returns:
Response dict from service.
"""

return self.sagemaker_client.update_feature_group(
FeatureGroupName=feature_group_name, FeatureAdditions=feature_additions
)

def update_feature_metadata(
self,
feature_group_name: str,
feature_name: str,
description: str = None,
parameter_additions: Sequence[Dict[str, str]] = None,
parameter_removals: Sequence[str] = None,
) -> Dict[str, Any]:
"""Update a feature metadata and add/remove metadata.

Args:
feature_group_name (str): name of the FeatureGroup to update.
feature_name (str): name of the feature to update.
description (str): description of the feature to update.
parameter_additions (Sequence[Dict[str, str]): list of feature parameter to be added.
parameter_removals (Sequence[Dict[str, str]): list of feature parameter to be removed.
Returns:
Response dict from service.
"""

request = {
"FeatureGroupName": feature_group_name,
"FeatureName": feature_name,
}

if description is not None:
request["Description"] = description
if parameter_additions is not None:
request["ParameterAdditions"] = parameter_additions
if parameter_removals is not None:
request["ParameterRemovals"] = parameter_removals

return self.sagemaker_client.update_feature_metadata(**request)

def describe_feature_metadata(
self, feature_group_name: str, feature_name: str
) -> Dict[str, Any]:
"""Describe feature metadata by feature name in FeatureStore service.

Args:
feature_group_name (str): name of the FeatureGroup.
feature_name (str): name of the feature.
Returns:
Response dict from service.
"""

return self.sagemaker_client.describe_feature_metadata(
FeatureGroupName=feature_group_name, FeatureName=feature_name
)

def put_record(
self,
feature_group_name: str,
Expand Down
92 changes: 91 additions & 1 deletion tests/integ/test_feature_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,9 @@
import pytest
from pandas import DataFrame

from sagemaker.feature_store.feature_definition import FractionalFeatureDefinition
from sagemaker.feature_store.feature_group import FeatureGroup
from sagemaker.feature_store.inputs import FeatureValue
from sagemaker.feature_store.inputs import FeatureValue, FeatureParameter
from sagemaker.session import get_execution_role, Session
from tests.integ.timeout import timeout

Expand Down Expand Up @@ -237,6 +238,83 @@ def test_create_feature_store(
assert output["FeatureGroupArn"].endswith(f"feature-group/{feature_group_name}")


def test_update_feature_group(
feature_store_session,
role,
feature_group_name,
offline_store_s3_uri,
pandas_data_frame,
):
feature_group = FeatureGroup(name=feature_group_name, sagemaker_session=feature_store_session)
feature_group.load_feature_definitions(data_frame=pandas_data_frame)

with cleanup_feature_group(feature_group):
feature_group.create(
s3_uri=offline_store_s3_uri,
record_identifier_name="feature1",
event_time_feature_name="feature3",
role_arn=role,
enable_online_store=True,
)
_wait_for_feature_group_create(feature_group)

new_feature_name = "new_feature"
new_features = [FractionalFeatureDefinition(feature_name=new_feature_name)]
feature_group.update(new_features)
_wait_for_feature_group_update(feature_group)
feature_definitions = feature_group.describe().get("FeatureDefinitions")
assert any([True for elem in feature_definitions if new_feature_name in elem.values()])


def test_feature_metadata(
feature_store_session,
role,
feature_group_name,
offline_store_s3_uri,
pandas_data_frame,
):
feature_group = FeatureGroup(name=feature_group_name, sagemaker_session=feature_store_session)
feature_group.load_feature_definitions(data_frame=pandas_data_frame)

with cleanup_feature_group(feature_group):
feature_group.create(
s3_uri=offline_store_s3_uri,
record_identifier_name="feature1",
event_time_feature_name="feature3",
role_arn=role,
enable_online_store=True,
)
_wait_for_feature_group_create(feature_group)

parameter_additions = [
FeatureParameter(key="key1", value="value1"),
FeatureParameter(key="key2", value="value2"),
]
description = "test description"
feature_name = "feature1"
feature_group.update_feature_metadata(
feature_name=feature_name,
description=description,
parameter_additions=parameter_additions,
)
describe_feature_metadata = feature_group.describe_feature_metadata(
feature_name=feature_name
)
print(describe_feature_metadata)
assert description == describe_feature_metadata.get("Description")
assert 2 == len(describe_feature_metadata.get("Parameters"))

parameter_removals = ["key1"]
feature_group.update_feature_metadata(
feature_name=feature_name, parameter_removals=parameter_removals
)
describe_feature_metadata = feature_group.describe_feature_metadata(
feature_name=feature_name
)
assert description == describe_feature_metadata.get("Description")
assert 1 == len(describe_feature_metadata.get("Parameters"))


def test_ingest_without_string_feature(
feature_store_session,
role,
Expand Down Expand Up @@ -304,6 +382,18 @@ def _wait_for_feature_group_create(feature_group: FeatureGroup):
print(f"FeatureGroup {feature_group.name} successfully created.")


def _wait_for_feature_group_update(feature_group: FeatureGroup):
status = feature_group.describe().get("LastUpdateStatus").get("Status")
while status == "InProgress":
print("Waiting for Feature Group Update")
time.sleep(5)
status = feature_group.describe().get("LastUpdateStatus").get("Status")
if status != "Successful":
print(feature_group.describe())
raise RuntimeError(f"Failed to update feature group {feature_group.name}")
print(f"FeatureGroup {feature_group.name} successfully updated.")


@contextmanager
def cleanup_feature_group(feature_group: FeatureGroup):
try:
Expand Down
47 changes: 47 additions & 0 deletions tests/unit/sagemaker/feature_store/test_feature_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
AthenaQuery,
IngestionError,
)
from sagemaker.feature_store.inputs import FeatureParameter


class PicklableMock(Mock):
Expand Down Expand Up @@ -154,6 +155,52 @@ def test_feature_store_describe(sagemaker_session_mock):
)


def test_feature_store_update(sagemaker_session_mock, feature_group_dummy_definitions):
feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock)
feature_group.update(feature_group_dummy_definitions)
sagemaker_session_mock.update_feature_group.assert_called_with(
feature_group_name="MyFeatureGroup",
feature_additions=[fd.to_dict() for fd in feature_group_dummy_definitions],
)


def test_feature_metadata_update(sagemaker_session_mock):
feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock)

parameter_additions = [FeatureParameter(key="key1", value="value1")]
parameter_removals = ["key2"]

feature_group.update_feature_metadata(
feature_name="Feature1",
description="TestDescription",
parameter_additions=parameter_additions,
parameter_removals=parameter_removals,
)
sagemaker_session_mock.update_feature_metadata.assert_called_with(
feature_group_name="MyFeatureGroup",
feature_name="Feature1",
description="TestDescription",
parameter_additions=[pa.to_dict() for pa in parameter_additions],
parameter_removals=parameter_removals,
)
feature_group.update_feature_metadata(feature_name="Feature1", description="TestDescription")
sagemaker_session_mock.update_feature_metadata.assert_called_with(
feature_group_name="MyFeatureGroup",
feature_name="Feature1",
description="TestDescription",
parameter_additions=[],
parameter_removals=[],
)


def test_feature_metadata_describe(sagemaker_session_mock):
feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock)
feature_group.describe_feature_metadata(feature_name="Feature1")
sagemaker_session_mock.describe_feature_metadata.assert_called_with(
feature_group_name="MyFeatureGroup", feature_name="Feature1"
)


def test_put_record(sagemaker_session_mock):
feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock)
feature_group.put_record(record=[])
Expand Down
6 changes: 6 additions & 0 deletions tests/unit/sagemaker/feature_store/test_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
S3StorageConfig,
DataCatalogConfig,
OfflineStoreConfig,
FeatureParameter,
)


Expand Down Expand Up @@ -83,3 +84,8 @@ def test_offline_data_store_config():
"DisableGlueTableCreation": False,
}
)


def test_feature_metadata():
config = FeatureParameter(key="key", value="value")
assert ordered(config.to_dict()) == ordered({"Key": "key", "Value": "value"})
Loading