# Managed Spot Training for XGBoost

This notebook shows usage of SageMaker Managed Spot infrastructure for XGBoost training. Below we show how Spot instances can be used for the 'algorithm mode' and 'script mode' training methods with the XGBoost container. 

[Managed Spot Training](https://docs.aws.amazon.com/sagemaker/latest/dg/model-managed-spot-training.html) uses Amazon EC2 Spot instance to run training jobs instead of on-demand instances. You can specify which training jobs use spot instances and a stopping condition that specifies how long Amazon SageMaker waits for a job to run using Amazon EC2 Spot instances.

This notebook was tested in Amazon SageMaker Studio on a ml.t3.medium instance with Python 3 (Data Science) kernel.

In this notebook we will perform XGBoost training as described [here](). See the original notebook for more details on the data. 

### Setup variables and define functions

In [2]:
!pip3 install -U sagemaker

Collecting sagemaker
  Downloading sagemaker-2.112.2.tar.gz (579 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m579.2/579.2 kB[0m [31m6.4 MB/s[0m eta [36m0:00:00[0m00:01[0m
[?25h  Preparing metadata (setup.py) ... [?25ldone
Collecting schema
  Using cached schema-0.7.5-py2.py3-none-any.whl (17 kB)
Building wheels for collected packages: sagemaker
  Building wheel for sagemaker (setup.py) ... [?25ldone
[?25h  Created wheel for sagemaker: filename=sagemaker-2.112.2-py2.py3-none-any.whl size=796129 sha256=d4ce37abe0c458d35c13e4b445d2c80b86ffe5b6a244f0aeb19ebd3ae05ccb4a
  Stored in directory: /root/.cache/pip/wheels/c9/2a/d8/0db78f00aee63d4fddc2c64edcb1e761ef8e1a502137dcbaeb
Successfully built sagemaker
Installing collected packages: schema, sagemaker
  Attempting uninstall: sagemaker
    Found existing installation: sagemaker 2.107.0
    Uninstalling sagemaker-2.107.0:
      Successfully uninstalled sagemaker-2.107.0
Successfully installed sagemaker-2.112.2 s

In [3]:
%%time

import io
import os
import boto3
import sagemaker

role = sagemaker.get_execution_role()
region = boto3.Session().region_name

# S3 bucket for saving code and model artifacts.
# Feel free to specify a different bucket here if you wish.
bucket = sagemaker.Session().default_bucket()
prefix = "sagemaker/DEMO-xgboost-builtin"
# customize to your bucket where you have would like to store the data

CPU times: user 951 ms, sys: 154 ms, total: 1.11 s
Wall time: 2.12 s


### Fetching the dataset

In [4]:
%%time
s3 = boto3.client("s3")
# Load the dataset
FILE_DATA = "abalone"
s3.download_file(
    "sagemaker-sample-files", f"datasets/tabular/uci_abalone/abalone.libsvm", FILE_DATA
)
sagemaker.Session().upload_data(FILE_DATA, bucket=bucket, key_prefix=prefix + "/train")

CPU times: user 222 ms, sys: 28 ms, total: 250 ms
Wall time: 1.04 s


's3://sagemaker-us-west-2-240487350066/sagemaker/DEMO-xgboost-builtin/train/abalone'

### Obtaining the latest XGBoost container
We obtain the new container by specifying the framework version (1.5-1). This version specifies the upstream XGBoost framework version (1.5) and an additional SageMaker version (1). If you have an existing XGBoost workflow based on the previous (1.0-1, 1.2-2 or 1.3-1) container, this would be the only change necessary to get the same workflow working with the new container.

In [5]:
container = sagemaker.image_uris.retrieve("xgboost", region, "1.5-1")

### Training the XGBoost model

After setting training parameters, we kick off training, and poll for status until training is completed, which in this example, takes few minutes.

To run our training script on SageMaker, we construct a sagemaker.xgboost.estimator.XGBoost estimator, which accepts several constructor arguments:

* __entry_point__: The path to the Python script SageMaker runs for training and prediction.
* __role__: Role ARN
* __hyperparameters__: A dictionary passed to the train function as hyperparameters.
* __train_instance_type__ *(optional)*: The type of SageMaker instances for training. __Note__: This particular mode does not currently support training on GPU instance types.
* __sagemaker_session__ *(optional)*: The session used to train on Sagemaker.

In [6]:
hyperparameters = {
    "max_depth": "5",
    "eta": "0.2",
    "gamma": "4",
    "min_child_weight": "6",
    "subsample": "0.7",
    "objective": "reg:squarederror",
    "num_round": "50",
    "verbosity": "2",
}

instance_type = "ml.m5.2xlarge"
output_path = "s3://{}/{}/{}/output".format(bucket, prefix, "abalone-xgb")
content_type = "libsvm"

If Spot instances are used, the training job can be interrupted, causing it to take longer to start or finish. If a training job is interrupted, a checkpointed snapshot can be used to resume from a previously saved point and can save training time (and cost).

To enable checkpointing for Managed Spot Training using SageMaker XGBoost we need to configure three things: 

1. Enable the `train_use_spot_instances` constructor arg - a simple self-explanatory boolean. 

2. Set the `train_max_wait constructor` arg - this is an int arg representing the amount of time you are willing to wait for Spot infrastructure to become available. Some instance types are harder to get at Spot prices and you may have to wait longer. You are not charged for time spent waiting for Spot infrastructure to become available, you're only charged for actual compute time spent once Spot instances have been successfully procured. 

3. Setup a `checkpoint_s3_uri` constructor arg - this arg will tell SageMaker an S3 location where to save checkpoints. While not strictly necessary, checkpointing is highly recommended for Manage Spot Training jobs due to the fact that Spot instances can be interrupted with short notice and using checkpoints to resume from the last interruption ensures you don't lose any progress made before the interruption.

Feel free to toggle the `train_use_spot_instances` variable to see the effect of running the same job using regular (a.k.a. "On Demand") infrastructure.

Note that `train_max_wait` can be set if and only if `train_use_spot_instances` is enabled and must be greater than or equal to `train_max_run`.

In [7]:
import time
from sagemaker.inputs import TrainingInput

job_name = "DEMO-xgboost-builtin-" + time.strftime("%Y-%m-%d-%H-%M-%S", time.gmtime())
print("Training job", job_name)

# use_spot_instances = True
# max_run = 3600
# max_wait = 7200 if use_spot_instances else None
# checkpoint_s3_uri = (
#     "s3://{}/{}/checkpoints/{}".format(bucket, prefix, job_name) if use_spot_instances else None
# )
# print("Checkpoint path:", checkpoint_s3_uri)

estimator = sagemaker.estimator.Estimator(
    container,
    role,
    hyperparameters=hyperparameters,
    instance_count=1,
    instance_type=instance_type,
    volume_size=5,  # 5 GB
    output_path=output_path,
    sagemaker_session=sagemaker.Session(),
    # use_spot_instances=use_spot_instances,
    # max_run=max_run,
    # max_wait=max_wait,
    # checkpoint_s3_uri=checkpoint_s3_uri,
)
train_input = TrainingInput(
    s3_data="s3://{}/{}/{}".format(bucket, prefix, "train"), content_type="libsvm"
)
estimator.fit({"train": train_input}, job_name=job_name)

Training job DEMO-xgboost-spot-2022-10-12-15-36-09
2022-10-12 15:36:09 Starting - Starting the training job...ProfilerReport-1665588969: InProgress
...
2022-10-12 15:36:42 Starting - Preparing the instances for training......
2022-10-12 15:37:57 Downloading - Downloading input data...
2022-10-12 15:38:37 Training - Downloading the training image...
2022-10-12 15:39:07 Uploading - Uploading generated training model[34m[2022-10-12 15:39:00.514 ip-10-0-71-53.us-west-2.compute.internal:1 INFO utils.py:27] RULE_JOB_STOP_SIGNAL_FILENAME: None[0m
[34m[2022-10-12:15:39:00:INFO] Imported framework sagemaker_xgboost_container.training[0m
[34m[2022-10-12:15:39:00:INFO] Failed to parse hyperparameter objective value reg:squarederror to Json.[0m
[34mReturning the value itself[0m
[34m[2022-10-12:15:39:00:INFO] No GPUs detected (normal if no gpus installed)[0m
[34m[2022-10-12:15:39:00:INFO] Running XGBoost Sagemaker in algorithm mode[0m
[34m[2022-10-12:15:39:00:INFO] files path: /opt/ml/