In [1]:
import boto3
import sagemaker
from sagemaker.inputs import TrainingInput
from sagemaker.tuner import IntegerParameter
from sagemaker.tuner import ContinuousParameter
from sagemaker.tuner import CategoricalParameter
from sagemaker.tuner import HyperparameterTuner
from sagemaker.xgboost.estimator import XGBoost

In [2]:
sess = sagemaker.Session()
bucket = sess.default_bucket()
role = sagemaker.get_execution_role()
region = boto3.Session().region_name

sm = boto3.Session().client(service_name='sagemaker', region_name=region)

In [3]:
train_data_s3_uri = 's3://{}/data/train/train/'.format(bucket)
validation_data_s3_uri = 's3://{}/data/train/validation'.format(bucket)
test_data_s3_uri = 's3://{}/data/test/'.format(bucket)

In [4]:
s3_input_train_data = TrainingInput(s3_data=train_data_s3_uri)
s3_input_validation_data = TrainingInput(s3_data=validation_data_s3_uri)
s3_input_test_data = TrainingInput(s3_data=test_data_s3_uri)

In [5]:
hyperparameter_ranges = {
    'num_round': IntegerParameter(1, 1000, scaling_type='Logarithmic'),
    'colsample_bytree': ContinuousParameter(0.5, 1, scaling_type='Linear'),
    'lambda': ContinuousParameter(0.0001, 1000, scaling_type='Logarithmic')
}

In [6]:
hyperparameters = {
    'objective': 'binary:logistic',
    'num_round': 50
}

In [7]:
xgboost_container = sagemaker.image_uris.retrieve('xgboost', region, '1.2-2')

In [8]:
estimator = sagemaker.estimator.Estimator(
    image_uri=xgboost_container,
    hyperparameters=hyperparameters,
    role=sagemaker.get_execution_role(),
    instance_count=1,
    instance_type='ml.m5.large'
)

In [9]:
objective_metric_name = 'validation:logloss'

In [10]:
tuner = HyperparameterTuner(
    estimator=estimator,
    objective_type="Minimize",
    objective_metric_name=objective_metric_name,
    hyperparameter_ranges=hyperparameter_ranges,
    max_jobs=2,
    max_parallel_jobs=1,
    strategy='Bayesian',
    early_stopping_type='Auto'
)

In [11]:
tuner.fit(
    inputs={
        'train': s3_input_train_data,
        'validation': s3_input_validation_data,
        'test': s3_input_test_data
    },
    include_cls_metadata=False,
    wait=False
)