In [1]:
from pyspark.sql import SparkSession
spark = SparkSession.builder.getOrCreate()

In [None]:
from splicemachine.notebook import get_spark_ui, get_mlflow_ui
get_spark_ui()

In [3]:
from splicemachine.spark import PySpliceContext
from splicemachine.mlflow_support import *

splice = PySpliceContext(spark)
mlflow.register_splice_context(splice)

In [None]:
mlflow.set_experiment('Iris_classification')
mlflow.start_run(run_name='Spark Decision Tree')
get_mlflow_ui(mlflow.current_exp_id())

In [14]:
from sklearn.datasets import load_iris
import pandas as pd
import numpy as np

data = load_iris()
cols = [i.replace('(cm)','').strip().replace(' ','_') for i in data.feature_names] + ['label'] # Column name cleanup
pdf = pd.DataFrame(np.c_[data.data, data.target], columns=cols)
df = spark.createDataFrame(pdf)
df.show()

+------------+-----------+------------+-----------+-----+
|sepal_length|sepal_width|petal_length|petal_width|label|
+------------+-----------+------------+-----------+-----+
|         5.1|        3.5|         1.4|        0.2|  0.0|
|         4.9|        3.0|         1.4|        0.2|  0.0|
|         4.7|        3.2|         1.3|        0.2|  0.0|
|         4.6|        3.1|         1.5|        0.2|  0.0|
|         5.0|        3.6|         1.4|        0.2|  0.0|
|         5.4|        3.9|         1.7|        0.4|  0.0|
|         4.6|        3.4|         1.4|        0.3|  0.0|
|         5.0|        3.4|         1.5|        0.2|  0.0|
|         4.4|        2.9|         1.4|        0.2|  0.0|
|         4.9|        3.1|         1.5|        0.1|  0.0|
|         5.4|        3.7|         1.5|        0.2|  0.0|
|         4.8|        3.4|         1.6|        0.2|  0.0|
|         4.8|        3.0|         1.4|        0.1|  0.0|
|         4.3|        3.0|         1.1|        0.1|  0.0|
|         5.8|

In [21]:
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.classification import DecisionTreeClassifier
from pyspark.ml import Pipeline

va = VectorAssembler(inputCols=cols[:-1], outputCol='features') # Define feature vector
dt = DecisionTreeClassifier() # Define model
pipeline = Pipeline(stages=[va, dt]) # Chain our steps together into a full pipeline

train, test = df.randomSplit([0.8,0.2])
model = pipeline.fit(train)
print(model)

PipelineModel_e5f045504804


In [35]:
mlflow.log_pipeline_stages(model)
mlflow.log_feature_transformations(model)
mlflow.lp('maxDepth', model.stages[-1].getOrDefault('maxDepth'))
mlflow.lp('maxBins', model.stages[-1].getOrDefault('maxBins'))
mlflow.log_model(model, 'spark_dt') # Important!

Saving artifact of size: 15.968 KB to Splice Machine DB


In [42]:
mlflow.deploy_db?

In [47]:
from splicemachine.mlflow_support.utilities import get_user
schema = get_user()
run_id = mlflow.current_run_id()
jid = mlflow.deploy_db(schema, 'iris_model', run_id, create_model_table=True, df=df.select(cols[:-1]), primary_key={'MOMENT_ID':'INT'}, classes=list(data.target_names))
mlflow.watch_job(jid)

Deploying model to database...
Your Job has been submitted. The returned value of this function is the job id, which you can use to monitor the your task in real-time. Run mlflow.watch_job(<job id>) tostream them to stdout, or mlflow.fetch_logs(<job id>) to read them one time to a list
---Job Logs------Job Logs---
INFO     2020-09-11 14:17:56.542 - A service worker has found your request
INFO     2020-09-11 14:17:56.665 - Checking whether handler DEPLOY_DATABASE is enabled
INFO     2020-09-11 14:17:56.717 - Handler is available
INFO     2020-09-11 14:17:56.745 - Retrieving Run from MLFlow Tracking Server...
INFO     2020-09-11 14:17:56.850 - Retrieved MLFlow Run
INFO     2020-09-11 14:17:56.880 - Updating MLFlow Run for the UI
INFO     2020-09-11 14:17:56.969 - Reading Model Artifact Stream from Splice Machine
INFO     2020-09-11 14:17:56.996 - Extracting Model from DB with Name: spark_dt
INFO     2020-09-11 14:17:57.064 - Decoding Model Artifact Binary Stream for Deployment
INFO     2

In [73]:
%%sql

insert into iris_model (sepal_length, sepal_width, petal_length, petal_width, moment_id) values (5.1, 3.5, 1.4, 0.2, 0);
insert into iris_model (sepal_length, sepal_width, petal_length, petal_width, moment_id) values (6.4, 2.7, 5.3, 2.0, 1);

select * from iris_model;

  and should_run_async(code)


In [74]:
spark.stop()