# Training and testing a model with Spark and MemSQL - Pushdown Disabled

Setup Spark Context

In [1]:
import os
os.environ['PYSPARK_SUBMIT_ARGS'] = '--packages "com.memsql:memsql-spark-connector_2.11:3.0.0-rc1-spark-2.4.4" pyspark-shell'

In [2]:
import pyspark
spark = pyspark.sql.SparkSession.builder.master("local[*]").getOrCreate()

Connect to MemSQL

In [3]:
spark.conf.set("spark.datasource.memsql.ddlEndpoint", "localhost")
spark.conf.set("spark.datasource.memsql.user", "root")
spark.conf.set("spark.datasource.memsql.password", "")
spark.conf.set("spark.datasource.memsql.disablePushdown", "true")

Query data from MemSQL Table

In [4]:
data = spark.read.format("memsql") \
    .load("tpch.lineitem") \
    .select('l_partkey','l_suppkey','l_quantity','l_discount','l_tax','l_extendedprice') \
    .limit(1000000)

Assemble features vector (columns used as predictors in model)

In [5]:
feature_columns = ['l_partkey','l_suppkey','l_quantity','l_discount','l_tax']
from pyspark.ml.feature import VectorAssembler
assembler = VectorAssembler(inputCols=feature_columns,outputCol="features")
data_2 = assembler.transform(data)

Separate the data used for training vs. testing model, and run linear regression. 
Measure timespan for performance benchmarking.

In [6]:
train, test = data_2.randomSplit([0.7, 0.3])
from pyspark.ml.regression import LinearRegression
algo = LinearRegression(featuresCol="features", labelCol="l_extendedprice")

Train the model.
Note: Check out MemSQL Studio's Resource Usage to see the queries run.

In [7]:
%%time
model = algo.fit(train)

CPU times: user 25.9 ms, sys: 9.27 ms, total: 35.1 ms
Wall time: 1min 54s


Collect model metrics and make predictions.

In [8]:
%%time
# evaluation
evaluation_summary = model.evaluate(test)

CPU times: user 13.2 ms, sys: 7.31 ms, total: 20.5 ms
Wall time: 1min 40s


In [9]:
%%time
# predicting values
predictions = model.transform(test)

CPU times: user 5.59 ms, sys: 2.81 ms, total: 8.4 ms
Wall time: 52 ms


In [10]:
r_squared = evaluation_summary.r2
r_squared

0.8626332754492319

Show predictions.

In [11]:
predictions \
    .select(predictions['l_partkey'],predictions['l_suppkey'] \
    ,predictions['l_quantity'],predictions['l_discount'],predictions['l_tax'],predictions['prediction']) \
    .show() 

+---------+---------+----------+----------+-----+------------------+
|l_partkey|l_suppkey|l_quantity|l_discount|l_tax|        prediction|
+---------+---------+----------+----------+-----+------------------+
|       18|   500019|     44.00|      0.08| 0.00| 65970.22715696078|
|       38|   750039|     46.00|      0.02| 0.01| 68952.20722547553|
|       74|       75|      1.00|      0.08| 0.05|1471.1811359388616|
|       97|   500098|      8.00|      0.01| 0.03|11960.307478002207|
|      158|   250159|     32.00|      0.00| 0.00| 47914.35790036007|
|      297|   500298|      6.00|      0.07| 0.06|  8998.83993524212|
|      356|   500357|     32.00|      0.02| 0.05|47947.151749741795|
|      380|   750381|     39.00|      0.00| 0.05|58448.125668747765|
|      389|   750390|      7.00|      0.03| 0.08|10493.942591755862|
|      395|   750396|     46.00|      0.10| 0.02| 68999.96446936413|
|      500|   750501|     22.00|      0.07| 0.03|33002.734381948365|
|      523|      524|     48.00|  