# Training and testing a model with Spark and MemSQL - Pushdown Enabled

Setup Spark Context

In [1]:
import os
os.environ['PYSPARK_SUBMIT_ARGS'] = '--packages "com.memsql:memsql-spark-connector_2.11:3.0.0-rc1-spark-2.4.4" pyspark-shell'

In [2]:
import pyspark
spark = pyspark.sql.SparkSession.builder.master("local[*]").getOrCreate()

Connect to MemSQL

In [3]:
spark.conf.set("spark.datasource.memsql.ddlEndpoint", "localhost")
spark.conf.set("spark.datasource.memsql.user", "root")
spark.conf.set("spark.datasource.memsql.password", "")

spark.conf.set("spark.datasource.memsql.disablePushdown", "false")

Query data from MemSQL Table

In [4]:
data = spark.read.format("memsql") \
    .load("tpch.lineitem_bu") \
    .select('l_partkey','l_suppkey','l_quantity','l_discount','l_tax','l_extendedprice') \
    .limit(1000000)

Assemble features vector (columns used as predictors in model)

In [5]:
feature_columns = ['l_partkey','l_suppkey','l_quantity','l_discount','l_tax']
from pyspark.ml.feature import VectorAssembler
assembler = VectorAssembler(inputCols=feature_columns,outputCol="features")
data_2 = assembler.transform(data)

Separate the data used for training vs. testing model, and run linear regression. 
Measure timespan for performance benchmarking.

In [6]:
train, test = data_2.randomSplit([0.7, 0.3])
from pyspark.ml.regression import LinearRegression
algo = LinearRegression(featuresCol="features", labelCol="l_extendedprice")

Train the model. Note: Check out MemSQL Studio's Resource Usage to see the queries run.

In [7]:
%%time
model = algo.fit(train)

CPU times: user 13.7 ms, sys: 6.57 ms, total: 20.2 ms
Wall time: 29.9 s


Collect model metrics and make predictions.

In [8]:
%%time
# evaluation
evaluation_summary = model.evaluate(test)

CPU times: user 3.33 ms, sys: 1.93 ms, total: 5.26 ms
Wall time: 13.4 s


In [9]:
%%time
# predicting values
predictions = model.transform(test)

CPU times: user 7.42 ms, sys: 558 µs, total: 7.98 ms
Wall time: 57.2 ms


In [10]:
r_squared = evaluation_summary.r2
r_squared

0.8629420794044855

Show predictions.

In [11]:
predictions \
    .select(predictions['l_partkey'],predictions['l_suppkey'] \
    ,predictions['l_quantity'],predictions['l_discount'],predictions['l_tax'],predictions['prediction']) \
    .show() 

+---------+---------+----------+----------+-----+------------------+
|l_partkey|l_suppkey|l_quantity|l_discount|l_tax|        prediction|
+---------+---------+----------+----------+-----+------------------+
|       38|   750039|     46.00|      0.02| 0.01| 68977.74461490566|
|       74|       75|      1.00|      0.08| 0.05|1447.3229085589162|
|      158|   250159|     32.00|      0.00| 0.00| 47934.72253363739|
|      251|      252|     11.00|      0.09| 0.05| 16448.97259354734|
|      297|   500298|      6.00|      0.07| 0.06| 8988.435249517062|
|      318|      319|     29.00|      0.08| 0.05| 43451.31132863032|
|      518|      519|      6.00|      0.02| 0.03| 8933.819999290781|
|      521|   500522|     22.00|      0.08| 0.06| 32990.93922878267|
|      523|      524|     48.00|      0.09| 0.06| 71960.70395181864|
|      541|   500542|     38.00|      0.06| 0.07|56999.236464156704|
|      903|   500904|      5.00|      0.07| 0.05| 7481.871057433555|
|     1005|   751006|      9.00|  