Skip to content

Commit

Permalink
fix: add kwargs to create_model for 1p to work with kms (#1081)
Browse files Browse the repository at this point in the history
  • Loading branch information
Dan authored and icywang86rui committed Oct 10, 2019
1 parent d12071f commit 57a2384
Show file tree
Hide file tree
Showing 7 changed files with 21 additions and 7 deletions.
4 changes: 3 additions & 1 deletion src/sagemaker/amazon/factorization_machines.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ def __init__(
self.factors_init_sigma = factors_init_sigma
self.factors_init_value = factors_init_value

def create_model(self, vpc_config_override=VPC_CONFIG_DEFAULT):
def create_model(self, vpc_config_override=VPC_CONFIG_DEFAULT, **kwargs):
"""Return a :class:`~sagemaker.amazon.FactorizationMachinesModel`
referencing the latest s3 model data produced by this Estimator.
Expand All @@ -244,12 +244,14 @@ def create_model(self, vpc_config_override=VPC_CONFIG_DEFAULT):
the model. Default: use subnets and security groups from this Estimator.
* 'Subnets' (list[str]): List of subnet ids.
* 'SecurityGroupIds' (list[str]): List of security group ids.
**kwargs: Additional kwargs passed to the FactorizationMachinesModel constructor.
"""
return FactorizationMachinesModel(
self.model_data,
self.role,
sagemaker_session=self.sagemaker_session,
vpc_config=self.get_vpc_config(vpc_config_override),
**kwargs
)


Expand Down
4 changes: 3 additions & 1 deletion src/sagemaker/amazon/ipinsights.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def __init__(
self.shuffled_negative_sampling_rate = shuffled_negative_sampling_rate
self.weight_decay = weight_decay

def create_model(self, vpc_config_override=VPC_CONFIG_DEFAULT):
def create_model(self, vpc_config_override=VPC_CONFIG_DEFAULT, **kwargs):
"""Create a model for the latest s3 model produced by this estimator.
Args:
Expand All @@ -140,6 +140,7 @@ def create_model(self, vpc_config_override=VPC_CONFIG_DEFAULT):
Default: use subnets and security groups from this Estimator.
* 'Subnets' (list[str]): List of subnet ids.
* 'SecurityGroupIds' (list[str]): List of security group ids.
**kwargs: Additional kwargs passed to the IPInsightsModel constructor.
Returns:
:class:`~sagemaker.amazon.IPInsightsModel`: references the latest s3 model
data produced by this estimator.
Expand All @@ -149,6 +150,7 @@ def create_model(self, vpc_config_override=VPC_CONFIG_DEFAULT):
self.role,
sagemaker_session=self.sagemaker_session,
vpc_config=self.get_vpc_config(vpc_config_override),
**kwargs
)

def _prepare_for_training(self, records, mini_batch_size=None, job_name=None):
Expand Down
4 changes: 3 additions & 1 deletion src/sagemaker/amazon/knn.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def __init__(
'"dimension_reduction_target" is required when "dimension_reduction_type" is set.'
)

def create_model(self, vpc_config_override=VPC_CONFIG_DEFAULT):
def create_model(self, vpc_config_override=VPC_CONFIG_DEFAULT, **kwargs):
"""Return a :class:`~sagemaker.amazon.KNNModel` referencing the latest
s3 model data produced by this Estimator.
Expand All @@ -154,12 +154,14 @@ def create_model(self, vpc_config_override=VPC_CONFIG_DEFAULT):
the model. Default: use subnets and security groups from this Estimator.
* 'Subnets' (list[str]): List of subnet ids.
* 'SecurityGroupIds' (list[str]): List of security group ids.
**kwargs: Additional kwargs passed to the KNNModel constructor.
"""
return KNNModel(
self.model_data,
self.role,
sagemaker_session=self.sagemaker_session,
vpc_config=self.get_vpc_config(vpc_config_override),
**kwargs
)

def _prepare_for_training(self, records, mini_batch_size=None, job_name=None):
Expand Down
4 changes: 3 additions & 1 deletion src/sagemaker/amazon/ntm.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def __init__(
self.weight_decay = weight_decay
self.learning_rate = learning_rate

def create_model(self, vpc_config_override=VPC_CONFIG_DEFAULT):
def create_model(self, vpc_config_override=VPC_CONFIG_DEFAULT, **kwargs):
"""Return a :class:`~sagemaker.amazon.NTMModel` referencing the latest
s3 model data produced by this Estimator.
Expand All @@ -164,12 +164,14 @@ def create_model(self, vpc_config_override=VPC_CONFIG_DEFAULT):
the model. Default: use subnets and security groups from this Estimator.
* 'Subnets' (list[str]): List of subnet ids.
* 'SecurityGroupIds' (list[str]): List of security group ids.
**kwargs: Additional kwargs passed to the NTMModel constructor.
"""
return NTMModel(
self.model_data,
self.role,
sagemaker_session=self.sagemaker_session,
vpc_config=self.get_vpc_config(vpc_config_override),
**kwargs
)

def _prepare_for_training( # pylint: disable=signature-differs
Expand Down
4 changes: 3 additions & 1 deletion src/sagemaker/amazon/object2vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ def __init__(
self.enc0_freeze_pretrained_embedding = enc0_freeze_pretrained_embedding
self.enc1_freeze_pretrained_embedding = enc1_freeze_pretrained_embedding

def create_model(self, vpc_config_override=VPC_CONFIG_DEFAULT):
def create_model(self, vpc_config_override=VPC_CONFIG_DEFAULT, **kwargs):
"""Return a :class:`~sagemaker.amazon.Object2VecModel` referencing the
latest s3 model data produced by this Estimator.
Expand All @@ -304,12 +304,14 @@ def create_model(self, vpc_config_override=VPC_CONFIG_DEFAULT):
the model. Default: use subnets and security groups from this Estimator.
* 'Subnets' (list[str]): List of subnet ids.
* 'SecurityGroupIds' (list[str]): List of security group ids.
**kwargs: Additional kwargs passed to the Object2VecModel constructor.
"""
return Object2VecModel(
self.model_data,
self.role,
sagemaker_session=self.sagemaker_session,
vpc_config=self.get_vpc_config(vpc_config_override),
**kwargs
)

def _prepare_for_training(self, records, mini_batch_size=None, job_name=None):
Expand Down
4 changes: 3 additions & 1 deletion src/sagemaker/amazon/pca.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def __init__(
self.subtract_mean = subtract_mean
self.extra_components = extra_components

def create_model(self, vpc_config_override=VPC_CONFIG_DEFAULT):
def create_model(self, vpc_config_override=VPC_CONFIG_DEFAULT, **kwargs):
"""Return a :class:`~sagemaker.amazon.pca.PCAModel` referencing the
latest s3 model data produced by this Estimator.
Expand All @@ -129,12 +129,14 @@ def create_model(self, vpc_config_override=VPC_CONFIG_DEFAULT):
the model. Default: use subnets and security groups from this Estimator.
* 'Subnets' (list[str]): List of subnet ids.
* 'SecurityGroupIds' (list[str]): List of security group ids.
**kwargs: Additional kwargs passed to the PCAModel constructor.
"""
return PCAModel(
self.model_data,
self.role,
sagemaker_session=self.sagemaker_session,
vpc_config=self.get_vpc_config(vpc_config_override),
**kwargs
)

def _prepare_for_training(self, records, mini_batch_size=None, job_name=None):
Expand Down
4 changes: 3 additions & 1 deletion src/sagemaker/amazon/randomcutforest.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def __init__(
self.num_trees = num_trees
self.eval_metrics = eval_metrics

def create_model(self, vpc_config_override=VPC_CONFIG_DEFAULT):
def create_model(self, vpc_config_override=VPC_CONFIG_DEFAULT, **kwargs):
"""Return a :class:`~sagemaker.amazon.RandomCutForestModel` referencing
the latest s3 model data produced by this Estimator.
Expand All @@ -122,12 +122,14 @@ def create_model(self, vpc_config_override=VPC_CONFIG_DEFAULT):
the model. Default: use subnets and security groups from this Estimator.
* 'Subnets' (list[str]): List of subnet ids.
* 'SecurityGroupIds' (list[str]): List of security group ids.
**kwargs: Additional kwargs passed to the RandomCutForestModel constructor.
"""
return RandomCutForestModel(
self.model_data,
self.role,
sagemaker_session=self.sagemaker_session,
vpc_config=self.get_vpc_config(vpc_config_override),
**kwargs
)

def _prepare_for_training(self, records, mini_batch_size=None, job_name=None):
Expand Down

0 comments on commit 57a2384

Please sign in to comment.