From ac7b85406b5fe92a5d94696c228a5920f18d44b4 Mon Sep 17 00:00:00 2001 From: Ragav Venkatesan Date: Tue, 16 Jan 2018 19:20:24 -0800 Subject: [PATCH 1/5] image classification algorithm api --- src/sagemaker/amazon/image_classification.py | 205 +++++++++++++++++++ tests/integ/test_image_classification.py | 72 +++++++ 2 files changed, 277 insertions(+) create mode 100644 src/sagemaker/amazon/image_classification.py create mode 100644 tests/integ/test_image_classification.py diff --git a/src/sagemaker/amazon/image_classification.py b/src/sagemaker/amazon/image_classification.py new file mode 100644 index 0000000000..70b9ed1fea --- /dev/null +++ b/src/sagemaker/amazon/image_classification.py @@ -0,0 +1,205 @@ +# Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from sagemaker.amazon.amazon_estimator import AmazonS3AlgorithmEstimatorBase, registry +from sagemaker.amazon.hyperparameter import Hyperparameter as hp # noqa +from sagemaker.amazon.validation import gt, isin, isint, ge, isstr, lt, le +from sagemaker.predictor import RealTimePredictor +from sagemaker.model import Model +from sagemaker.session import Session + + +class ImageClassification(AmazonS3AlgorithmEstimatorBase): + + repo = 'image-classification:latest' + + num_classes = hp('num_classes', (gt(1), isint), 'num_classes should be an integer greater-than 1') + num_training_samples = hp('num_training_samples', (gt(1), isint), 'num_training_samples should be an integer greater-than 1') + use_pretrained_model = hp('use_pretrained_model', (isin(0, 1), isint), 'use_pretrained_model should be in the set, [0,1]') + checkpoint_frequency = hp('checkpoint_frequency', (ge(1), isint), 'checkpoint_frequency should be an integer greater-than 1') + num_layers = hp('num_layers', (isin(18, 34, 50, 101, 152, 200, 20, 32, 44, 56, 110), isint), \ + 'num_layers should be in the set [18, 34, 50, 101, 152, 200, 20, 32, 44, 56, 110]' ) + resize = hp('resize', (gt(1), isint), 'resize should be an integer greater-than 1') + epochs = hp('epochs', (ge(1), isint), 'epochs should be an integer greater-than 1') + learning_rate = hp('learning_rate', (gt(0)), 'learning_rate shoudl be a floating point greater than 0' ) + lr_schedule_factor = hp ('lr_schedule_factor', (gt(0)), 'lr_schedule_factor should be a floating point greater than 0') + lr_scheduler_step = hp ('lr_scheduler_step' ,(isstr), 'lr_scheduler_step should be a string input.') + optimizer = hp ('optimizer', (isin('sgd', 'adam', 'rmsprop', 'nag')), \ + 'Should be one optimizer among the list sgd, adam, rmsprop, or nag.') + momentum = hp ('momentum', (ge(0), le(1)), 'momentum is expected in the range 0, 1') + weight_decay = hp ('weight_decay', (ge(0), le(1)), 'weight_decay in range 0 , 1 ') + beta_1 = hp ('beta_1', (ge(0), le(1)), 'beta_1 shoud be in range 0, 1') + beta_2 = hp ('beta_2', (ge(0), le(1)), 'beta_2 should be in the range 0, 1') + eps = hp ('eps', (gt(0), le(1)), 'eps should be in the range 0, 1') + gamma = hp ('gamma', (ge(0), le(1)), 'gamma should be in the range 0, 1') + mini_batch_size = hp ('mini_batch_size', (gt(0)), 'mini_batch_size should be an integer greater than 0') + image_shape = hp ('image_shape', (isstr), 'image_shape is expected to be a string') + augmentation_type = hp ('beta_1', (isin ('crop', 'crop_color', 'crop_color_transform')), \ + 'beta_1 must be from one option offered') + top_k = hp ('top_k', (ge(1), isint), 'top_k should be greater than or equal to 1') + kv_store = hp ('kv_store', (isin ('dist_sync', 'dist_async' )), 'Can be dist_sync or dist_async') + + + + def __init__(self, role, train_instance_count, train_instance_type, num_classes, num_training_samples, resize = None, + lr_scheduler_step = None, use_pretrained_model = 0, checkpoint_frequency = 1 , num_layers = 152, + epochs = 30, learning_rate = 0.1, + lr_schedule_factor = 0.1, optimizer = 'sgd', momentum = 0., weight_decay = 0.0001, beta_1 = 0.9, + beta_2 = 0.999, eps = 1e-8, gamma = 0.9 , mini_batch_size = 32 , image_shape = '3,224,224', + augmentation_type = None, top_k = None, kv_store = None, **kwargs): + """ + An Image classification algorithm :class:`~sagemaker.amazon.AmazonAlgorithmEstimatorBase`. Learns a classifier model that + + This Estimator may be fit via calls to + :meth:`~sagemaker.amazon.amazon_estimator.AmazonS3AlgorithmEstimatorBase.fit` + + After this Estimator is fit, model data is stored in S3. The model may be deployed to an Amazon SageMaker + Endpoint by invoking :meth:`~sagemaker.amazon.estimator.EstimatorBase.deploy`. As well as deploying an Endpoint, + ``deploy`` returns a :class:`~sagemaker.amazon.kmeans.ImageClassificationPredictor` object that can be used to label + assignment, using the trained model hosted in the SageMaker Endpoint. + + ImageClassification Estimators can be configured by setting hyperparameters. The available hyperparameters for + ImageClassification are documented below. For further information on the AWS ImageClassification algorithm, please consult AWS technical + documentation: https://docs.aws.amazon.com/sagemaker/latest/dg/IC-Hyperparameter.html + + Args: + role (str): An AWS IAM role (either name or full ARN). The Amazon SageMaker training jobs and + APIs that create Amazon SageMaker endpoints use this role to access + training data and model artifacts. After the endpoint is created, + the inference code might use the IAM role, if accessing AWS resource. + For more information, see ???. + train_instance_count (int): Number of Amazon EC2 instances to use for training. + train_instance_type (str): Type of EC2 instance to use for training, for example, 'ml.c4.xlarge'. + num_classes (int): Number of output classes. This parameter defines the dimensions of the network output + and is typically set to the number of classes in the dataset. + num_training_samples (int): Number of training examples in the input dataset. If there is a + mismatch between this value and the number of samples in the training + set, then the behavior of the lr_scheduler_step parameter is undefined + and distributed training accuracy might be affected. + use_pretrained_model (int): Flag to indicate whether to use pre-trained model for training. + If set to `1`, then the pretrained model with the corresponding number + of layers is loaded and used for training. Only the top FC layer are + reinitialized with random weights. Otherwise, the network is trained from scratch. Default value: 0 + checkpoint_frequency (int): Period to store model parameters (in number of epochs). Default value: 1 + num_layers (int): Number of layers for the network. For data with large image size (for example, 224x224 - like ImageNet), + we suggest selecting the number of layers from the set [18, 34, 50, 101, 152, 200]. For data with small + image size (for example, 28x28 - like CFAR), we suggest selecting the number of layers from the + set [20, 32, 44, 56, 110]. The number of layers in each set is based on the ResNet paper. + For transfer learning, the number of layers defines the architecture of base network and hence + can only be selected from the set [18, 34, 50, 101, 152, 200]. Default value: 152 + resize (int): Resize the image before using it for training. The images are resized so that the shortest side is of this + parameter. If the parameter is not set, then the training data is used as such without resizing. + Note: This option is available only for inputs specified as application/x-image content-type in + training and validation channels. + epochs (int): Number of training epochs. Default value: 30 + learning_rate (float): Initial learning rate. Float. Range in [0, 1]. Default value: 0.1 + lr_scheduler_factor (flaot): The ratio to reduce learning rate used in conjunction with the `lr_scheduler_step` parameter, + defined as `lr_new = lr_old * lr_scheduler_factor`. Valid values: Float. Range in [0, 1]. Default value: 0.1 + lr_scheduler_step (str): The epochs at which to reduce the learning rate. As explained in the ``lr_scheduler_factor`` parameter, the + learning rate is reduced by ``lr_scheduler_factor`` at these epochs. For example, if the value is set + to "10, 20", then the learning rate is reduced by ``lr_scheduler_factor`` after 10th epoch and again by + ``lr_scheduler_factor`` after 20th epoch. The epochs are delimited by ",". + optimizer (str): The optimizer types. For more details of the parameters for the optimizers, please refer to MXNet's API. + Valid values: One of sgd, adam, rmsprop, or nag. Default value: `sgd`. + momentum (float): The momentum for sgd and nag, ignored for other optimizers. Valid values: Float. Range in [0, 1]. Default value: 0 + weight_decay (float): The coefficient weight decay for sgd and nag, ignored for other optimizers. Range in [0, 1]. Default value: 0.0001 + beta_1 (float): The beta1 for adam, in other words, exponential decay rate for the first moment estimates. Range in [0, 1]. Default value: 0.9 + beta_2 (float): The beta2 for adam, in other words, exponential decay rate for the second moment estimates. Range in [0, 1]. Default value: 0.999 + eps (float): The epsilon for adam and rmsprop. It is usually set to a small value to avoid division by 0. Range in [0, 1]. Default value: 1e-8 + gamma (float): The gamma for rmsprop. A decay factor of moving average of the squared gradient. Range in [0, 1]. Default value: 0.9 + mini_batch_size (int): The batch size for training. In a single-machine multi-GPU setting, each GPU handles mini_batch_size/num_gpu + training samples. For the multi-machine training in dist_sync mode, the actual batch size is mini_batch_size*number + of machines. See MXNet docs for more details. Default value: 32 + image_shape (str): The input image dimensions, which is the same size as the input layer of the network. + The format is defined as 'num_channels, height, width'. The image dimension can take on any value as the + network can handle varied dimensions of the input. However, there may be memory constraints if a larger image + dimension is used. Typical image dimensions for image classification are '3, 224, 224'. This is similar to the ImageNet dataset. + Default value: ‘3, 224, 224’ + augmentation_type: (str): Data augmentation type. The input images can be augmented in multiple ways as specified below. + 'crop' - Randomly crop the image and flip the image horizontally + 'crop_color' - In addition to ‘crop’, three random values in the range [-36, 36], [-50, 50], and [-50, 50] + are added to the corresponding Hue-Saturation-Lightness channels respectively + 'crop_color_transform': In addition to crop_color, random transformations, including rotation, + shear, and aspect ratio variations are applied to the image. The maximum angle of rotation + is 10 degrees, the maximum shear ratio is 0.1, and the maximum aspect changing ratio is 0.25. + top_k (int): Report the top-k accuracy during training. This parameter has to be greater than 1, + since the top-1 training accuracy is the same as the regular training accuracy that has already been reported. + kv_store (str): Weight update synchronization mode during distributed training. The weight updates can be updated either synchronously + or asynchronously across machines. Synchronous updates typically provide better accuracy than asynchronous + updates but can be slower. See distributed training in MXNet for more details. This parameter is not applicable + to single machine training. + 'dist_sync' - The gradients are synchronized after every batch with all the workers. With dist_sync, + batch-size now means the batch size used on each machine. So if there are n machines and we use + batch size b, then dist_sync behaves like local with batch size n*b + 'dist_async'- Performs asynchronous updates. The weights are updated whenever gradients are received from any + machine and the weight updates are atomic. However, the order is not guaranteed. + **kwargs: base class keyword argument values. + """ + super(ImageClassification, self).__init__(role, train_instance_count, train_instance_type, **kwargs) + self.num_classes = num_classes + self.num_training_samples = num_training_samples + self.resize = resize + self.lr_scheduler_step = lr_scheduler_step + self.use_pretrained_model = use_pretrained_model + self.checkpoint_frequency = checkpoint_frequency + self.num_layers = num_layers + self.epochs = epochs + self.learning_rate = learning_rate + self.lr_schedule_factor = lr_schedule_factor + self.optimizetr = optimizer + self.momentum = momentum + self.weight_decay = weight_decay + self.beta_1 = beta_1 + self.beta_2 = beta_2 + self.eps = eps + self.gamma = gamma + self.mini_batch_size = mini_batch_size + self.image_shape = image_shape + self.augmentation_type = augmentation_type + self.top_k = top_k + self.kv_store = kv_store + + def create_model(self): + """Return a :class:`~sagemaker.amazon.image_classification.ImageClassification` referencing the latest + s3 model data produced by this Estimator.""" + return ImageClassificationModel(self.model_data, self.role, self.sagemaker_session) + + def hyperparameters(self): + """Return the SageMaker hyperparameters for training this ImageClassification Estimator""" + hp = dict(force_dense='True') # Not sure what this is. + hp.update(super(ImageClassification, self).hyperparameters()) + return hp + + +class ImageClassificationPredictor(RealTimePredictor): + """Assigns input vectors to their closest cluster in a ImageClassification model. + + The implementation of :meth:`~sagemaker.predictor.RealTimePredictor.predict` in this + `RealTimePredictor` requires a `x-image` as input. + + ``predict()`` returns """ + + def __init__(self, endpoint, sagemaker_session=None): + super(ImageClassifcationPredictor, self).__init__(endpoint, sagemaker_session, serializer=numpy_to_record_serializer(), + deserializer=record_deserializer(), content_type = 'application/x-image') + + +class ImageClassificationModel(Model): + """Reference KMeans s3 model data. Calling :meth:`~sagemaker.model.Model.deploy` creates an Endpoint and return + a Predictor to performs k-means cluster assignment.""" + + def __init__(self, model_data, role, sagemaker_session=None): + sagemaker_session = sagemaker_session or Session() + image = registry(sagemaker_session.boto_session.region_name, algorithm = 'image_classification') + \ + "/" + ImageClassification.repo + super(ImageClassificationModel, self).__init__(model_data, image, role, predictor_cls=ImageClassificationPredictor, + sagemaker_session=sagemaker_session) diff --git a/tests/integ/test_image_classification.py b/tests/integ/test_image_classification.py new file mode 100644 index 0000000000..b2be924685 --- /dev/null +++ b/tests/integ/test_image_classification.py @@ -0,0 +1,72 @@ +# Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +import gzip +import pickle +import sys + +import boto3 +import os + +import sagemaker +from sagemaker import ImageClassification, ImageClassificationModel +from sagemaker.utils import name_from_base +from tests.integ import DATA_DIR, REGION +from tests.integ.timeout import timeout, timeout_and_delete_endpoint_by_name +import urllib + +def download(url): + filename = url.split("/")[-1] + if not os.path.exists(filename): + urllib.request.urlretrieve(url, filename) + + +def upload_to_s3(channel, file, bucket): + s3 = boto3.resource('s3') + data = open(file, "rb") + key = channel + '/' + file + s3.Bucket(bucket).put_object(Key=key, Body=data) + +def test_image_classification(): + + with timeout(minutes=15): + sagemaker_session = sagemaker.Session(boto_session=boto3.Session(region_name=REGION)) + + # caltech-256 + download('http://data.mxnet.io/data/caltech-256/caltech-256-60-train.rec') + upload_to_s3('train', 'caltech-256-60-train.rec', sagemaker_session.default_bucket()) + download('http://data.mxnet.io/data/caltech-256/caltech-256-60-val.rec') + upload_to_s3('validation', 'caltech-256-60-val.rec', sagemaker_session.default_bucket()) + ic = ImageClassification(role='SageMakerRole', train_instance_count=1, + train_instance_type='ml.c4.xlarge', data_location = 's3://' + sagemaker_session.default_bucket(), + num_classes=257, num_training_samples=15420, epochs = 1, image_shape= '3,32,32', + sagemaker_session=sagemaker_session, base_job_name='test-ic') + + ic.epochs = 1 + records = [] + records.append(ic.s3_record_set( 'train', channel = 'train')) + records.append(ic.s3_record_set( 'validation', channel = 'validation')) + import pdb + pdb.set_trace() + ic.fit(records) + """ + endpoint_name = name_from_base('ic') + with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session, minutes=20): + model = ImageClassificationModel(ic.model_data, role='SageMakerRole', sagemaker_session=sagemaker_session) + predictor = model.deploy(1, 'ml.c4.xlarge', endpoint_name=endpoint_name) + result = predictor.predict(train_set[0][:10]) + + assert len(result) == 10 + for record in result: + assert record.label["closest_cluster"] is not None + assert record.label["distance_to_cluster"] is not None + """ From 8b96f693f6f6dd1aeacd62f47548f9ef5aa56891 Mon Sep 17 00:00:00 2001 From: Ragav Venkatesan Date: Tue, 16 Jan 2018 19:20:32 -0800 Subject: [PATCH 2/5] image classification api --- .gitignore | 1 + src/sagemaker/__init__.py | 4 +- src/sagemaker/amazon/amazon_estimator.py | 78 ++++++++++++++++++++---- src/sagemaker/amazon/validation.py | 5 ++ tests/unit/test_amazon_estimator.py | 29 +++++++++ 5 files changed, 103 insertions(+), 14 deletions(-) diff --git a/.gitignore b/.gitignore index 3ee5780429..cb6d0d2664 100644 --- a/.gitignore +++ b/.gitignore @@ -20,3 +20,4 @@ examples/tensorflow/distributed_mnist/data doc/_build **/.DS_Store venv/ +*.rec \ No newline at end of file diff --git a/src/sagemaker/__init__.py b/src/sagemaker/__init__.py index d5901c086d..15e8600742 100644 --- a/src/sagemaker/__init__.py +++ b/src/sagemaker/__init__.py @@ -16,6 +16,7 @@ from sagemaker.amazon.kmeans import KMeans, KMeansModel, KMeansPredictor from sagemaker.amazon.pca import PCA, PCAModel, PCAPredictor from sagemaker.amazon.linear_learner import LinearLearner, LinearLearnerModel, LinearLearnerPredictor +from sagemaker.amazon.image_classification import ImageClassification, ImageClassificationModel, ImageClassificationPredictor from sagemaker.model import Model from sagemaker.predictor import RealTimePredictor @@ -27,5 +28,6 @@ __all__ = [estimator, KMeans, KMeansModel, KMeansPredictor, PCA, PCAModel, PCAPredictor, LinearLearner, - LinearLearnerModel, LinearLearnerPredictor, Model, RealTimePredictor, Session, + LinearLearnerModel, LinearLearnerPredictor, Model, RealTimePredictor, Session, + ImageClassification, ImageClassificationModel, ImageClassificationPredictor, container_def, s3_input, production_variant, get_execution_role] diff --git a/src/sagemaker/amazon/amazon_estimator.py b/src/sagemaker/amazon/amazon_estimator.py index 6da58aa165..67ee7fd0b6 100644 --- a/src/sagemaker/amazon/amazon_estimator.py +++ b/src/sagemaker/amazon/amazon_estimator.py @@ -23,7 +23,6 @@ logger = logging.getLogger(__name__) - class AmazonAlgorithmEstimatorBase(EstimatorBase): """Base class for Amazon first-party Estimator implementations. This class isn't intended to be instantiated directly.""" @@ -128,10 +127,53 @@ def record_set(self, train, labels=None, channel="train"): logger.debug("Created manifest file {}".format(manifest_s3_file)) return RecordSet(manifest_s3_file, num_records=train.shape[0], feature_dim=train.shape[1], channel=channel) +class AmazonS3AlgorithmEstimatorBase(AmazonAlgorithmEstimatorBase): + """Base class for Amazon first-party Estimator implementations. This class ins't + intended to be instantiated directly. This is difference from the base class + because this class handles S3 data""" + + def fit(self, records, mini_batch_size=None, distribution = 'ShardedByS3Key', **kwargs): + """Fit this Estimator on serialized Record objects, stored in S3. + + ``records`` should be a list of instances of :class:`~RecordSet`. This defines a collection of + s3 data files to train this ``Estimator`` on. + + More information on the Amazon Record format is available at: + https://docs.aws.amazon.com/sagemaker/latest/dg/cdf-training.html + + See :meth:`~AmazonS3AlgorithmEstimatorBase.s3_record_set` to construct a ``RecordSet`` object + from :class:`~numpy.ndarray` arrays. + + Args: + records (list): This is a list of :class:`~RecordSet` items The list of records to train + this ``Estimator`` will depend on each algorithm and type of input data. + mini_batch_size (int or None): The size of each mini-batch to use when training. If None, a + default value will be used. + """ + default_mini_batch_size = self.MAX_DEFAULT_BATCH_SIZE + self.mini_batch_size = mini_batch_size or default_mini_batch_size + #self.feature_dim = records.feature_dim + data = {} + for record in records: + data = {record.channel: s3_input(record.s3_data, distribution=distribution, + s3_data_type=record.s3_data_type)} + super(AmazonAlgorithmEstimatorBase, self).fit(data, **kwargs) + + def s3_record_set(self, s3_loc, channel="train" ): + """Build a :class:`~RecordSet` from a S3 location with data in it. + + Args: + s3_loc (str): A s3 bucket where data is located + channel (str): The SageMaker TrainingJob channel this RecordSet should be assigned to. + + Returns: + RecordSet: A RecordSet referencing the encoded, uploading training and label data. + """ + return RecordSet(s3_loc, channel=channel) class RecordSet(object): - def __init__(self, s3_data, num_records, feature_dim, s3_data_type='ManifestFile', channel='train'): + def __init__(self, s3_data, num_records = None, feature_dim = None, s3_data_type='ManifestFile', channel='train'): """A collection of Amazon :class:~`Record` objects serialized and stored in S3. Args: @@ -166,7 +208,6 @@ def _build_shards(num_shards, array): shards.append(array[(num_shards - 1) * shard_size:]) return shards - def upload_numpy_to_s3_shards(num_shards, s3, bucket, key_prefix, array, labels=None): """Upload the training ``array`` and ``labels`` arrays to ``num_shards`` s3 objects, stored in "s3://``bucket``/``key_prefix``/".""" @@ -202,13 +243,24 @@ def upload_numpy_to_s3_shards(num_shards, s3, bucket, key_prefix, array, labels= finally: raise ex - -def registry(region_name): - """Return docker registry for the given AWS region""" - account_id = { - "us-east-1": "382416733822", - "us-east-2": "404615174143", - "us-west-2": "174872318107", - "eu-west-1": "438346466558" - }[region_name] - return "{}.dkr.ecr.{}.amazonaws.com".format(account_id, region_name) +def registry(region_name, algorithm = None): + """Return docker registry for the given AWS region + + Args: + algorithm (str): Provide the algorithm to get the docker back""" + if algorithm is None: + account_id = { + "us-east-1": "382416733822", + "us-east-2": "404615174143", + "us-west-2": "174872318107", + "eu-west-1": "438346466558" + }[region_name] + return "{}.dkr.ecr.{}.amazonaws.com".format(account_id, region_name) + elif algorithm in ['image_classification']: + account_id = { + "us-east-1": "811284229777", + "us-east-2": "825641698319", + "us-west-2": "433757028032", + "eu-west-1": "685385470294" + }[region_name] + return "{}.dkr.ecr.{}.amazonaws.com".format(account_id, region_name) \ No newline at end of file diff --git a/src/sagemaker/amazon/validation.py b/src/sagemaker/amazon/validation.py index ff3259be8f..93fa960d30 100644 --- a/src/sagemaker/amazon/validation.py +++ b/src/sagemaker/amazon/validation.py @@ -30,6 +30,10 @@ def validate(value): return value < maximum return validate +def le(maximum): + def validate(value): + return value <= maximum + return validate def isin(*expected): def validate(value): @@ -45,4 +49,5 @@ def validate(value): isint = istype(int) isbool = istype(bool) +isstr = istype(str) isnumber = istype(numbers.Number) # noqa diff --git a/tests/unit/test_amazon_estimator.py b/tests/unit/test_amazon_estimator.py index a9eb15886e..e562be86d5 100644 --- a/tests/unit/test_amazon_estimator.py +++ b/tests/unit/test_amazon_estimator.py @@ -16,6 +16,7 @@ # Use PCA as a test implementation of AmazonAlgorithmEstimator from sagemaker.amazon.pca import PCA +from sagemaker.amazon.image_classification import ImageClassification from sagemaker.amazon.amazon_estimator import upload_numpy_to_s3_shards, _build_shards, registry @@ -63,6 +64,10 @@ def test_init(sagemaker_session): pca = PCA(num_components=55, sagemaker_session=sagemaker_session, **COMMON_ARGS) assert pca.num_components == 55 +def test_s3_init(sagemaker_session): + ic = ImageClassification(epochs = 12, num_classes = 2, num_training_samples = 2, + sagemaker_session=sagemaker_session, **COMMON_ARGS) + assert ic.epochs == 12 def test_init_all_pca_hyperparameters(sagemaker_session): pca = PCA(num_components=55, algorithm_mode='randomized', @@ -72,6 +77,14 @@ def test_init_all_pca_hyperparameters(sagemaker_session): assert pca.algorithm_mode == 'randomized' assert pca.extra_components == 33 +def test_init_all_ic_hyperparameters(sagemaker_session): + ic = ImageClassification(data_location='s3://some-bucket/some-key/', + num_classes=257, num_training_samples=15420, epochs = 1, + image_shape= '3,32,32', sagemaker_session=sagemaker_session, + **COMMON_ARGS) + assert ic.num_classes == 257 + assert ic.num_training_samples == 15420 + assert ic.image_shape == '3,32,32' def test_init_estimator_args(sagemaker_session): pca = PCA(num_components=1, train_max_run=1234, sagemaker_session=sagemaker_session, @@ -82,6 +95,16 @@ def test_init_estimator_args(sagemaker_session): assert pca.train_max_run == 1234 assert pca.data_location == 's3://some-bucket/some-key/' +def test_init_s3estimator_args(sagemaker_session): + ic = ImageClassification(data_location='s3://some-bucket/some-key/', + num_classes=257, num_training_samples=15420, epochs = 1, + image_shape= '3,32,32', sagemaker_session=sagemaker_session, + **COMMON_ARGS) + assert ic.train_instance_type == COMMON_ARGS['train_instance_type'] + assert ic.train_instance_count == COMMON_ARGS['train_instance_count'] + assert ic.role == COMMON_ARGS['role'] + assert ic.data_location == 's3://some-bucket/some-key/' + def test_data_location_validation(sagemaker_session): pca = PCA(num_components=2, sagemaker_session=sagemaker_session, **COMMON_ARGS) @@ -99,6 +122,12 @@ def test_pca_hyperparameters(sagemaker_session): subtract_mean='True', algorithm_mode='randomized') +def test_ic_hyperparameters(sagemaker_session): + ic = ImageClassification(data_location = 's3://some-bucket/some-key/', + num_classes=257, num_training_samples=15420, epochs = 1, + image_shape= '3,32,32', sagemaker_session=sagemaker_session, + **COMMON_ARGS) + assert isinstance(ic.hyperparameters(),dict) def test_image(sagemaker_session): pca = PCA(num_components=55, sagemaker_session=sagemaker_session, **COMMON_ARGS) From 7167aecd10dd6b85d7290f312fd24f6fc306383a Mon Sep 17 00:00:00 2001 From: Ragav Venkatesan Date: Tue, 16 Jan 2018 19:25:24 -0800 Subject: [PATCH 3/5] sync --- src/sagemaker/__init__.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/sagemaker/__init__.py b/src/sagemaker/__init__.py index 9ab20f045a..2dc35de963 100644 --- a/src/sagemaker/__init__.py +++ b/src/sagemaker/__init__.py @@ -30,6 +30,8 @@ __all__ = [estimator, KMeans, KMeansModel, KMeansPredictor, PCA, PCAModel, PCAPredictor, LinearLearner, - LinearLearnerModel, LinearLearnerPredictor, Model, RealTimePredictor, Session, + LinearLearnerModel, LinearLearnerPredictor, + FactorizationMachines, FactorizationMachinesModel, FactorizationMachinesPredictor, + Model, RealTimePredictor, Session, ImageClassification, ImageClassificationModel, ImageClassificationPredictor, container_def, s3_input, production_variant, get_execution_role ] From 3d985e75476bdfd7257a59bbdabde542e597f528 Mon Sep 17 00:00:00 2001 From: Ragav Venkatesan Date: Tue, 16 Jan 2018 20:35:25 -0800 Subject: [PATCH 4/5] estimator is done. waiting on tests. --- src/sagemaker/amazon/amazon_estimator.py | 10 +++++----- tests/integ/test_image_classification.py | 10 ++++------ 2 files changed, 9 insertions(+), 11 deletions(-) diff --git a/src/sagemaker/amazon/amazon_estimator.py b/src/sagemaker/amazon/amazon_estimator.py index 252a754bf9..12e7d99a8d 100644 --- a/src/sagemaker/amazon/amazon_estimator.py +++ b/src/sagemaker/amazon/amazon_estimator.py @@ -147,13 +147,13 @@ def fit(self, records, mini_batch_size=None, distribution = 'ShardedByS3Key', ** mini_batch_size (int or None): The size of each mini-batch to use when training. If None, a default value will be used. """ - default_mini_batch_size = self.MAX_DEFAULT_BATCH_SIZE + default_mini_batch_size = 32 self.mini_batch_size = mini_batch_size or default_mini_batch_size - #self.feature_dim = records.feature_dim + #self.feature_dim = records.feature_dim data = {} for record in records: - data = {record.channel: s3_input(record.s3_data, distribution=distribution, - s3_data_type=record.s3_data_type)} + data[record.channel] = s3_input(record.s3_data, distribution=distribution, + s3_data_type=record.s3_data_type) super(AmazonAlgorithmEstimatorBase, self).fit(data, **kwargs) def s3_record_set(self, s3_loc, channel="train" ): @@ -166,7 +166,7 @@ def s3_record_set(self, s3_loc, channel="train" ): Returns: RecordSet: A RecordSet referencing the encoded, uploading training and label data. """ - return RecordSet(s3_loc, channel=channel) + return RecordSet(self.data_location + '/' + s3_loc, channel=channel) class RecordSet(object): diff --git a/tests/integ/test_image_classification.py b/tests/integ/test_image_classification.py index b2be924685..750638268f 100644 --- a/tests/integ/test_image_classification.py +++ b/tests/integ/test_image_classification.py @@ -45,18 +45,16 @@ def test_image_classification(): download('http://data.mxnet.io/data/caltech-256/caltech-256-60-train.rec') upload_to_s3('train', 'caltech-256-60-train.rec', sagemaker_session.default_bucket()) download('http://data.mxnet.io/data/caltech-256/caltech-256-60-val.rec') - upload_to_s3('validation', 'caltech-256-60-val.rec', sagemaker_session.default_bucket()) + upload_to_s3('validation', 'caltech-256-60-val.rec', sagemaker_session.default_bucket()) ic = ImageClassification(role='SageMakerRole', train_instance_count=1, - train_instance_type='ml.c4.xlarge', data_location = 's3://' + sagemaker_session.default_bucket(), + train_instance_type='ml.p3.2xlarge', data_location = 's3://' + sagemaker_session.default_bucket(), num_classes=257, num_training_samples=15420, epochs = 1, image_shape= '3,32,32', sagemaker_session=sagemaker_session, base_job_name='test-ic') ic.epochs = 1 records = [] - records.append(ic.s3_record_set( 'train', channel = 'train')) - records.append(ic.s3_record_set( 'validation', channel = 'validation')) - import pdb - pdb.set_trace() + records.append(ic.s3_record_set( 'training', channel = 'train')) + records.append(ic.s3_record_set( 'validation', channel = 'validation')) ic.fit(records) """ endpoint_name = name_from_base('ic') From 24353e2835739bcc28c87e8715887e6d23ff38cc Mon Sep 17 00:00:00 2001 From: Ragav Venkatesan Date: Wed, 24 Jan 2018 12:40:23 -0800 Subject: [PATCH 5/5] formatting for flake --- src/sagemaker/__init__.py | 5 ++-- src/sagemaker/amazon/amazon_estimator.py | 17 +++++++------ tests/unit/test_amazon_estimator.py | 32 ++++++++++++++---------- 3 files changed, 31 insertions(+), 23 deletions(-) diff --git a/src/sagemaker/__init__.py b/src/sagemaker/__init__.py index 2dc35de963..d60e37a061 100644 --- a/src/sagemaker/__init__.py +++ b/src/sagemaker/__init__.py @@ -16,7 +16,8 @@ from sagemaker.amazon.kmeans import KMeans, KMeansModel, KMeansPredictor from sagemaker.amazon.pca import PCA, PCAModel, PCAPredictor from sagemaker.amazon.linear_learner import LinearLearner, LinearLearnerModel, LinearLearnerPredictor -from sagemaker.amazon.image_classification import ImageClassification, ImageClassificationModel, ImageClassificationPredictor +from sagemaker.amazon.image_classification import ImageClassification, ImageClassificationModel +from sagemaker.amazon.image_classification import ImageClassificationPredictor from sagemaker.amazon.factorization_machines import FactorizationMachines, FactorizationMachinesModel from sagemaker.amazon.factorization_machines import FactorizationMachinesPredictor @@ -34,4 +35,4 @@ FactorizationMachines, FactorizationMachinesModel, FactorizationMachinesPredictor, Model, RealTimePredictor, Session, ImageClassification, ImageClassificationModel, ImageClassificationPredictor, - container_def, s3_input, production_variant, get_execution_role ] + container_def, s3_input, production_variant, get_execution_role] \ No newline at end of file diff --git a/src/sagemaker/amazon/amazon_estimator.py b/src/sagemaker/amazon/amazon_estimator.py index 12e7d99a8d..ef9b2c03c3 100644 --- a/src/sagemaker/amazon/amazon_estimator.py +++ b/src/sagemaker/amazon/amazon_estimator.py @@ -124,12 +124,13 @@ def record_set(self, train, labels=None, channel="train"): logger.debug("Created manifest file {}".format(manifest_s3_file)) return RecordSet(manifest_s3_file, num_records=train.shape[0], feature_dim=train.shape[1], channel=channel) + class AmazonS3AlgorithmEstimatorBase(AmazonAlgorithmEstimatorBase): - """Base class for Amazon first-party Estimator implementations. This class ins't - intended to be instantiated directly. This is difference from the base class + """Base class for Amazon first-party Estimator implementations. This class isn't + intended to be instantiated directly. This is difference from the base class because this class handles S3 data""" - def fit(self, records, mini_batch_size=None, distribution = 'ShardedByS3Key', **kwargs): + def fit(self, records, mini_batch_size=None, distribution='ShardedByS3Key', **kwargs): """Fit this Estimator on serialized Record objects, stored in S3. ``records`` should be a list of instances of :class:`~RecordSet`. This defines a collection of @@ -142,21 +143,20 @@ def fit(self, records, mini_batch_size=None, distribution = 'ShardedByS3Key', ** from :class:`~numpy.ndarray` arrays. Args: - records (list): This is a list of :class:`~RecordSet` items The list of records to train + records (list): This is a list of :class:`~RecordSet` items The list of records to train this ``Estimator`` will depend on each algorithm and type of input data. mini_batch_size (int or None): The size of each mini-batch to use when training. If None, a - default value will be used. + default value will be used. """ default_mini_batch_size = 32 self.mini_batch_size = mini_batch_size or default_mini_batch_size - #self.feature_dim = records.feature_dim data = {} for record in records: data[record.channel] = s3_input(record.s3_data, distribution=distribution, - s3_data_type=record.s3_data_type) + s3_data_type=record.s3_data_type) super(AmazonAlgorithmEstimatorBase, self).fit(data, **kwargs) - def s3_record_set(self, s3_loc, channel="train" ): + def s3_record_set(self, s3_loc, channel="train"): """Build a :class:`~RecordSet` from a S3 location with data in it. Args: @@ -168,6 +168,7 @@ def s3_record_set(self, s3_loc, channel="train" ): """ return RecordSet(self.data_location + '/' + s3_loc, channel=channel) +# Re-write a new recordset class for s3 objects. class RecordSet(object): def __init__(self, s3_data, num_records = None, feature_dim = None, s3_data_type='ManifestFile', channel='train'): diff --git a/tests/unit/test_amazon_estimator.py b/tests/unit/test_amazon_estimator.py index e562be86d5..0dac3ce678 100644 --- a/tests/unit/test_amazon_estimator.py +++ b/tests/unit/test_amazon_estimator.py @@ -19,7 +19,6 @@ from sagemaker.amazon.image_classification import ImageClassification from sagemaker.amazon.amazon_estimator import upload_numpy_to_s3_shards, _build_shards, registry - COMMON_ARGS = {'role': 'myrole', 'train_instance_count': 1, 'train_instance_type': 'ml.c4.xlarge'} REGION = "us-west-2" @@ -64,11 +63,13 @@ def test_init(sagemaker_session): pca = PCA(num_components=55, sagemaker_session=sagemaker_session, **COMMON_ARGS) assert pca.num_components == 55 + def test_s3_init(sagemaker_session): - ic = ImageClassification(epochs = 12, num_classes = 2, num_training_samples = 2, - sagemaker_session=sagemaker_session, **COMMON_ARGS) + ic = ImageClassification(epochs=12, num_classes=2, num_training_samples=2, + sagemaker_session=sagemaker_session, **COMMON_ARGS) assert ic.epochs == 12 + def test_init_all_pca_hyperparameters(sagemaker_session): pca = PCA(num_components=55, algorithm_mode='randomized', subtract_mean=True, extra_components=33, sagemaker_session=sagemaker_session, @@ -77,15 +78,17 @@ def test_init_all_pca_hyperparameters(sagemaker_session): assert pca.algorithm_mode == 'randomized' assert pca.extra_components == 33 + def test_init_all_ic_hyperparameters(sagemaker_session): ic = ImageClassification(data_location='s3://some-bucket/some-key/', - num_classes=257, num_training_samples=15420, epochs = 1, - image_shape= '3,32,32', sagemaker_session=sagemaker_session, + num_classes=257, num_training_samples=15420, epochs=1, + image_shape='3,32,32', sagemaker_session=sagemaker_session, **COMMON_ARGS) assert ic.num_classes == 257 assert ic.num_training_samples == 15420 assert ic.image_shape == '3,32,32' + def test_init_estimator_args(sagemaker_session): pca = PCA(num_components=1, train_max_run=1234, sagemaker_session=sagemaker_session, data_location='s3://some-bucket/some-key/', **COMMON_ARGS) @@ -95,11 +98,12 @@ def test_init_estimator_args(sagemaker_session): assert pca.train_max_run == 1234 assert pca.data_location == 's3://some-bucket/some-key/' + def test_init_s3estimator_args(sagemaker_session): ic = ImageClassification(data_location='s3://some-bucket/some-key/', - num_classes=257, num_training_samples=15420, epochs = 1, - image_shape= '3,32,32', sagemaker_session=sagemaker_session, - **COMMON_ARGS) + num_classes=257, num_training_samples=15420, epochs=1, + image_shape='3,32,32', sagemaker_session=sagemaker_session, + **COMMON_ARGS) assert ic.train_instance_type == COMMON_ARGS['train_instance_type'] assert ic.train_instance_count == COMMON_ARGS['train_instance_count'] assert ic.role == COMMON_ARGS['role'] @@ -122,12 +126,14 @@ def test_pca_hyperparameters(sagemaker_session): subtract_mean='True', algorithm_mode='randomized') + def test_ic_hyperparameters(sagemaker_session): - ic = ImageClassification(data_location = 's3://some-bucket/some-key/', - num_classes=257, num_training_samples=15420, epochs = 1, - image_shape= '3,32,32', sagemaker_session=sagemaker_session, - **COMMON_ARGS) - assert isinstance(ic.hyperparameters(),dict) + ic = ImageClassification(data_location='s3://some-bucket/some-key/', + num_classes=257, num_training_samples=15420, epochs=1, + image_shape='3,32,32', sagemaker_session=sagemaker_session, + **COMMON_ARGS) + assert isinstance(ic.hyperparameters(), dict) + def test_image(sagemaker_session): pca = PCA(num_components=55, sagemaker_session=sagemaker_session, **COMMON_ARGS)