In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
# specify substep parameters for interactive run
# this cell will be replaced during job run with the parameters from json within params subfolder
substep_params={
    "rmse_threshold": 15
}

In [None]:
# load pipeline and step parameters - do not edit
from sinara.substep import get_pipeline_params, get_step_params
pipeline_params = get_pipeline_params(pprint=True)
step_params = get_step_params(pprint=True)

In [None]:
# define substep interface
from sinara.substep import NotebookSubstep, NotebookSubstepRunResult, ENV_NAME, PIPELINE_NAME, ZONE_NAME, STEP_NAME, RUN_ID, ENTITY_NAME, ENTITY_PATH, SUBSTEP_NAME

substep = NotebookSubstep(pipeline_params, step_params, substep_params)

substep.interface(
   
    inputs =
    [
        { STEP_NAME: "model_train", ENTITY_NAME: "bento_service" },
    ],
    outputs = 
    [
    ]
)

substep.print_interface_info()

substep.exit_in_visualize_mode()

In [None]:
# run spark
from sinara.spark import SinaraSpark

spark = SinaraSpark.run_session(0)
SinaraSpark.ui_url()

In [None]:
# read inputs 
bento_step_inputs = substep.inputs(step_name="model_train")

In [None]:
# Test model API
from sinara.bentoml import start_dev_bentoservice, stop_dev_bentoservice, load_bentoservice
from sklearn.metrics import root_mean_squared_error
import pandas as pd
import requests
import json

# Load BentoService
bento_serv = load_bentoservice(bento_step_inputs.bento_service)

# Stop a dev model server if running
stop_dev_bentoservice(bento_serv)

# Start a dev model server to test out the API endpoint locally
start_dev_bentoservice(bento_serv)

serv_v = json.loads(requests.post("http://127.0.0.1:5000/service_version", json={}).text)
print(serv_v)

test_data = json.loads(requests.post("http://127.0.0.1:5000/test_data", json={}).text)

preds = json.loads(requests.post("http://127.0.0.1:5000/predict", 
                                   json=test_data['X']).text)

rmse = root_mean_squared_error(pd.DataFrame(test_data['Y']).values, preds)
print("The root mean squared error (RMSE) on eval set: {:.4f}".format(rmse))

# Stop the dev model server
stop_dev_bentoservice(bento_serv)

In [None]:
# check eval result
rmse_threshold = substep_params["rmse_threshold"]
if rmse > rmse_threshold:
    raise Exception(f'RMSE is {rmse}, more than acceptable value of {rmse_threshold}')

In [None]:
# stop spark
SinaraSpark.stop_session()