# R Forecastig Bring-Your-Own Container Demo

This notebook demonstrates how to use the container to create an endpoint serving
forecasts made by the [R forecast package](https://cran.r-project.org/web/packages/forecast/index.html) using Amazon SageMaker.

In [None]:
import json
import time
import urllib
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline
import sagemaker

## Setup


In [None]:
# set this to the ECR image url of the container, returned by build_and_push.sh
# e.g. 123456789.dkr.ecr.us-east-1.amazonaws.com/r_forecast_bring_your_own:latest
CONTAINER_IMAGE = "to_be_set" 

EXECUTION_ROLE_ARN = sagemaker.get_execution_role()  # or set manually

INSTANCE_TYPE = "ml.c5.xlarge"
INSTANCE_COUNT = 1

In [None]:
sagemaker_session = sagemaker.Session()
sagemaker_client = sagemaker_session.sagemaker_client
sagemaker_runtime_client = sagemaker_session.sagemaker_runtime_client

# Region for SageMaker calls -- should be the same as your ECR
print("Region: " + sagemaker_session.boto_region_name)

## Endpoint Creation

In [None]:
# Create endpoint from container image

name = 'r-forecast-test-' + time.strftime("%Y-%m-%d-%H-%M-%S", time.gmtime())

primary_container = {
    'Image': CONTAINER_IMAGE,
}

# Create the Model
# Note that we are not providing a ModelDataUrl in the primary_container, as there is no training step 
create_model_response = sagemaker_client.create_model(
    ModelName = name,
    ExecutionRoleArn = EXECUTION_ROLE_ARN,
    PrimaryContainer = primary_container)

print("ModelArn: " + create_model_response['ModelArn'])

time.sleep(5)  # wait for model creation to finish

# Create the EndpointConfig
create_endpoint_config_response = sagemaker_client.create_endpoint_config(
    EndpointConfigName = name,
    ProductionVariants=[{
        'InstanceType': INSTANCE_TYPE,
        'InitialInstanceCount': INSTANCE_COUNT,
        'ModelName': name,
        'VariantName': 'AllTraffic'}])

print("EndpointConfigArn: " + create_endpoint_config_response['EndpointConfigArn'])

time.sleep(5)

# Create the Endpoint
create_endpoint_response = sagemaker_client.create_endpoint(
    EndpointName=name,
    EndpointConfigName=name)

print("EndpointArn: " + create_endpoint_response['EndpointArn'])

In [None]:
# Query endpoint status
# The status needs to change to 'InService' before continuing

describe_endpoint_response = sagemaker_client.describe_endpoint(EndpointName=name)

print("Creating endpoint", end="")
while describe_endpoint_response['EndpointStatus'] == 'Creating':
    time.sleep(10)
    describe_endpoint_response = sagemaker_client.describe_endpoint(EndpointName=name)
    print(".", end="", flush=True)
print(" done")
    
assert describe_endpoint_response['EndpointStatus'] == 'InService'

## Basic Forecast Request / Response

In [None]:
# Construct a basic request. The request format is the same as the one used
# used by the DeepAR algorithm,
# see https://docs.aws.amazon.com/sagemaker/latest/dg/deepar-in-formats.html

toy_time_series = [1, 2, 3, 4, 5]

request = {
    "instances": [
        {
                "start": "2018-01-01",
                "target": toy_time_series
        }
    ]
}

In [None]:
# obtain a response from the endpoint
response = sagemaker_runtime_client.invoke_endpoint(
    EndpointName=name,
    ContentType='application/json',
    Body=json.dumps(request)
)

forecasts = json.loads(response['Body'].read().decode())["predictions"]
forecasts

In [None]:
# Plot the forecast

fig, ax = plt.subplots(1, 1, figsize=(14, 4))
ax.plot(range(0, len(toy_time_series)), toy_time_series, 'x')
ax.plot(
    range(len(toy_time_series), len(toy_time_series) + len(forecasts[0]["mean"])),
    forecasts[0]["mean"], 'kx')
ax.grid()

## Monthly Milk Production Data Example

Now we'll try a more interesting data set, namely the "Monthly milk production: pounds per cow. Jan 62 – Dec 75" data set from
[Rob Hyndman's Time Series Data Library](https://robjhyndman.com/hyndsight/tsdl/)
available from [here](https://datamarket.com/data/set/22ox).

In [None]:
# retrieve the data from the DataMarket API
URL = "https://datamarket.com/api/v1/series.json?ds=22ox"
response = urllib.request.urlopen(URL).read()
data = json.loads(response[18:-1])[0]['data']
milk_production = pd.Series([x[1] for x in data], index=pd.date_range(data[0][0], periods=len(data), freq="1M"))

In [None]:
# reserve the last two years for testing
milk_production_train = milk_production[:-24]
milk_production_test = milk_production[-24:]

fig, ax = plt.subplots(1, 1, figsize=(14, 4))
milk_production_train.plot(ax=ax)
milk_production_test.plot(ax=ax)
ax.grid()

In [None]:
# For convenience, we define a function for retrieving forecasts from the endpoint
def get_forecast(start, target, method, frequency, prediction_length):
    request = {
        "configuration": {
            "frequency": frequency,
            "method": method,
            "output_types": ["mean", "quantiles"],
            "prediction_length": prediction_length,
            "quantiles": ["0.1", "0.5", "0.9"]
        },
        "instances": [
            {
                    "start": start,
                    "target": target
            }
        ]
    }

    # obtain a response from the endpoint
    response = sagemaker_runtime_client.invoke_endpoint(
        EndpointName=name,
        ContentType='application/json',
        Body=json.dumps(request)
    )

    return json.loads(response['Body'].read().decode())["predictions"]

In [None]:
# Plot the forecasts
fig, axs = plt.subplots(4, 1, sharey=True, figsize=(14, 16))
axs = axs.ravel()
for i, method in enumerate(['ets', 'ets_additive', 'arima', 'tbats']):
    forecasts = get_forecast(
        start=str(milk_production_train.index[0]), 
        target=milk_production_train.values.tolist(), 
        method=method, frequency=12, prediction_length=24
    )
    milk_production_train.plot(ax=axs[i])
    milk_production_test.plot(ax=axs[i])
    pd.Series(forecasts[0]['mean'], milk_production_test.index).plot(ax=axs[i])
    axs[i].fill_between(
        milk_production_test.index,
        forecasts[0]['quantiles']['0.1'],
        forecasts[0]['quantiles']['0.9']
    )
    axs[i].set_title(method)
    axs[i].grid()

If you are planning to make a large number of forecast requests, you can increase the number of instances (and the number of cores per instance) to achieve the desired throughput. Note, however, that the parallelism is on a per-request basis, i.e. instead of sending a single request containing a large number instances, requests containing a few instances each should be made in parallel (e.g. using Python's multiprocessing module). 