diff --git a/CHANGELOG.rst b/CHANGELOG.rst index ccf88d1de9..f7810b3187 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -2,6 +2,15 @@ CHANGELOG ========= +1.0.4 +===== + +* feature: Estimators: add support for Amazon Neural Topic Model(NTM) algorithm +* feature: Documentation: Fix description of an argument of sagemaker.session.train +* feature: Documentation: Add FM and LDA to the documentation +* feature: Estimators: add support for async fit +* bug-fix: Estimators: fix estimator role expansion + 1.0.3 ===== diff --git a/README.rst b/README.rst index f60aafa109..63dd2f102c 100644 --- a/README.rst +++ b/README.rst @@ -39,7 +39,7 @@ You can install from source by cloning this repository and issuing a pip install git clone https://github.com/aws/sagemaker-python-sdk.git python setup.py sdist - pip install dist/sagemaker-1.0.3.tar.gz + pip install dist/sagemaker-1.0.4.tar.gz Supported Python versions ~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -1447,7 +1447,7 @@ Amazon SageMaker provides several built-in machine learning algorithms that you The full list of algorithms is available on the AWS website: https://docs.aws.amazon.com/sagemaker/latest/dg/algos.html -SageMaker Python SDK includes Estimator wrappers for the AWS K-means, Principal Components Analysis, Linear Learner, Factorization Machines and LDA algorithms. +SageMaker Python SDK includes Estimator wrappers for the AWS K-means, Principal Components Analysis(PCA), Linear Learner, Factorization Machines, Latent Dirichlet Allocation(LDA) and Neural Topic Model(NTM) algorithms. Definition and usage ~~~~~~~~~~~~~~~~~~~~ diff --git a/doc/conf.py b/doc/conf.py index 3675148d67..be63a6bd20 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -18,7 +18,7 @@ def __getattr__(cls, name): 'tensorflow.python.framework', 'tensorflow_serving', 'tensorflow_serving.apis'] sys.modules.update((mod_name, Mock()) for mod_name in MOCK_MODULES) -version = '1.0.3' +version = '1.0.4' project = u'sagemaker' # Add any Sphinx extension module names here, as strings. They can be extensions diff --git a/doc/index.rst b/doc/index.rst index 1a840f6257..9e97ecba83 100644 --- a/doc/index.rst +++ b/doc/index.rst @@ -49,3 +49,4 @@ Amazon provides implementations of some common machine learning algortithms opti sagemaker.amazon.amazon_estimator factorization_machines lda + ntm diff --git a/doc/ntm.rst b/doc/ntm.rst new file mode 100644 index 0000000000..628cfd7de8 --- /dev/null +++ b/doc/ntm.rst @@ -0,0 +1,23 @@ +NTM +-------------------- + +The Amazon SageMaker NTM algorithm. + +.. autoclass:: sagemaker.NTM + :members: + :undoc-members: + :show-inheritance: + :inherited-members: + :exclude-members: image, num_topics, encoder_layers, epochs, encoder_layers_activation, optimizer, tolerance, + num_patience_epochs, batch_norm, rescale_gradient, clip_gradient, weight_decay, learning_rate + + +.. autoclass:: sagemaker.NTMModel + :members: + :undoc-members: + :show-inheritance: + +.. autoclass:: sagemaker.NTMPredictor + :members: + :undoc-members: + :show-inheritance: diff --git a/setup.py b/setup.py index 1edd9dd2dc..80c9fce712 100644 --- a/setup.py +++ b/setup.py @@ -11,7 +11,7 @@ def read(fname): setup(name="sagemaker", - version="1.0.3", + version="1.0.4", description="Open source library for training and deploying models on Amazon SageMaker.", packages=find_packages('src'), package_dir={'': 'src'}, diff --git a/src/sagemaker/__init__.py b/src/sagemaker/__init__.py index b293cf2f5c..93a62c2a72 100644 --- a/src/sagemaker/__init__.py +++ b/src/sagemaker/__init__.py @@ -19,6 +19,7 @@ from sagemaker.amazon.linear_learner import LinearLearner, LinearLearnerModel, LinearLearnerPredictor from sagemaker.amazon.factorization_machines import FactorizationMachines, FactorizationMachinesModel from sagemaker.amazon.factorization_machines import FactorizationMachinesPredictor +from sagemaker.amazon.ntm import NTM, NTMModel, NTMPredictor from sagemaker.model import Model from sagemaker.predictor import RealTimePredictor @@ -33,5 +34,5 @@ LinearLearnerModel, LinearLearnerPredictor, LDA, LDAModel, LDAPredictor, FactorizationMachines, FactorizationMachinesModel, FactorizationMachinesPredictor, - Model, RealTimePredictor, Session, + Model, NTM, NTMModel, NTMPredictor, RealTimePredictor, Session, container_def, s3_input, production_variant, get_execution_role] diff --git a/src/sagemaker/amazon/amazon_estimator.py b/src/sagemaker/amazon/amazon_estimator.py index 27538edf59..22022c65f3 100644 --- a/src/sagemaker/amazon/amazon_estimator.py +++ b/src/sagemaker/amazon/amazon_estimator.py @@ -228,7 +228,7 @@ def upload_numpy_to_s3_shards(num_shards, s3, bucket, key_prefix, array, labels= def registry(region_name, algorithm=None): """Return docker registry for the given AWS region""" - if algorithm in [None, "pca", "kmeans", "linear-learner", "factorization-machines"]: + if algorithm in [None, "pca", "kmeans", "linear-learner", "factorization-machines", "ntm"]: account_id = { "us-east-1": "382416733822", "us-east-2": "404615174143", diff --git a/src/sagemaker/amazon/ntm.py b/src/sagemaker/amazon/ntm.py new file mode 100644 index 0000000000..21f0c8f1aa --- /dev/null +++ b/src/sagemaker/amazon/ntm.py @@ -0,0 +1,146 @@ +# Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase, registry +from sagemaker.amazon.common import numpy_to_record_serializer, record_deserializer +from sagemaker.amazon.hyperparameter import Hyperparameter as hp # noqa +from sagemaker.amazon.validation import ge, le, isin +from sagemaker.predictor import RealTimePredictor +from sagemaker.model import Model +from sagemaker.session import Session + + +class NTM(AmazonAlgorithmEstimatorBase): + + repo_name = 'ntm' + repo_version = 1 + + num_topics = hp('num_topics', (ge(2), le(1000)), 'An integer in [2, 1000]', int) + encoder_layers = hp(name='encoder_layers', validation_message='A comma separated list of ' + 'positive integers', data_type=list) + epochs = hp('epochs', (ge(1), le(100)), 'An integer in [1, 100]', int) + encoder_layers_activation = hp('encoder_layers_activation', isin('sigmoid', 'tanh', 'relu'), + 'One of "sigmoid", "tanh" or "relu"', str) + optimizer = hp('optimizer', isin('adagrad', 'adam', 'rmsprop', 'sgd', 'adadelta'), + 'One of "adagrad", "adam", "rmsprop", "sgd" and "adadelta"', str) + tolerance = hp('tolerance', (ge(1e-6), le(0.1)), 'A float in [1e-6, 0.1]', float) + num_patience_epochs = hp('num_patience_epochs', (ge(1), le(10)), 'An integer in [1, 10]', int) + batch_norm = hp(name='batch_norm', validation_message='Value must be a boolean', data_type=bool) + rescale_gradient = hp('rescale_gradient', (ge(1e-3), le(1.0)), 'A float in [1e-3, 1.0]', float) + clip_gradient = hp('clip_gradient', ge(1e-3), 'A float greater equal to 1e-3', float) + weight_decay = hp('weight_decay', (ge(0.0), le(1.0)), 'A float in [0.0, 1.0]', float) + learning_rate = hp('learning_rate', (ge(1e-6), le(1.0)), 'A float in [1e-6, 1.0]', float) + + def __init__(self, role, train_instance_count, train_instance_type, num_topics, + encoder_layers=None, epochs=None, encoder_layers_activation=None, optimizer=None, tolerance=None, + num_patience_epochs=None, batch_norm=None, rescale_gradient=None, clip_gradient=None, + weight_decay=None, learning_rate=None, **kwargs): + """Neural Topic Model (NTM) is :class:`Estimator` used for unsupervised learning. + + This Estimator may be fit via calls to + :meth:`~sagemaker.amazon.amazon_estimator.AmazonAlgorithmEstimatorBase.fit`. It requires Amazon + :class:`~sagemaker.amazon.record_pb2.Record` protobuf serialized data to be stored in S3. + There is an utility :meth:`~sagemaker.amazon.amazon_estimator.AmazonAlgorithmEstimatorBase.record_set` that + can be used to upload data to S3 and creates :class:`~sagemaker.amazon.amazon_estimator.RecordSet` to be passed + to the `fit` call. + + To learn more about the Amazon protobuf Record class and how to prepare bulk data in this format, please + consult AWS technical documentation: https://docs.aws.amazon.com/sagemaker/latest/dg/cdf-training.html + + After this Estimator is fit, model data is stored in S3. The model may be deployed to an Amazon SageMaker + Endpoint by invoking :meth:`~sagemaker.amazon.estimator.EstimatorBase.deploy`. As well as deploying an Endpoint, + deploy returns a :class:`~sagemaker.amazon.ntm.NTMPredictor` object that can be used + for inference calls using the trained model hosted in the SageMaker Endpoint. + + NTM Estimators can be configured by setting hyperparameters. The available hyperparameters for + NTM are documented below. + + For further information on the AWS NTM algorithm, + please consult AWS technical documentation: https://docs.aws.amazon.com/sagemaker/latest/dg/ntm.html + + Args: + role (str): An AWS IAM role (either name or full ARN). The Amazon SageMaker training jobs and + APIs that create Amazon SageMaker endpoints use this role to access + training data and model artifacts. After the endpoint is created, + the inference code might use the IAM role, if accessing AWS resource. + train_instance_type (str): Type of EC2 instance to use for training, for example, 'ml.c4.xlarge'. + num_topics (int): Required. The number of topics for NTM to find within the data. + encoder_layers (list): Optional. Represents number of layers in the encoder and the output size of + each layer. + epochs (int): Optional. Maximum number of passes over the training data. + encoder_layers_activation (str): Optional. Activation function to use in the encoder layers. + optimizer (str): Optional. Optimizer to use for training. + tolerance (float): Optional. Maximum relative change in the loss function within the last + num_patience_epochs number of epochs below which early stopping is triggered. + num_patience_epochs (int): Optional. Number of successive epochs over which early stopping criterion + is evaluated. + batch_norm (bool): Optional. Whether to use batch normalization during training. + rescale_gradient (float): Optional. Rescale factor for gradient. + clip_gradient (float): Optional. Maximum magnitude for each gradient component. + weight_decay (float): Optional. Weight decay coefficient. Adds L2 regularization. + learning_rate (float): Optional. Learning rate for the optimizer. + **kwargs: base class keyword argument values. + """ + + super(NTM, self).__init__(role, train_instance_count, train_instance_type, **kwargs) + self.num_topics = num_topics + self.encoder_layers = encoder_layers + self.epochs = epochs + self.encoder_layers_activation = encoder_layers_activation + self.optimizer = optimizer + self.tolerance = tolerance + self.num_patience_epochs = num_patience_epochs + self.batch_norm = batch_norm + self.rescale_gradient = rescale_gradient + self.clip_gradient = clip_gradient + self.weight_decay = weight_decay + self.learning_rate = learning_rate + + def create_model(self): + """Return a :class:`~sagemaker.amazon.NTMModel` referencing the latest + s3 model data produced by this Estimator.""" + + return NTMModel(self.model_data, self.role, sagemaker_session=self.sagemaker_session) + + def fit(self, records, mini_batch_size=None, **kwargs): + if mini_batch_size is not None and (mini_batch_size < 1 or mini_batch_size > 10000): + raise ValueError("mini_batch_size must be in [1, 10000]") + super(NTM, self).fit(records, mini_batch_size, **kwargs) + + +class NTMPredictor(RealTimePredictor): + """Transforms input vectors to lower-dimesional representations. + + The implementation of :meth:`~sagemaker.predictor.RealTimePredictor.predict` in this + `RealTimePredictor` requires a numpy ``ndarray`` as input. The array should contain the + same number of columns as the feature-dimension of the data used to fit the model this + Predictor performs inference on. + + :meth:`predict()` returns a list of :class:`~sagemaker.amazon.record_pb2.Record` objects, one + for each row in the input ``ndarray``. The lower dimension vector result is stored in the ``projection`` + key of the ``Record.label`` field.""" + + def __init__(self, endpoint, sagemaker_session=None): + super(NTMPredictor, self).__init__(endpoint, sagemaker_session, serializer=numpy_to_record_serializer(), + deserializer=record_deserializer()) + + +class NTMModel(Model): + """Reference NTM s3 model data. Calling :meth:`~sagemaker.model.Model.deploy` creates an Endpoint and return + a Predictor that transforms vectors to a lower-dimensional representation.""" + + def __init__(self, model_data, role, sagemaker_session=None): + sagemaker_session = sagemaker_session or Session() + repo = '{}:{}'.format(NTM.repo_name, NTM.repo_version) + image = '{}/{}'.format(registry(sagemaker_session.boto_session.region_name, NTM.repo_name), repo) + super(NTMModel, self).__init__(model_data, image, role, predictor_cls=NTMPredictor, + sagemaker_session=sagemaker_session) diff --git a/src/sagemaker/amazon/validation.py b/src/sagemaker/amazon/validation.py index ede48cc9b3..7c7fa4f2a0 100644 --- a/src/sagemaker/amazon/validation.py +++ b/src/sagemaker/amazon/validation.py @@ -30,6 +30,12 @@ def validate(value): return validate +def le(maximum): + def validate(value): + return value <= maximum + return validate + + def isin(*expected): def validate(value): return value in expected diff --git a/tests/data/ntm/nips-train_1.pbr b/tests/data/ntm/nips-train_1.pbr new file mode 100644 index 0000000000..193cc98860 Binary files /dev/null and b/tests/data/ntm/nips-train_1.pbr differ diff --git a/tests/integ/record_set.py b/tests/integ/record_set.py new file mode 100644 index 0000000000..587ed88d14 --- /dev/null +++ b/tests/integ/record_set.py @@ -0,0 +1,23 @@ +from six.moves.urllib.parse import urlparse + +from sagemaker.amazon.amazon_estimator import RecordSet +from sagemaker.utils import sagemaker_timestamp + + +def prepare_record_set_from_local_files(dir_path, destination, num_records, feature_dim, sagemaker_session): + """Build a :class:`~RecordSet` by pointing to local files. + + Args: + dir_path (string): Path to local directory from where the files shall be uploaded. + destination (string): S3 path to upload the file to. + num_records (int): Number of records in all the files + feature_dim (int): Number of features in the data set + sagemaker_session (sagemaker.session.Session): Session object to manage interactions with Amazon SageMaker APIs. + Returns: + RecordSet: A RecordSet specified by S3Prefix to to be used in training. + """ + key_prefix = urlparse(destination).path + key_prefix = key_prefix + '{}-{}'.format("testfiles", sagemaker_timestamp()) + key_prefix = key_prefix.lstrip('/') + uploaded_location = sagemaker_session.upload_data(path=dir_path, key_prefix=key_prefix) + return RecordSet(uploaded_location, num_records, feature_dim, s3_data_type='S3Prefix') diff --git a/tests/integ/test_lda.py b/tests/integ/test_lda.py index c6df685254..5e7619796e 100644 --- a/tests/integ/test_lda.py +++ b/tests/integ/test_lda.py @@ -13,16 +13,15 @@ import boto3 import numpy as np import os -from six.moves.urllib.parse import urlparse import sagemaker from sagemaker import LDA, LDAModel -from sagemaker.amazon.amazon_estimator import RecordSet from sagemaker.amazon.common import read_records -from sagemaker.utils import name_from_base, sagemaker_timestamp +from sagemaker.utils import name_from_base from tests.integ import DATA_DIR, REGION from tests.integ.timeout import timeout, timeout_and_delete_endpoint_by_name +from tests.integ.record_set import prepare_record_set_from_local_files def test_lda(): @@ -41,8 +40,8 @@ def test_lda(): lda = LDA(role='SageMakerRole', train_instance_type='ml.c4.xlarge', num_topics=10, sagemaker_session=sagemaker_session, base_job_name='test-lda') - record_set = _prepare_record_set_from_local_files(data_path, lda.data_location, - len(all_records), feature_num, sagemaker_session) + record_set = prepare_record_set_from_local_files(data_path, lda.data_location, + len(all_records), feature_num, sagemaker_session) lda.fit(record_set, 100) endpoint_name = name_from_base('lda') @@ -56,22 +55,3 @@ def test_lda(): assert len(result) == 1 for record in result: assert record.label["topic_mixture"] is not None - - -def _prepare_record_set_from_local_files(dir_path, destination, num_records, feature_dim, sagemaker_session): - """Build a :class:`~RecordSet` by pointing to local files. - - Args: - dir_path (string): Path to local directory from where the files shall be uploaded. - destination (string): S3 path to upload the file to. - num_records (int): Number of records in all the files - feature_dim (int): Number of features in the data set - sagemaker_session (sagemaker.session.Session): Session object to manage interactions with Amazon SageMaker APIs. - Returns: - RecordSet: A RecordSet specified by S3Prefix to to be used in training. - """ - key_prefix = urlparse(destination).path - key_prefix = key_prefix + '{}-{}'.format("testfiles", sagemaker_timestamp()) - key_prefix = key_prefix.lstrip('/') - uploaded_location = sagemaker_session.upload_data(path=dir_path, key_prefix=key_prefix) - return RecordSet(uploaded_location, num_records, feature_dim, s3_data_type='S3Prefix') diff --git a/tests/integ/test_ntm.py b/tests/integ/test_ntm.py new file mode 100644 index 0000000000..6be0a2f3e9 --- /dev/null +++ b/tests/integ/test_ntm.py @@ -0,0 +1,57 @@ +# Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +import boto3 +import numpy as np +import os + +import sagemaker +from sagemaker import NTM, NTMModel +from sagemaker.amazon.common import read_records +from sagemaker.utils import name_from_base + +from tests.integ import DATA_DIR, REGION +from tests.integ.timeout import timeout, timeout_and_delete_endpoint_by_name +from tests.integ.record_set import prepare_record_set_from_local_files + + +def test_ntm(): + + with timeout(minutes=15): + sagemaker_session = sagemaker.Session(boto_session=boto3.Session(region_name=REGION)) + data_path = os.path.join(DATA_DIR, 'ntm') + data_filename = 'nips-train_1.pbr' + + with open(os.path.join(data_path, data_filename), 'rb') as f: + all_records = read_records(f) + + # all records must be same + feature_num = int(all_records[0].features['values'].float32_tensor.shape[0]) + + ntm = NTM(role='SageMakerRole', train_instance_count=1, train_instance_type='ml.c4.xlarge', num_topics=10, + sagemaker_session=sagemaker_session, base_job_name='test-ntm') + + record_set = prepare_record_set_from_local_files(data_path, ntm.data_location, + len(all_records), feature_num, sagemaker_session) + ntm.fit(record_set, None) + + endpoint_name = name_from_base('ntm') + with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session, minutes=20): + model = NTMModel(ntm.model_data, role='SageMakerRole', sagemaker_session=sagemaker_session) + predictor = model.deploy(1, 'ml.c4.xlarge', endpoint_name=endpoint_name) + + predict_input = np.random.rand(1, feature_num) + result = predictor.predict(predict_input) + + assert len(result) == 1 + for record in result: + assert record.label["topic_weights"] is not None diff --git a/tests/unit/test_ntm.py b/tests/unit/test_ntm.py new file mode 100644 index 0000000000..f248b73697 --- /dev/null +++ b/tests/unit/test_ntm.py @@ -0,0 +1,327 @@ +# Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +import pytest +from mock import Mock, patch + +from sagemaker.amazon.ntm import NTM, NTMPredictor +from sagemaker.amazon.amazon_estimator import registry, RecordSet + +ROLE = 'myrole' +TRAIN_INSTANCE_COUNT = 1 +TRAIN_INSTANCE_TYPE = 'ml.c4.xlarge' +NUM_TOPICS = 5 + +COMMON_TRAIN_ARGS = {'role': ROLE, 'train_instance_count': TRAIN_INSTANCE_COUNT, + 'train_instance_type': TRAIN_INSTANCE_TYPE} +ALL_REQ_ARGS = dict({'num_topics': NUM_TOPICS}, **COMMON_TRAIN_ARGS) + +REGION = "us-west-2" +BUCKET_NAME = "Some-Bucket" + +DESCRIBE_TRAINING_JOB_RESULT = { + 'ModelArtifacts': { + 'S3ModelArtifacts': "s3://bucket/model.tar.gz" + } +} + + +@pytest.fixture() +def sagemaker_session(): + boto_mock = Mock(name='boto_session', region_name=REGION) + sms = Mock(name='sagemaker_session', boto_session=boto_mock) + sms.boto_region_name = REGION + sms.default_bucket = Mock(name='default_bucket', return_value=BUCKET_NAME) + sms.sagemaker_client.describe_training_job = Mock(name='describe_training_job', + return_value=DESCRIBE_TRAINING_JOB_RESULT) + + return sms + + +def test_init_required_positional(sagemaker_session): + ntm = NTM(ROLE, TRAIN_INSTANCE_COUNT, TRAIN_INSTANCE_TYPE, NUM_TOPICS, sagemaker_session=sagemaker_session) + assert ntm.role == ROLE + assert ntm.train_instance_count == TRAIN_INSTANCE_COUNT + assert ntm.train_instance_type == TRAIN_INSTANCE_TYPE + assert ntm.num_topics == NUM_TOPICS + + +def test_init_required_named(sagemaker_session): + ntm = NTM(sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) + + assert ntm.role == COMMON_TRAIN_ARGS['role'] + assert ntm.train_instance_count == TRAIN_INSTANCE_COUNT + assert ntm.train_instance_type == COMMON_TRAIN_ARGS['train_instance_type'] + assert ntm.num_topics == ALL_REQ_ARGS['num_topics'] + + +def test_all_hyperparameters(sagemaker_session): + ntm = NTM(sagemaker_session=sagemaker_session, + encoder_layers=[1, 2, 3], epochs=3, encoder_layers_activation='tanh', optimizer='sgd', + tolerance=0.05, num_patience_epochs=2, batch_norm=False, rescale_gradient=0.5, clip_gradient=0.5, + weight_decay=0.5, learning_rate=0.5, **ALL_REQ_ARGS) + assert ntm.hyperparameters() == dict( + num_topics=str(ALL_REQ_ARGS['num_topics']), + encoder_layers='[1, 2, 3]', + epochs='3', + encoder_layers_activation='tanh', + optimizer='sgd', + tolerance='0.05', + num_patience_epochs='2', + batch_norm='False', + rescale_gradient='0.5', + clip_gradient='0.5', + weight_decay='0.5', + learning_rate='0.5' + ) + + +def test_image(sagemaker_session): + ntm = NTM(sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) + assert ntm.train_image() == registry(REGION, "ntm") + '/ntm:1' + + +def test_num_topics_validation_fail_type(sagemaker_session): + with pytest.raises(ValueError): + NTM(num_topics='other', sagemaker_session=sagemaker_session, **COMMON_TRAIN_ARGS) + + +def test_num_topics_validation_fail_value_lower(sagemaker_session): + with pytest.raises(ValueError): + NTM(num_topics=0, sagemaker_session=sagemaker_session, **COMMON_TRAIN_ARGS) + + +def test_num_topics_validation_fail_value_upper(sagemaker_session): + with pytest.raises(ValueError): + NTM(num_topics=10000, sagemaker_session=sagemaker_session, **COMMON_TRAIN_ARGS) + + +def test_encoder_layers_validation_fail_type(sagemaker_session): + with pytest.raises(TypeError): + NTM(encoder_layers=0, sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) + + +def test_epochs_validation_fail_type(sagemaker_session): + with pytest.raises(ValueError): + NTM(epochs='other', sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) + + +def test_epochs_validation_fail_value_lower(sagemaker_session): + with pytest.raises(ValueError): + NTM(epochs=0, sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) + + +def test_epochs_validation_fail_value_upper(sagemaker_session): + with pytest.raises(ValueError): + NTM(epochs=1000, sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) + + +def test_encoder_layers_activation_validation_fail_type(sagemaker_session): + with pytest.raises(ValueError): + NTM(encoder_layers_activation=0, sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) + + +def test_encoder_layers_activation_validation_fail_value(sagemaker_session): + with pytest.raises(ValueError): + NTM(encoder_layers_activation='other', sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) + + +def test_optimizer_validation_fail_type(sagemaker_session): + with pytest.raises(ValueError): + NTM(optimizer=0, sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) + + +def test_optimizer_validation_fail_value(sagemaker_session): + with pytest.raises(ValueError): + NTM(optimizer='other', sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) + + +def test_tolerance_validation_fail_type(sagemaker_session): + with pytest.raises(ValueError): + NTM(tolerance='other', sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) + + +def test_tolerance_validation_fail_value_lower(sagemaker_session): + with pytest.raises(ValueError): + NTM(tolerance=0, sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) + + +def test_tolerance_validation_fail_value_upper(sagemaker_session): + with pytest.raises(ValueError): + NTM(tolerance=0.5, sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) + + +def test_num_patience_epochs_validation_fail_type(sagemaker_session): + with pytest.raises(ValueError): + NTM(num_patience_epochs='other', sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) + + +def test_num_patience_epochs_validation_fail_value_lower(sagemaker_session): + with pytest.raises(ValueError): + NTM(num_patience_epochs=0, sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) + + +def test_num_patience_epochs_validation_fail_value_upper(sagemaker_session): + with pytest.raises(ValueError): + NTM(num_patience_epochs=100, sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) + + +def test_rescale_gradient_fail_type(sagemaker_session): + with pytest.raises(ValueError): + NTM(rescale_gradient='other', sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) + + +def test_rescale_gradient_validation_fail_value_lower(sagemaker_session): + with pytest.raises(ValueError): + NTM(rescale_gradient=0, sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) + + +def test_rescale_gradient_validation_fail_value_upper(sagemaker_session): + with pytest.raises(ValueError): + NTM(rescale_gradient=10, sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) + + +def test_clip_gradient_fail_type(sagemaker_session): + with pytest.raises(ValueError): + NTM(clip_gradient='other', sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) + + +def test_clip_gradient_validation_fail_value(sagemaker_session): + with pytest.raises(ValueError): + NTM(clip_gradient=0, sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) + + +def test_weight_decay_fail_type(sagemaker_session): + with pytest.raises(ValueError): + NTM(weight_decay='other', sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) + + +def test_weight_decay_validation_fail_value_lower(sagemaker_session): + with pytest.raises(ValueError): + NTM(weight_decay=-1, sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) + + +def test_weight_decay_validation_fail_value_upper(sagemaker_session): + with pytest.raises(ValueError): + NTM(weight_decay=2, sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) + + +def test_learning_rate_fail_type(sagemaker_session): + with pytest.raises(ValueError): + NTM(learning_rate='other', sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) + + +def test_learning_rate_validation_fail_value_lower(sagemaker_session): + with pytest.raises(ValueError): + NTM(learning_rate=0, sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) + + +def test_learning_rate_validation_fail_value_upper(sagemaker_session): + with pytest.raises(ValueError): + NTM(learning_rate=2, sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) + + +PREFIX = "prefix" +BASE_TRAIN_CALL = { + 'hyperparameters': {}, + 'image': registry(REGION, "ntm") + '/ntm:1', + 'input_config': [{ + 'DataSource': { + 'S3DataSource': { + 'S3DataDistributionType': 'ShardedByS3Key', + 'S3DataType': 'ManifestFile', + 'S3Uri': 's3://{}/{}'.format(BUCKET_NAME, PREFIX) + } + }, + 'ChannelName': 'train' + }], + 'input_mode': 'File', + 'output_config': {'S3OutputPath': 's3://{}/'.format(BUCKET_NAME)}, + 'resource_config': { + 'InstanceCount': TRAIN_INSTANCE_COUNT, + 'InstanceType': TRAIN_INSTANCE_TYPE, + 'VolumeSizeInGB': 30 + }, + 'stop_condition': {'MaxRuntimeInSeconds': 86400} +} + +FEATURE_DIM = 10 +MINI_BATCH_SIZE = 200 + + +@patch("sagemaker.amazon.amazon_estimator.AmazonAlgorithmEstimatorBase.fit") +def test_call_fit(base_fit, sagemaker_session): + ntm = NTM(base_job_name="ntm", sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) + + data = RecordSet("s3://{}/{}".format(BUCKET_NAME, PREFIX), num_records=1, feature_dim=FEATURE_DIM, channel='train') + + ntm.fit(data, MINI_BATCH_SIZE) + + base_fit.assert_called_once() + assert len(base_fit.call_args[0]) == 2 + assert base_fit.call_args[0][0] == data + assert base_fit.call_args[0][1] == MINI_BATCH_SIZE + + +def test_call_fit_none_mini_batch_size(sagemaker_session): + ntm = NTM(base_job_name="ntm", sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) + + data = RecordSet("s3://{}/{}".format(BUCKET_NAME, PREFIX), num_records=1, feature_dim=FEATURE_DIM, + channel='train') + ntm.fit(data) + + +def test_call_fit_wrong_type_mini_batch_size(sagemaker_session): + ntm = NTM(base_job_name="ntm", sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) + + data = RecordSet("s3://{}/{}".format(BUCKET_NAME, PREFIX), num_records=1, feature_dim=FEATURE_DIM, + channel='train') + + with pytest.raises((TypeError, ValueError)): + ntm.fit(data, "some") + + +def test_call_fit_wrong_value_lower_mini_batch_size(sagemaker_session): + ntm = NTM(base_job_name="ntm", sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) + + data = RecordSet("s3://{}/{}".format(BUCKET_NAME, PREFIX), num_records=1, feature_dim=FEATURE_DIM, + channel='train') + with pytest.raises(ValueError): + ntm.fit(data, 0) + + +def test_call_fit_wrong_value_upper_mini_batch_size(sagemaker_session): + ntm = NTM(base_job_name="ntm", sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) + + data = RecordSet("s3://{}/{}".format(BUCKET_NAME, PREFIX), num_records=1, feature_dim=FEATURE_DIM, + channel='train') + with pytest.raises(ValueError): + ntm.fit(data, 10001) + + +def test_model_image(sagemaker_session): + ntm = NTM(sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) + data = RecordSet("s3://{}/{}".format(BUCKET_NAME, PREFIX), num_records=1, feature_dim=FEATURE_DIM, channel='train') + ntm.fit(data, MINI_BATCH_SIZE) + + model = ntm.create_model() + assert model.image == registry(REGION, "ntm") + '/ntm:1' + + +def test_predictor_type(sagemaker_session): + ntm = NTM(sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) + data = RecordSet("s3://{}/{}".format(BUCKET_NAME, PREFIX), num_records=1, feature_dim=FEATURE_DIM, channel='train') + ntm.fit(data, MINI_BATCH_SIZE) + model = ntm.create_model() + predictor = model.deploy(1, TRAIN_INSTANCE_TYPE) + + assert isinstance(predictor, NTMPredictor)