Skip to content

Commit

Permalink
Adding Object2Vec support to SageMaker Python SDK (#467)
Browse files Browse the repository at this point in the history
  • Loading branch information
pnpnpn authored and laurenyu committed Nov 8, 2018
1 parent 5201c60 commit 9285b6a
Show file tree
Hide file tree
Showing 10 changed files with 1,580 additions and 5 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@
CHANGELOG
=========

1.14.1
======

* feature: Estimators: add support for Amazon Object2Vec algorithm

1.14.0
======

Expand Down
2 changes: 1 addition & 1 deletion README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,7 @@ Amazon SageMaker provides several built-in machine learning algorithms that you
The full list of algorithms is available at: https://docs.aws.amazon.com/sagemaker/latest/dg/algos.html
The SageMaker Python SDK includes estimator wrappers for the AWS K-means, Principal Components Analysis (PCA), Linear Learner, Factorization Machines,
Latent Dirichlet Allocation (LDA), Neural Topic Model (NTM) Random Cut Forest and k-nearest neighbors (k-NN) algorithms.
Latent Dirichlet Allocation (LDA), Neural Topic Model (NTM), Random Cut Forest, k-nearest neighbors (k-NN), and Object2Vec algorithms.
For more information, see `AWS SageMaker Estimators and Models`_.
Expand Down
3 changes: 2 additions & 1 deletion src/sagemaker/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from sagemaker.amazon.randomcutforest import (RandomCutForest, RandomCutForestModel, # noqa: F401
RandomCutForestPredictor)
from sagemaker.amazon.knn import KNN, KNNModel, KNNPredictor # noqa: F401
from sagemaker.amazon.object2vec import Object2Vec, Object2VecModel # noqa: F401

from sagemaker.analytics import TrainingJobAnalytics, HyperparameterTuningJobAnalytics # noqa: F401
from sagemaker.local.local_session import LocalSession # noqa: F401
Expand All @@ -35,4 +36,4 @@
from sagemaker.session import s3_input # noqa: F401
from sagemaker.session import get_execution_role # noqa: F401

__version__ = '1.14.0'
__version__ = '1.14.1'
2 changes: 1 addition & 1 deletion src/sagemaker/amazon/README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ Amazon SageMaker provides several built-in machine learning algorithms that you

The full list of algorithms is available on the AWS website: https://docs.aws.amazon.com/sagemaker/latest/dg/algos.html

SageMaker Python SDK includes Estimator wrappers for the AWS K-means, Principal Components Analysis(PCA), Linear Learner, Factorization Machines, Latent Dirichlet Allocation(LDA), Neural Topic Model(NTM), Random Cut Forest algorithms and k-nearest neighbors (k-NN).
SageMaker Python SDK includes Estimator wrappers for the AWS K-means, Principal Components Analysis(PCA), Linear Learner, Factorization Machines, Latent Dirichlet Allocation(LDA), Neural Topic Model(NTM), Random Cut Forest algorithms, k-nearest neighbors (k-NN) and Object2Vec.

Definition and usage
~~~~~~~~~~~~~~~~~~~~
Expand Down
2 changes: 1 addition & 1 deletion src/sagemaker/amazon/amazon_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ def registry(region_name, algorithm=None):
https://github.com/aws/sagemaker-python-sdk/tree/master/src/sagemaker/amazon
"""
if algorithm in [None, "pca", "kmeans", "linear-learner", "factorization-machines", "ntm",
"randomcutforest", "knn"]:
"randomcutforest", "knn", "object2vec"]:
account_id = {
"us-east-1": "382416733822",
"us-east-2": "404615174143",
Expand Down
247 changes: 247 additions & 0 deletions src/sagemaker/amazon/object2vec.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,247 @@
# Copyright 2017-2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
from __future__ import absolute_import

from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase, registry
from sagemaker.amazon.hyperparameter import Hyperparameter as hp # noqa
from sagemaker.amazon.validation import ge, le, isin
from sagemaker.predictor import RealTimePredictor
from sagemaker.model import Model
from sagemaker.session import Session
from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT


class Object2Vec(AmazonAlgorithmEstimatorBase):

repo_name = 'object2vec'
repo_version = 1
MINI_BATCH_SIZE = 32

enc_dim = hp('enc_dim', (ge(4), le(10000)),
'An integer in [4, 10000]', int)
mini_batch_size = hp('mini_batch_size', (ge(1), le(10000)),
'An integer in [1, 10000]', int)
epochs = hp('epochs', (ge(1), le(100)),
'An integer in [1, 100]', int)
early_stopping_patience = hp('early_stopping_patience', (ge(1), le(5)),
'An integer in [1, 5]', int)
early_stopping_tolerance = hp('early_stopping_tolerance', (ge(1e-06), le(0.1)),
'A float in [1e-06, 0.1]', float)
dropout = hp('dropout', (ge(0.0), le(1.0)),
'A float in [0.0, 1.0]', float)
weight_decay = hp('weight_decay', (ge(0.0), le(10000.0)),
'A float in [0.0, 10000.0]', float)
bucket_width = hp('bucket_width', (ge(0), le(100)),
'An integer in [0, 100]', int)
num_classes = hp('num_classes', (ge(2), le(30)),
'An integer in [2, 30]', int)
mlp_layers = hp('mlp_layers', (ge(1), le(10)),
'An integer in [1, 10]', int)
mlp_dim = hp('mlp_dim', (ge(2), le(10000)),
'An integer in [2, 10000]', int)
mlp_activation = hp('mlp_activation', isin("tanh", "relu", "linear"),
'One of "tanh", "relu", "linear"', str)
output_layer = hp('output_layer', isin("softmax", "mean_squared_error"),
'One of "softmax", "mean_squared_error"', str)
optimizer = hp('optimizer', isin("adagrad", "adam", "rmsprop", "sgd", "adadelta"),
'One of "adagrad", "adam", "rmsprop", "sgd", "adadelta"', str)
learning_rate = hp('learning_rate', (ge(1e-06), le(1.0)),
'A float in [1e-06, 1.0]', float)
enc0_network = hp('enc0_network', isin("hcnn", "bilstm", "pooled_embedding"),
'One of "hcnn", "bilstm", "pooled_embedding"', str)
enc1_network = hp('enc1_network', isin("hcnn", "bilstm", "pooled_embedding", "enc0"),
'One of "hcnn", "bilstm", "pooled_embedding", "enc0"', str)
enc0_cnn_filter_width = hp('enc0_cnn_filter_width', (ge(1), le(9)),
'An integer in [1, 9]', int)
enc1_cnn_filter_width = hp('enc1_cnn_filter_width', (ge(1), le(9)),
'An integer in [1, 9]', int)
enc0_max_seq_len = hp('enc0_max_seq_len', (ge(1), le(5000)),
'An integer in [1, 5000]', int)
enc1_max_seq_len = hp('enc1_max_seq_len', (ge(1), le(5000)),
'An integer in [1, 5000]', int)
enc0_token_embedding_dim = hp('enc0_token_embedding_dim', (ge(2), le(1000)),
'An integer in [2, 1000]', int)
enc1_token_embedding_dim = hp('enc1_token_embedding_dim', (ge(2), le(1000)),
'An integer in [2, 1000]', int)
enc0_vocab_size = hp('enc0_vocab_size', (ge(2), le(3000000)),
'An integer in [2, 3000000]', int)
enc1_vocab_size = hp('enc1_vocab_size', (ge(2), le(3000000)),
'An integer in [2, 3000000]', int)
enc0_layers = hp('enc0_layers', (ge(1), le(4)),
'An integer in [1, 4]', int)
enc1_layers = hp('enc1_layers', (ge(1), le(4)),
'An integer in [1, 4]', int)
enc0_freeze_pretrained_embedding = hp('enc0_freeze_pretrained_embedding', (),
'Either True or False', bool)
enc1_freeze_pretrained_embedding = hp('enc1_freeze_pretrained_embedding', (),
'Either True or False', bool)

def __init__(self, role, train_instance_count, train_instance_type,
epochs,
enc0_max_seq_len,
enc0_vocab_size,
enc_dim=None,
mini_batch_size=None,
early_stopping_patience=None,
early_stopping_tolerance=None,
dropout=None,
weight_decay=None,
bucket_width=None,
num_classes=None,
mlp_layers=None,
mlp_dim=None,
mlp_activation=None,
output_layer=None,
optimizer=None,
learning_rate=None,
enc0_network=None,
enc1_network=None,
enc0_cnn_filter_width=None,
enc1_cnn_filter_width=None,
enc1_max_seq_len=None,
enc0_token_embedding_dim=None,
enc1_token_embedding_dim=None,
enc1_vocab_size=None,
enc0_layers=None,
enc1_layers=None,
enc0_freeze_pretrained_embedding=None,
enc1_freeze_pretrained_embedding=None,
**kwargs):
"""Object2Vec is :class:`Estimator` used for anomaly detection.
This Estimator may be fit via calls to
:meth:`~sagemaker.amazon.amazon_estimator.AmazonAlgorithmEstimatorBase.fit`.
There is an utility :meth:`~sagemaker.amazon.amazon_estimator.AmazonAlgorithmEstimatorBase.record_set` that
can be used to upload data to S3 and creates :class:`~sagemaker.amazon.amazon_estimator.RecordSet` to be passed
to the `fit` call.
After this Estimator is fit, model data is stored in S3. The model may be deployed to an Amazon SageMaker
Endpoint by invoking :meth:`~sagemaker.amazon.estimator.EstimatorBase.deploy`. As well as deploying an
Endpoint, deploy returns a :class:`~sagemaker.amazon.RealTimePredictor` object that can be used
for inference calls using the trained model hosted in the SageMaker Endpoint.
Object2Vec Estimators can be configured by setting hyperparameters. The available hyperparameters for
Object2Vec are documented below.
For further information on the AWS Object2Vec algorithm,
please consult AWS technical documentation: https://docs.aws.amazon.com/sagemaker/latest/dg/object2vec.html
Args:
role (str): An AWS IAM role (either name or full ARN). The Amazon SageMaker training jobs and
APIs that create Amazon SageMaker endpoints use this role to access
training data and model artifacts. After the endpoint is created,
the inference code might use the IAM role, if accessing AWS resource.
train_instance_count (int): Number of Amazon EC2 instances to use for training.
train_instance_type (str): Type of EC2 instance to use for training, for example, 'ml.c4.xlarge'.
epochs(int): Total number of epochs for SGD training
enc0_max_seq_len(int): Maximum sequence length
enc0_vocab_size(int): Vocabulary size of tokens
enc_dim(int): Optional. Dimension of the output of the embedding layer
mini_batch_size(int): Optional. mini batch size for SGD training
early_stopping_patience(int): Optional. The allowed number of consecutive epochs without improvement
before early stopping is applied
early_stopping_tolerance(float): Optional. The value used to determine whether the algorithm has made
improvement between two consecutive epochs for early stopping
dropout(float): Optional. Dropout probability on network layers
weight_decay(float): Optional. Weight decay parameter during optimization
bucket_width(int): Optional. The allowed difference between data sequence length when bucketing is enabled
num_classes(int): Optional. Number of classes for classification training (ignored for regression problems)
mlp_layers(int): Optional. Number of MLP layers in the network
mlp_dim(int): Optional. Dimension of the output of MLP layer
mlp_activation(str): Optional. Type of activation function for the MLP layer
output_layer(str): Optional. Type of output layer
optimizer(str): Optional. Type of optimizer for training
learning_rate(float): Optional. Learning rate for SGD training
enc0_network(str): Optional. Network model of encoder "enc0"
enc1_network(str): Optional. Network model of encoder "enc1"
enc0_cnn_filter_width(int): Optional. CNN filter width
enc1_cnn_filter_width(int): Optional. CNN filter width
enc1_max_seq_len(int): Optional. Maximum sequence length
enc0_token_embedding_dim(int): Optional. Output dimension of token embedding layer
enc1_token_embedding_dim(int): Optional. Output dimension of token embedding layer
enc1_vocab_size(int): Optional. Vocabulary size of tokens
enc0_layers(int): Optional. Number of layers in encoder
enc1_layers(int): Optional. Number of layers in encoder
enc0_freeze_pretrained_embedding(bool): Optional. Freeze pretrained embedding weights
enc1_freeze_pretrained_embedding(bool): Optional. Freeze pretrained embedding weights
**kwargs: base class keyword argument values.
"""

super(Object2Vec, self).__init__(role, train_instance_count, train_instance_type, **kwargs)

self.enc_dim = enc_dim
self.mini_batch_size = mini_batch_size
self.epochs = epochs
self.early_stopping_patience = early_stopping_patience
self.early_stopping_tolerance = early_stopping_tolerance
self.dropout = dropout
self.weight_decay = weight_decay
self.bucket_width = bucket_width
self.num_classes = num_classes
self.mlp_layers = mlp_layers
self.mlp_dim = mlp_dim
self.mlp_activation = mlp_activation
self.output_layer = output_layer
self.optimizer = optimizer
self.learning_rate = learning_rate
self.enc0_network = enc0_network
self.enc1_network = enc1_network
self.enc0_cnn_filter_width = enc0_cnn_filter_width
self.enc1_cnn_filter_width = enc1_cnn_filter_width
self.enc0_max_seq_len = enc0_max_seq_len
self.enc1_max_seq_len = enc1_max_seq_len
self.enc0_token_embedding_dim = enc0_token_embedding_dim
self.enc1_token_embedding_dim = enc1_token_embedding_dim
self.enc0_vocab_size = enc0_vocab_size
self.enc1_vocab_size = enc1_vocab_size
self.enc0_layers = enc0_layers
self.enc1_layers = enc1_layers
self.enc0_freeze_pretrained_embedding = enc0_freeze_pretrained_embedding
self.enc1_freeze_pretrained_embedding = enc1_freeze_pretrained_embedding

def create_model(self, vpc_config_override=VPC_CONFIG_DEFAULT):
"""Return a :class:`~sagemaker.amazon.Object2VecModel` referencing the latest
s3 model data produced by this Estimator.
Args:
vpc_config_override (dict[str, list[str]]): Optional override for VpcConfig set on the model.
Default: use subnets and security groups from this Estimator.
* 'Subnets' (list[str]): List of subnet ids.
* 'SecurityGroupIds' (list[str]): List of security group ids.
"""
return Object2VecModel(self.model_data, self.role, sagemaker_session=self.sagemaker_session,
vpc_config=self.get_vpc_config(vpc_config_override))

def _prepare_for_training(self, records, mini_batch_size=None, job_name=None):
if mini_batch_size is None:
mini_batch_size = self.MINI_BATCH_SIZE

super(Object2Vec, self)._prepare_for_training(records, mini_batch_size=mini_batch_size, job_name=job_name)


class Object2VecModel(Model):
"""Reference Object2Vec s3 model data. Calling :meth:`~sagemaker.model.Model.deploy` creates an
Endpoint and returns a Predictor that calculates anomaly scores for datapoints."""

def __init__(self, model_data, role, sagemaker_session=None, **kwargs):
sagemaker_session = sagemaker_session or Session()
repo = '{}:{}'.format(Object2Vec.repo_name, Object2Vec.repo_version)
image = '{}/{}'.format(registry(sagemaker_session.boto_session.region_name,
Object2Vec.repo_name), repo)
super(Object2VecModel, self).__init__(model_data, image, role,
predictor_cls=RealTimePredictor,
sagemaker_session=sagemaker_session,
**kwargs)
3 changes: 2 additions & 1 deletion src/sagemaker/tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@
'linear-learner': 'LinearLearner',
'ntm': 'NTM',
'randomcutforest': 'RandomCutForest',
'knn': 'KNN'
'knn': 'KNN',
'object2vec': 'Object2Vec',
}


Expand Down

0 comments on commit 9285b6a

Please sign in to comment.