Skip to content

Commit

Permalink
fix: prevent multiple values error in sklearn.transformer() (#978)
Browse files Browse the repository at this point in the history
  • Loading branch information
jesterhazy committed Aug 13, 2019
1 parent b7a2b9c commit 987bbe6
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 0 deletions.
6 changes: 6 additions & 0 deletions src/sagemaker/sklearn/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,12 @@ def create_model(
object. See :func:`~sagemaker.sklearn.model.SKLearnModel` for full details.
"""
role = role or self.role

# remove unwanted entry_point kwarg
if "entry_point" in kwargs:
logger.debug("removing unused entry_point argument: %s", str(kwargs["entry_point"]))
kwargs = {k: v for k, v in kwargs.items() if k != "entry_point"}

return SKLearnModel(
self.model_data,
role,
Expand Down
20 changes: 20 additions & 0 deletions tests/unit/test_sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,26 @@ def test_sklearn(strftime, sagemaker_session, sklearn_version):
assert isinstance(predictor, SKLearnPredictor)


def test_transform_multiple_values_for_entry_point_issue(sagemaker_session, sklearn_version):
# https://github.com/aws/sagemaker-python-sdk/issues/974
sklearn = SKLearn(
entry_point=SCRIPT_PATH,
role=ROLE,
sagemaker_session=sagemaker_session,
train_instance_type=INSTANCE_TYPE,
py_version=PYTHON_VERSION,
framework_version=sklearn_version,
)

inputs = "s3://mybucket/train"

sklearn.fit(inputs=inputs)

transformer = sklearn.transformer(instance_count=1, instance_type="ml.m4.xlarge")
# if we got here, we didn't get a "multiple values" error
assert transformer is not None


def test_fail_distributed_training(sagemaker_session, sklearn_version):
with pytest.raises(AttributeError) as error:
SKLearn(
Expand Down

0 comments on commit 987bbe6

Please sign in to comment.