In [16]:
import boto3
import sagemaker
from sagemaker.inputs import TrainingInput
from sagemaker.tuner import IntegerParameter
from sagemaker.tuner import ContinuousParameter
from sagemaker.tuner import CategoricalParameter
from sagemaker.tuner import HyperparameterTuner
from sagemaker.xgboost.estimator import XGBoost

In [3]:
sess = sagemaker.Session()
bucket = sess.default_bucket()
role = sagemaker.get_execution_role()
region = boto3.Session().region_name

sm = boto3.Session().client(service_name='sagemaker', region_name=region)

In [7]:
train_data_s3_uri = '{}/train/train/'.format(bucket)
validation_data_s3_uri = '{}/train/validation'.format(bucket)
test_data_s3_uri = '{}/test/'.format(bucket)

In [10]:
s3_input_train_data = TrainingInput(s3_data=train_data_s3_uri)
s3_input_validation_data = TrainingInput(s3_data=validation_data_s3_uri)
s3_input_test_data = TrainingInput(s3_data=test_data_s3_uri)

In [14]:
hyperparameter_ranges = {
    'num_round': IntegerParameter(50, 100, scaling_type='Linear'),
    'alpha': ContinuousParameter(0, 1000, scaling_type='Logarithmic'),
    'booster': CategoricalParameter(['gbtree', 'gblinear', 'dart'])
}

In [15]:
metrics_definitions = [
    {'Name': 'train:accuracy', 'Regex': 'accuracy: ([0-9\\.]+)'},
    {'Name': 'train:loss', 'Regex': 'loss: ([0-9\\.]+)'},
    {'Name': 'validation:loss', 'Regex': 'val_loss: ([0-9\\.]+)'},
    {'Name': 'validation:accuracy', 'Regex': 'val_accuracy: ([0-9\\.]+)'}
]