In [2]:
from __future__ import print_function

from pyspark.ml import Pipeline
from pyspark.ml.evaluation import RegressionEvaluator
# $example on$
from pyspark.ml.feature import OneHotEncoder, VectorAssembler, StringIndexer
from pyspark.ml.regression import DecisionTreeRegressor
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
from pyspark.sql.window import Window

if __name__ == "__main__":
    spark = SparkSession \
        .builder \
        .appName("SVMTest") \
        .getOrCreate()

    # $example on$
    # Load the data stored in LIBSVM format as a DataFrame.
    df = spark.read.format("csv").option("header", "true").load("sample123.csv")

w = Window.partitionBy().orderBy("date")

df = df.withColumn('diffOpenClose', df.open - df.close)
df = df.withColumn('diffHighLow', df.high - df.low)
df = df.withColumn('target', F.when(F.lag(df.close).over(w) < df.close, 'yes').otherwise('no'))
df.drop('date')
categoricalColumns = ['high', 'low', 'open', 'close']
stages = []

for categoricalCol in categoricalColumns:
    stringIndexer = StringIndexer(inputCol=categoricalCol, outputCol=categoricalCol + 'Index')
    encoder = OneHotEncoder(inputCols=[stringIndexer.getOutputCol()], outputCols=[categoricalCol + "classVec"])
    stages += [stringIndexer, encoder]

label_stringIdx = StringIndexer(inputCol='target', outputCol='label')
stages += [label_stringIdx]

assembler = VectorAssembler(inputCols=[c + "classVec" for c in categoricalColumns], outputCol="features")
stages += [assembler]

pipeline = Pipeline(stages=stages)
pipelineModel = pipeline.fit(df)
df = pipelineModel.transform(df)

df.select('close', 'label', 'features').show()
(trainingData, testData) = df.randomSplit([0.8, 0.2])

dr = DecisionTreeRegressor(labelCol="label", featuresCol="features")

model = dr.fit(trainingData)
predictions = model.transform(testData)

evaluator = RegressionEvaluator(labelCol="label", predictionCol="prediction", metricName="rmse")
rmse = evaluator.evaluate(predictions)
print("Root Mean Squared Error (RMSE) on test data = %g" % rmse)


+----------+-----+--------------------+
|     close|label|            features|
+----------+-----+--------------------+
| 611.27002|  1.0|(2624,[26,682,133...|
|459.850006|  1.0|(2624,[164,820,14...|
| 544.22998|  0.0|(2624,[19,675,133...|
|    612.75|  0.0|(2624,[601,1257,1...|
|459.109985|  1.0|(2624,[161,817,14...|
|555.330017|  0.0|(2624,[419,1075,1...|
|529.109985|  1.0|(2624,[63,719,137...|
|     499.5|  1.0|(2624,[254,910,15...|
|547.929993|  0.0|(2624,[393,1049,1...|
|517.900024|  1.0|(2624,[49,705,136...|
|507.859985|  1.0|(2624,[274,930,15...|
|    524.25|  0.0|(2624,[55,711,136...|
|515.580017|  1.0|(2624,[297,953,16...|
|542.119995|  0.0|(2624,[376,1032,1...|
| 600.27002|  0.0|(2624,[2,658,1314...|
|544.650024|  1.0|(2624,[385,1041,1...|
|578.070007|  0.0|(2624,[494,1150,1...|
|491.779999|  1.0|(2624,[230,886,15...|
|    526.25|  0.0|(2624,[0,656,1312...|
|490.200012|  1.0|(2624,[227,883,15...|
+----------+-----+--------------------+
only showing top 20 rows

Root Mean Squa