## Trainer

What you need: 
Specify the train / test (optional) data channel.

What it does: 
Train the model, deploy it as a Sagemaker endpoint for future inference. Delete the endpoint when the model is no longer in use.

In [None]:
import boto3
import sagemaker
from sagemaker import get_execution_role
from sagemaker.amazon.amazon_estimator import get_image_uri

In [None]:
# Initialize training process
s3_train_path = 's3://sagemaker-deepar20190120/sagemaker/wiki-test-deepar/data/train_subset'
s3_test_path = 's3://sagemaker-deepar20190120/sagemaker/wiki-test-deepar/data/test_subset'
s3_model_output_path = 's3://sagemaker-deepar20190120/sagemaker/wiki-test-deepar/output'

freq = 'H'
prediction_length = 48
context_length = 118

In [None]:
sagemaker_session = sagemaker.Session()
role = get_execution_role()
image_name = get_image_uri(boto3.Session().region_name, 'forecasting-deepar')
estimator = sagemaker.estimator.Estimator(
    sagemaker_session=sagemaker_session,
    image_name=image_name,
    role=role,
    train_instance_count=1,
    train_instance_type='ml.m4.xlarge',
    base_job_name='DEMO-deepar',
    output_path=s3_model_output_path
)

hyperparameters = {
    "time_freq": freq,
    "context_length": str(context_length),
    "prediction_length": str(prediction_length),
    "num_cells": "35",
    "num_layers": "2",
    "likelihood": "student-t",
    "epochs": "39",
    "mini_batch_size": "85",
    "learning_rate": "0.0030902721170490166",
    "dropout_rate": "0.052384954005170334",
    "early_stopping_patience": "10"
}

estimator.set_hyperparameters(**hyperparameters)

data_channels = {
    "train": s3_train_path,
    "test": s3_test_path
}

In [None]:
# training starts here!!!
estimator.fit(inputs=data_channels)

In [None]:
# deploy here!!!
job_name = estimator.latest_training_job.name

endpoint_name = sagemaker_session.endpoint_from_job(
    job_name=job_name,
    initial_instance_count=1,
    instance_type='ml.m4.xlarge',
    deployment_image=image_name,
    role=role
)

In [None]:
sagemaker_session.delete_endpoint(endpoint_name)
