# Object detection using managed spot training

This notebook shows how to use [Amazon SageMaker managed spot training](https://docs.aws.amazon.com/sagemaker/latest/dg/model-managed-spot-training.html) to run training jobs at potentially lower cost. Managed spot training uses [Amazon EC2 Spot instances](https://aws.amazon.com/ec2/spot/) and manages the Spot interruptions on your behalf.

To highlight the differences between on-demand and Spot instances, this notebook is the same as [Amazon SageMaker Object Detection using the RecordIO format](https://github.com/awslabs/amazon-sagemaker-examples/blob/master/introduction_to_amazon_algorithms/object_detection_pascalvoc_coco/object_detection_recordio_format.ipynb), but has been updated to use managed spot training. For a full description of the ML use case and how it is being solved, see the original notebook.

## Setup

See [Amazon SageMaker Object Detection using the RecordIO format](https://github.com/awslabs/amazon-sagemaker-examples/blob/master/introduction_to_amazon_algorithms/object_detection_pascalvoc_coco/object_detection_recordio_format.ipynb) for a description of the code.

### Prerequisites

This notebook has been tested with:
* SageMaker Python SDK 1.72.1
* Python 3.6
* Kernel: conda_mxnet_p36

In [None]:
import sagemaker
from sagemaker import get_execution_role

role = get_execution_role()
sess = sagemaker.Session()
bucket = sess.default_bucket() 
prefix = 'DEMO-ObjectDetection'
training_image = sagemaker.image_uris.retrieve(region=sess.boto_region_name, framework='object-detection')


### Download and prepare data
This notebook downloads and uses the Pascal VOC dataset, which has the following database usage rights:
> The VOC data includes images obtained from the Flickr website. Use of these images must respect the corresponding terms of use: 
> * Flickr terms of use (https://www.flickr.com/help/terms)

In [None]:
# Download the dataset
!wget -P /tmp http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar
!wget -P /tmp http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtrainval_06-Nov-2007.tar
!wget -P /tmp http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtest_06-Nov-2007.tar

# Extract the data.
!tar -xf /tmp/VOCtrainval_11-May-2012.tar && rm /tmp/VOCtrainval_11-May-2012.tar
!tar -xf /tmp/VOCtrainval_06-Nov-2007.tar && rm /tmp/VOCtrainval_06-Nov-2007.tar
!tar -xf /tmp/VOCtest_06-Nov-2007.tar && rm /tmp/VOCtest_06-Nov-2007.tar

# Convert data into RecordIO
!python tools/prepare_dataset.py --dataset pascal --year 2007,2012 --set trainval --target VOCdevkit/train.lst
!rm -rf VOCdevkit/VOC2012
!python tools/prepare_dataset.py --dataset pascal --year 2007 --set test --target VOCdevkit/val.lst --no-shuffle
!rm -rf VOCdevkit/VOC2007

### Upload data to Amazon Simple Storage Service (Amazon S3)

In [None]:
# Upload the RecordIO files to train and validation channels
train_channel = prefix + '/train'
validation_channel = prefix + '/validation'

sess.upload_data(path='VOCdevkit/train.rec', bucket=bucket, key_prefix=train_channel)
sess.upload_data(path='VOCdevkit/val.rec', bucket=bucket, key_prefix=validation_channel)

s3_train_data = 's3://{}/{}'.format(bucket, train_channel)
s3_validation_data = 's3://{}/{}'.format(bucket, validation_channel)

## Managed spot training

Managed spot training is controlled by two arguments to the `sagemaker.estimator.Estimator` constructor:

* `train_use_spot_instances`: Set to `True` to use Spot instances for training jobs.
* `train_max_wait`: Represents the amount of time to wait for a Spot instance to become available. Be aware that some Spot instance types take longer to get. You are charged only for actual compute time spent once Spot instances have been acquired, and not for time spent waiting for Spot instances to become available.

Note that `train_max_wait` can be set only if `train_use_spot_instances` is `True` and **must** be greater than or equal to `train_max_run`.

Toggle `train_use_spot_instances` in the following code to see the effect of running the same job using on-demand instances.

In [None]:
train_use_spot_instances = True
train_max_run=3600
train_max_wait = 3600 if train_use_spot_instances else None

### Training

Train the object detector by creating a `sagemaker.estimator.Estimator` object and launching the training job.

In [None]:
s3_output_location = 's3://{}/{}/output'.format(bucket, prefix)
od_model = sagemaker.estimator.Estimator(training_image,
                                         role, 
                                         instance_count=1, 
                                         instance_type='ml.p3.2xlarge',
                                         volume_size = 50,
                                         input_mode= 'File',
                                         output_path=s3_output_location,
                                         sagemaker_session=sess,
                                         use_spot_instances=train_use_spot_instances,
                                         max_run=train_max_run,
                                         max_wait=train_max_wait)

od_model.set_hyperparameters(base_network='resnet-50',
                             use_pretrained_model=1,
                             num_classes=20,
                             mini_batch_size=32,
                             epochs=1,
                             learning_rate=0.001,
                             lr_scheduler_step='3,6',
                             lr_scheduler_factor=0.1,
                             optimizer='sgd',
                             momentum=0.9,
                             weight_decay=0.0005,
                             overlap_threshold=0.5,
                             nms_threshold=0.45,
                             image_shape=300,
                             label_width=350,
                             num_training_samples=16551)

train_data = sagemaker.inputs.TrainingInput(s3_train_data, distribution='FullyReplicated', 
                        content_type='application/x-recordio', s3_data_type='S3Prefix')
validation_data = sagemaker.inputs.TrainingInput(s3_validation_data, distribution='FullyReplicated', 
                             content_type='application/x-recordio', s3_data_type='S3Prefix')
data_channels = {'train': train_data, 'validation': validation_data}

od_model.fit(inputs=data_channels, logs=True)

### Savings
At the end of the job output, two lines are printed:

* `Training seconds: X` : The actual compute time spent on the training job.
* `Billable seconds: Y` : The time you will be billed for after Spot discounting is applied.

When `train_use_spot_instances` is `True`, you should see a notable difference between training and billable seconds. This shows the cost savings when managed spot training is used, and is summarized in the final output:

* `Managed Spot Training savings: (1-Y/X)*100 %`