diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 86b55eb7fa..ccf88d1de9 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -2,6 +2,15 @@ CHANGELOG ========= +1.0.3 +===== + +* feature: Estimators: add support for Amazon LDA algorithm +* feature: Hyperparameters: Add data_type to hyperparameters +* feature: Documentation: Update TensorFlow examples following API change +* feature: Session: Support multi-part uploads + + 1.0.2 ===== diff --git a/setup.py b/setup.py index 6b048f83fc..1edd9dd2dc 100644 --- a/setup.py +++ b/setup.py @@ -11,7 +11,7 @@ def read(fname): setup(name="sagemaker", - version="1.0.2", + version="1.0.3", description="Open source library for training and deploying models on Amazon SageMaker.", packages=find_packages('src'), package_dir={'': 'src'}, diff --git a/src/sagemaker/__init__.py b/src/sagemaker/__init__.py index 098c067f96..b293cf2f5c 100644 --- a/src/sagemaker/__init__.py +++ b/src/sagemaker/__init__.py @@ -15,6 +15,7 @@ from sagemaker import estimator from sagemaker.amazon.kmeans import KMeans, KMeansModel, KMeansPredictor from sagemaker.amazon.pca import PCA, PCAModel, PCAPredictor +from sagemaker.amazon.lda import LDA, LDAModel, LDAPredictor from sagemaker.amazon.linear_learner import LinearLearner, LinearLearnerModel, LinearLearnerPredictor from sagemaker.amazon.factorization_machines import FactorizationMachines, FactorizationMachinesModel from sagemaker.amazon.factorization_machines import FactorizationMachinesPredictor @@ -30,6 +31,7 @@ __all__ = [estimator, KMeans, KMeansModel, KMeansPredictor, PCA, PCAModel, PCAPredictor, LinearLearner, LinearLearnerModel, LinearLearnerPredictor, + LDA, LDAModel, LDAPredictor, FactorizationMachines, FactorizationMachinesModel, FactorizationMachinesPredictor, Model, RealTimePredictor, Session, container_def, s3_input, production_variant, get_execution_role] diff --git a/src/sagemaker/amazon/amazon_estimator.py b/src/sagemaker/amazon/amazon_estimator.py index 9fbdb8b631..cc92fec0f4 100644 --- a/src/sagemaker/amazon/amazon_estimator.py +++ b/src/sagemaker/amazon/amazon_estimator.py @@ -47,7 +47,8 @@ def __init__(self, role, train_instance_count, train_instance_type, data_locatio self.data_location = data_location def train_image(self): - return registry(self.sagemaker_session.boto_region_name) + "/" + type(self).repo + repo = '{}:{}'.format(type(self).repo_name, type(self).repo_version) + return '{}/{}'.format(registry(self.sagemaker_session.boto_region_name, type(self).repo_name), repo) def hyperparameters(self): return hp.serialize_all(self) @@ -200,12 +201,22 @@ def upload_numpy_to_s3_shards(num_shards, s3, bucket, key_prefix, array, labels= raise ex -def registry(region_name): +def registry(region_name, algorithm=None): """Return docker registry for the given AWS region""" - account_id = { - "us-east-1": "382416733822", - "us-east-2": "404615174143", - "us-west-2": "174872318107", - "eu-west-1": "438346466558" - }[region_name] + if algorithm in [None, "pca", "kmeans", "linear-learner", "factorization-machines"]: + account_id = { + "us-east-1": "382416733822", + "us-east-2": "404615174143", + "us-west-2": "174872318107", + "eu-west-1": "438346466558" + }[region_name] + elif algorithm in ["lda"]: + account_id = { + "us-east-1": "766337827248", + "us-east-2": "999911452149", + "us-west-2": "266724342769", + "eu-west-1": "999678624901" + }[region_name] + else: + raise ValueError("Algorithm class:{} doesn't have mapping to account_id with images".format(algorithm)) return "{}.dkr.ecr.{}.amazonaws.com".format(account_id, region_name) diff --git a/src/sagemaker/amazon/factorization_machines.py b/src/sagemaker/amazon/factorization_machines.py index 66972316ac..5297367947 100644 --- a/src/sagemaker/amazon/factorization_machines.py +++ b/src/sagemaker/amazon/factorization_machines.py @@ -21,7 +21,8 @@ class FactorizationMachines(AmazonAlgorithmEstimatorBase): - repo = 'factorization-machines:1' + repo_name = 'factorization-machines' + repo_version = 1 num_factors = hp('num_factors', gt(0), 'An integer greater than zero', int) predictor_type = hp('predictor_type', isin('binary_classifier', 'regressor'), @@ -194,7 +195,8 @@ class FactorizationMachinesModel(Model): def __init__(self, model_data, role, sagemaker_session=None): sagemaker_session = sagemaker_session or Session() - image = registry(sagemaker_session.boto_session.region_name) + "/" + FactorizationMachines.repo + repo = '{}:{}'.format(FactorizationMachines.repo_name, FactorizationMachines.repo_version) + image = '{}/{}'.format(registry(sagemaker_session.boto_session.region_name), repo) super(FactorizationMachinesModel, self).__init__(model_data, image, role, diff --git a/src/sagemaker/amazon/kmeans.py b/src/sagemaker/amazon/kmeans.py index 020f496d8f..b684b68f07 100644 --- a/src/sagemaker/amazon/kmeans.py +++ b/src/sagemaker/amazon/kmeans.py @@ -21,7 +21,8 @@ class KMeans(AmazonAlgorithmEstimatorBase): - repo = 'kmeans:1' + repo_name = 'kmeans' + repo_version = 1 k = hp('k', gt(1), 'An integer greater-than 1', int) init_method = hp('init_method', isin('random', 'kmeans++'), 'One of "random", "kmeans++"', str) @@ -132,6 +133,7 @@ class KMeansModel(Model): def __init__(self, model_data, role, sagemaker_session=None): sagemaker_session = sagemaker_session or Session() - image = registry(sagemaker_session.boto_session.region_name) + "/" + KMeans.repo + repo = '{}:{}'.format(KMeans.repo_name, KMeans.repo_version) + image = '{}/{}'.format(registry(sagemaker_session.boto_session.region_name), repo) super(KMeansModel, self).__init__(model_data, image, role, predictor_cls=KMeansPredictor, sagemaker_session=sagemaker_session) diff --git a/src/sagemaker/amazon/lda.py b/src/sagemaker/amazon/lda.py new file mode 100644 index 0000000000..30367b2b0f --- /dev/null +++ b/src/sagemaker/amazon/lda.py @@ -0,0 +1,127 @@ +# Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase, registry +from sagemaker.amazon.common import numpy_to_record_serializer, record_deserializer +from sagemaker.amazon.hyperparameter import Hyperparameter as hp # noqa +from sagemaker.amazon.validation import gt +from sagemaker.predictor import RealTimePredictor +from sagemaker.model import Model +from sagemaker.session import Session + + +class LDA(AmazonAlgorithmEstimatorBase): + + repo_name = 'lda' + repo_version = 1 + + num_topics = hp('num_topics', gt(0), 'An integer greater than zero', int) + alpha0 = hp('alpha0', gt(0), 'A positive float', float) + max_restarts = hp('max_restarts', gt(0), 'An integer greater than zero', int) + max_iterations = hp('max_iterations', gt(0), 'An integer greater than zero', int) + tol = hp('tol', gt(0), 'A positive float', float) + + def __init__(self, role, train_instance_type, num_topics, + alpha0=None, max_restarts=None, max_iterations=None, tol=None, **kwargs): + """Latent Dirichlet Allocation (LDA) is :class:`Estimator` used for unsupervised learning. + + Amazon SageMaker Latent Dirichlet Allocation is an unsupervised learning algorithm that attempts to describe + a set of observations as a mixture of distinct categories. LDA is most commonly used to discover + a user-specified number of topics shared by documents within a text corpus. + Here each observation is a document, the features are the presence (or occurrence count) of each word, and + the categories are the topics. + + This Estimator may be fit via calls to + :meth:`~sagemaker.amazon.amazon_estimator.AmazonAlgorithmEstimatorBase.fit`. It requires Amazon + :class:`~sagemaker.amazon.record_pb2.Record` protobuf serialized data to be stored in S3. + There is an utility :meth:`~sagemaker.amazon.amazon_estimator.AmazonAlgorithmEstimatorBase.record_set` that + can be used to upload data to S3 and creates :class:`~sagemaker.amazon.amazon_estimator.RecordSet` to be passed + to the `fit` call. + + To learn more about the Amazon protobuf Record class and how to prepare bulk data in this format, please + consult AWS technical documentation: https://docs.aws.amazon.com/sagemaker/latest/dg/cdf-training.html + + After this Estimator is fit, model data is stored in S3. The model may be deployed to an Amazon SageMaker + Endpoint by invoking :meth:`~sagemaker.amazon.estimator.EstimatorBase.deploy`. As well as deploying an Endpoint, + deploy returns a :class:`~sagemaker.amazon.lda.LDAPredictor` object that can be used + for inference calls using the trained model hosted in the SageMaker Endpoint. + + LDA Estimators can be configured by setting hyperparameters. The available hyperparameters for + LDA are documented below. + + For further information on the AWS LDA algorithm, + please consult AWS technical documentation: https://docs.aws.amazon.com/sagemaker/latest/dg/lda.html + + Args: + role (str): An AWS IAM role (either name or full ARN). The Amazon SageMaker training jobs and + APIs that create Amazon SageMaker endpoints use this role to access + training data and model artifacts. After the endpoint is created, + the inference code might use the IAM role, if accessing AWS resource. + train_instance_type (str): Type of EC2 instance to use for training, for example, 'ml.c4.xlarge'. + num_topics (int): The number of topics for LDA to find within the data. + alpha0 (float): Optional. Initial guess for the concentration parameter + max_restarts (int): Optional. The number of restarts to perform during the Alternating Least Squares (ALS) + spectral decomposition phase of the algorithm. + max_iterations (int): Optional. The maximum number of iterations to perform during the + ALS phase of the algorithm. + tol (float): Optional. Target error tolerance for the ALS phase of the algorithm. + **kwargs: base class keyword argument values. + """ + + # this algorithm only supports single instance training + super(LDA, self).__init__(role, 1, train_instance_type, **kwargs) + self.num_topics = num_topics + self.alpha0 = alpha0 + self.max_restarts = max_restarts + self.max_iterations = max_iterations + self.tol = tol + + def create_model(self): + """Return a :class:`~sagemaker.amazon.LDAModel` referencing the latest + s3 model data produced by this Estimator.""" + + return LDAModel(self.model_data, self.role, sagemaker_session=self.sagemaker_session) + + def fit(self, records, mini_batch_size, **kwargs): + # mini_batch_size is required, prevent explicit calls with None + if mini_batch_size is None: + raise ValueError("mini_batch_size must be set") + super(LDA, self).fit(records, mini_batch_size, **kwargs) + + +class LDAPredictor(RealTimePredictor): + """Transforms input vectors to lower-dimesional representations. + + The implementation of :meth:`~sagemaker.predictor.RealTimePredictor.predict` in this + `RealTimePredictor` requires a numpy ``ndarray`` as input. The array should contain the + same number of columns as the feature-dimension of the data used to fit the model this + Predictor performs inference on. + + :meth:`predict()` returns a list of :class:`~sagemaker.amazon.record_pb2.Record` objects, one + for each row in the input ``ndarray``. The lower dimension vector result is stored in the ``projection`` + key of the ``Record.label`` field.""" + + def __init__(self, endpoint, sagemaker_session=None): + super(LDAPredictor, self).__init__(endpoint, sagemaker_session, serializer=numpy_to_record_serializer(), + deserializer=record_deserializer()) + + +class LDAModel(Model): + """Reference LDA s3 model data. Calling :meth:`~sagemaker.model.Model.deploy` creates an Endpoint and return + a Predictor that transforms vectors to a lower-dimensional representation.""" + + def __init__(self, model_data, role, sagemaker_session=None): + sagemaker_session = sagemaker_session or Session() + repo = '{}:{}'.format(LDA.repo_name, LDA.repo_version) + image = '{}/{}'.format(registry(sagemaker_session.boto_session.region_name, LDA.repo_name), repo) + super(LDAModel, self).__init__(model_data, image, role, predictor_cls=LDAPredictor, + sagemaker_session=sagemaker_session) diff --git a/src/sagemaker/amazon/linear_learner.py b/src/sagemaker/amazon/linear_learner.py index d1d9dd6cb8..bd5f2b9682 100644 --- a/src/sagemaker/amazon/linear_learner.py +++ b/src/sagemaker/amazon/linear_learner.py @@ -21,7 +21,8 @@ class LinearLearner(AmazonAlgorithmEstimatorBase): - repo = 'linear-learner:1' + repo_name = 'linear-learner' + repo_version = 1 DEFAULT_MINI_BATCH_SIZE = 1000 @@ -226,7 +227,8 @@ class LinearLearnerModel(Model): def __init__(self, model_data, role, sagemaker_session=None): sagemaker_session = sagemaker_session or Session() - image = registry(sagemaker_session.boto_session.region_name) + "/" + LinearLearner.repo + repo = '{}:{}'.format(LinearLearner.repo_name, LinearLearner.repo_version) + image = '{}/{}'.format(registry(sagemaker_session.boto_session.region_name), repo) super(LinearLearnerModel, self).__init__(model_data, image, role, predictor_cls=LinearLearnerPredictor, sagemaker_session=sagemaker_session) diff --git a/src/sagemaker/amazon/pca.py b/src/sagemaker/amazon/pca.py index 48f99c04d1..fa0f7e7217 100644 --- a/src/sagemaker/amazon/pca.py +++ b/src/sagemaker/amazon/pca.py @@ -20,7 +20,8 @@ class PCA(AmazonAlgorithmEstimatorBase): - repo = 'pca:1' + repo_name = 'pca' + repo_version = 1 DEFAULT_MINI_BATCH_SIZE = 500 @@ -118,6 +119,7 @@ class PCAModel(Model): def __init__(self, model_data, role, sagemaker_session=None): sagemaker_session = sagemaker_session or Session() - image = registry(sagemaker_session.boto_session.region_name) + "/" + PCA.repo + repo = '{}:{}'.format(PCA.repo_name, PCA.repo_version) + image = '{}/{}'.format(registry(sagemaker_session.boto_session.region_name), repo) super(PCAModel, self).__init__(model_data, image, role, predictor_cls=PCAPredictor, sagemaker_session=sagemaker_session) diff --git a/tests/data/lda/nips-train_1.pbr b/tests/data/lda/nips-train_1.pbr new file mode 100644 index 0000000000..193cc98860 Binary files /dev/null and b/tests/data/lda/nips-train_1.pbr differ diff --git a/tests/integ/test_lda.py b/tests/integ/test_lda.py new file mode 100644 index 0000000000..c6df685254 --- /dev/null +++ b/tests/integ/test_lda.py @@ -0,0 +1,77 @@ +# Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +import boto3 +import numpy as np +import os +from six.moves.urllib.parse import urlparse + +import sagemaker +from sagemaker import LDA, LDAModel +from sagemaker.amazon.amazon_estimator import RecordSet +from sagemaker.amazon.common import read_records +from sagemaker.utils import name_from_base, sagemaker_timestamp + +from tests.integ import DATA_DIR, REGION +from tests.integ.timeout import timeout, timeout_and_delete_endpoint_by_name + + +def test_lda(): + + with timeout(minutes=15): + sagemaker_session = sagemaker.Session(boto_session=boto3.Session(region_name=REGION)) + data_path = os.path.join(DATA_DIR, 'lda') + data_filename = 'nips-train_1.pbr' + + with open(os.path.join(data_path, data_filename), 'rb') as f: + all_records = read_records(f) + + # all records must be same + feature_num = int(all_records[0].features['values'].float32_tensor.shape[0]) + + lda = LDA(role='SageMakerRole', train_instance_type='ml.c4.xlarge', num_topics=10, + sagemaker_session=sagemaker_session, base_job_name='test-lda') + + record_set = _prepare_record_set_from_local_files(data_path, lda.data_location, + len(all_records), feature_num, sagemaker_session) + lda.fit(record_set, 100) + + endpoint_name = name_from_base('lda') + with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session, minutes=20): + model = LDAModel(lda.model_data, role='SageMakerRole', sagemaker_session=sagemaker_session) + predictor = model.deploy(1, 'ml.c4.xlarge', endpoint_name=endpoint_name) + + predict_input = np.random.rand(1, feature_num) + result = predictor.predict(predict_input) + + assert len(result) == 1 + for record in result: + assert record.label["topic_mixture"] is not None + + +def _prepare_record_set_from_local_files(dir_path, destination, num_records, feature_dim, sagemaker_session): + """Build a :class:`~RecordSet` by pointing to local files. + + Args: + dir_path (string): Path to local directory from where the files shall be uploaded. + destination (string): S3 path to upload the file to. + num_records (int): Number of records in all the files + feature_dim (int): Number of features in the data set + sagemaker_session (sagemaker.session.Session): Session object to manage interactions with Amazon SageMaker APIs. + Returns: + RecordSet: A RecordSet specified by S3Prefix to to be used in training. + """ + key_prefix = urlparse(destination).path + key_prefix = key_prefix + '{}-{}'.format("testfiles", sagemaker_timestamp()) + key_prefix = key_prefix.lstrip('/') + uploaded_location = sagemaker_session.upload_data(path=dir_path, key_prefix=key_prefix) + return RecordSet(uploaded_location, num_records, feature_dim, s3_data_type='S3Prefix') diff --git a/tests/unit/test_amazon_estimator.py b/tests/unit/test_amazon_estimator.py index a9eb15886e..005a3ee9d8 100644 --- a/tests/unit/test_amazon_estimator.py +++ b/tests/unit/test_amazon_estimator.py @@ -18,7 +18,6 @@ from sagemaker.amazon.pca import PCA from sagemaker.amazon.amazon_estimator import upload_numpy_to_s3_shards, _build_shards, registry - COMMON_ARGS = {'role': 'myrole', 'train_instance_count': 1, 'train_instance_type': 'ml.c4.xlarge'} REGION = "us-west-2" diff --git a/tests/unit/test_lda.py b/tests/unit/test_lda.py new file mode 100644 index 0000000000..59618a6dd9 --- /dev/null +++ b/tests/unit/test_lda.py @@ -0,0 +1,224 @@ +# Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +import pytest +from mock import Mock, patch + +from sagemaker.amazon.lda import LDA, LDAPredictor +from sagemaker.amazon.amazon_estimator import registry, RecordSet + +ROLE = 'myrole' +TRAIN_INSTANCE_COUNT = 1 +TRAIN_INSTANCE_TYPE = 'ml.c4.xlarge' +NUM_TOPICS = 3 + +COMMON_TRAIN_ARGS = {'role': ROLE, 'train_instance_type': TRAIN_INSTANCE_TYPE} +ALL_REQ_ARGS = dict({'num_topics': NUM_TOPICS}, **COMMON_TRAIN_ARGS) + +REGION = "us-west-2" +BUCKET_NAME = "Some-Bucket" + +DESCRIBE_TRAINING_JOB_RESULT = { + 'ModelArtifacts': { + 'S3ModelArtifacts': "s3://bucket/model.tar.gz" + } +} + + +@pytest.fixture() +def sagemaker_session(): + boto_mock = Mock(name='boto_session', region_name=REGION) + sms = Mock(name='sagemaker_session', boto_session=boto_mock) + sms.boto_region_name = REGION + sms.default_bucket = Mock(name='default_bucket', return_value=BUCKET_NAME) + sms.sagemaker_client.describe_training_job = Mock(name='describe_training_job', + return_value=DESCRIBE_TRAINING_JOB_RESULT) + + return sms + + +def test_init_required_positional(sagemaker_session): + lda = LDA(ROLE, TRAIN_INSTANCE_TYPE, NUM_TOPICS, sagemaker_session=sagemaker_session) + assert lda.role == ROLE + assert lda.train_instance_count == TRAIN_INSTANCE_COUNT + assert lda.train_instance_type == TRAIN_INSTANCE_TYPE + assert lda.num_topics == NUM_TOPICS + + +def test_init_required_named(sagemaker_session): + lda = LDA(sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) + + assert lda.role == COMMON_TRAIN_ARGS['role'] + assert lda.train_instance_count == TRAIN_INSTANCE_COUNT + assert lda.train_instance_type == COMMON_TRAIN_ARGS['train_instance_type'] + assert lda.num_topics == ALL_REQ_ARGS['num_topics'] + + +def test_all_hyperparameters(sagemaker_session): + lda = LDA(sagemaker_session=sagemaker_session, + alpha0=2.2, max_restarts=3, max_iterations=10, tol=3.3, + **ALL_REQ_ARGS) + assert lda.hyperparameters() == dict( + num_topics=str(ALL_REQ_ARGS['num_topics']), + alpha0='2.2', + max_restarts='3', + max_iterations='10', + tol='3.3', + ) + + +def test_image(sagemaker_session): + lda = LDA(sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) + assert lda.train_image() == registry(REGION, "lda") + '/lda:1' + + +def test_num_topics_validation_fail_type(sagemaker_session): + with pytest.raises(ValueError): + LDA(num_topics='other', sagemaker_session=sagemaker_session, **COMMON_TRAIN_ARGS) + + +def test_num_topics_validation_fail_value(sagemaker_session): + with pytest.raises(ValueError): + LDA(num_topics=0, sagemaker_session=sagemaker_session, **COMMON_TRAIN_ARGS) + + +def test_alpha0_validation_fail_type(sagemaker_session): + with pytest.raises(ValueError): + LDA(alpha0='other', sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) + + +def test_max_restarts_validation_fail_type(sagemaker_session): + with pytest.raises(ValueError): + LDA(max_restarts='other', sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) + + +def test_max_restarts_validation_fail_type2(sagemaker_session): + with pytest.raises(ValueError): + LDA(max_restarts=0.1, sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) + + +def test_max_restarts_validation_fail_value(sagemaker_session): + with pytest.raises(ValueError): + LDA(max_restarts=0, sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) + + +def test_max_iterations_validation_fail_type(sagemaker_session): + with pytest.raises(ValueError): + LDA(max_iterations='other', sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) + + +def test_max_iterations_validation_fail_value(sagemaker_session): + with pytest.raises(ValueError): + LDA(max_iterations=0, sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) + + +def test_tol_validation_fail_type(sagemaker_session): + with pytest.raises(ValueError): + LDA(tol='other', sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) + + +def test_tol_validation_fail_value(sagemaker_session): + with pytest.raises(ValueError): + LDA(tol=0, sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) + + +PREFIX = "prefix" +BASE_TRAIN_CALL = { + 'hyperparameters': {}, + 'image': registry(REGION, "lda") + '/lda:1', + 'input_config': [{ + 'DataSource': { + 'S3DataSource': { + 'S3DataDistributionType': 'ShardedByS3Key', + 'S3DataType': 'ManifestFile', + 'S3Uri': 's3://{}/{}'.format(BUCKET_NAME, PREFIX) + } + }, + 'ChannelName': 'train' + }], + 'input_mode': 'File', + 'output_config': {'S3OutputPath': 's3://{}/'.format(BUCKET_NAME)}, + 'resource_config': { + 'InstanceCount': TRAIN_INSTANCE_COUNT, + 'InstanceType': TRAIN_INSTANCE_TYPE, + 'VolumeSizeInGB': 30 + }, + 'stop_condition': {'MaxRuntimeInSeconds': 86400} +} + +FEATURE_DIM = 10 +MINI_BATCH_SZIE = 200 +HYPERPARAMS = {'num_topics': NUM_TOPICS, 'feature_dim': FEATURE_DIM, 'mini_batch_size': MINI_BATCH_SZIE} +STRINGIFIED_HYPERPARAMS = dict([(x, str(y)) for x, y in HYPERPARAMS.items()]) +HP_TRAIN_CALL = dict(BASE_TRAIN_CALL) +HP_TRAIN_CALL.update({'hyperparameters': STRINGIFIED_HYPERPARAMS}) + + +@patch("sagemaker.amazon.amazon_estimator.AmazonAlgorithmEstimatorBase.fit") +def test_call_fit(base_fit, sagemaker_session): + lda = LDA(base_job_name="lda", sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) + + data = RecordSet("s3://{}/{}".format(BUCKET_NAME, PREFIX), num_records=1, feature_dim=FEATURE_DIM, channel='train') + + lda.fit(data, MINI_BATCH_SZIE) + + base_fit.assert_called_once() + assert len(base_fit.call_args[0]) == 2 + assert base_fit.call_args[0][0] == data + assert base_fit.call_args[0][1] == MINI_BATCH_SZIE + + +def test_call_fit_none_mini_batch_size(sagemaker_session): + lda = LDA(base_job_name="lda", sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) + + data = RecordSet("s3://{}/{}".format(BUCKET_NAME, PREFIX), num_records=1, feature_dim=FEATURE_DIM, + channel='train') + with pytest.raises(ValueError): + lda.fit(data, None) + + +def test_call_fit_wrong_type_mini_batch_size(sagemaker_session): + lda = LDA(base_job_name="lda", sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) + + data = RecordSet("s3://{}/{}".format(BUCKET_NAME, PREFIX), num_records=1, feature_dim=FEATURE_DIM, + channel='train') + + with pytest.raises(ValueError): + lda.fit(data, "some") + + +def test_call_fit_wrong_value_mini_batch_size(sagemaker_session): + lda = LDA(base_job_name="lda", sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) + + data = RecordSet("s3://{}/{}".format(BUCKET_NAME, PREFIX), num_records=1, feature_dim=FEATURE_DIM, + channel='train') + with pytest.raises(ValueError): + lda.fit(data, 0) + + +def test_model_image(sagemaker_session): + lda = LDA(sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) + data = RecordSet("s3://{}/{}".format(BUCKET_NAME, PREFIX), num_records=1, feature_dim=FEATURE_DIM, channel='train') + lda.fit(data, MINI_BATCH_SZIE) + + model = lda.create_model() + assert model.image == registry(REGION, "lda") + '/lda:1' + + +def test_predictor_type(sagemaker_session): + lda = LDA(sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) + data = RecordSet("s3://{}/{}".format(BUCKET_NAME, PREFIX), num_records=1, feature_dim=FEATURE_DIM, channel='train') + lda.fit(data, MINI_BATCH_SZIE) + model = lda.create_model() + predictor = model.deploy(1, TRAIN_INSTANCE_TYPE) + + assert isinstance(predictor, LDAPredictor)