In [7]:
from pyspark import SparkContext as sc
from pyspark.sql import SparkSession

In [8]:
from pyspark.ml.regression import LinearRegression

In [9]:
from pyspark.ml.linalg import Vectors
from pyspark.ml.feature import VectorAssembler

from pyspark.ml.feature import StringIndexer


In [10]:
spark1 = SparkSession.builder.appName('StockData').getOrCreate()

In [6]:
df = spark1.read.csv('./all_stocks_5yr.csv',inferSchema=True,header=True)

In [7]:
df.printSchema()

root
 |-- date: string (nullable = true)
 |-- open: double (nullable = true)
 |-- high: double (nullable = true)
 |-- low: double (nullable = true)
 |-- close: double (nullable = true)
 |-- volume: integer (nullable = true)
 |-- Name: string (nullable = true)



## Dataset cleaning

Bellow we check for null values in the dataset and clear them if they exist.

In [8]:
df.createOrReplaceTempView('stock')

In [9]:
nullCriteria = "high IS NULL OR low IS NULL OR volume IS NULL OR date IS NULL OR Name IS NULL";

In [10]:
result = spark1.sql("SELECT * FROM stock WHERE " + nullCriteria)

In [11]:
result.show()

+----------+----+----+----+-------+------+----+
|      date|open|high| low|  close|volume|Name|
+----------+----+----+----+-------+------+----+
|2017-07-26|null|null|null|69.0842|     3| BHF|
|2016-01-12|null|null|null|  88.55|     0| DHR|
|2016-07-01|null|null|null|  49.54|     0| FTV|
|2016-01-12|null|null|null|  52.43|     0|   O|
|2015-06-09|null|null|null| 526.09| 12135|REGN|
|2016-04-07|null|null|null|  41.56|     0|  UA|
|2015-05-12|null|null|null| 124.08|569747|VRTX|
|2015-06-26|null|null|null|   61.9|   100| WRK|
+----------+----+----+----+-------+------+----+



In [12]:
df = df.na.drop()

In [13]:
peakValues = df.groupBy("Name").max().show(100);

+----+---------+---------+--------+----------+-----------+
|Name|max(open)|max(high)|max(low)|max(close)|max(volume)|
+----+---------+---------+--------+----------+-----------+
|ALXN|   206.66|   208.88| 205.509|    207.84|   18836943|
| GIS|    72.65|    72.95|    72.0|     72.64|   19747255|
|   K|     86.9|    87.16|    85.4|     86.98|   11598383|
| LEN|    71.97|    72.17|   71.36|     71.82|   22185910|
|SPGI|    183.8|   185.38| 181.935|     183.8|    9586165|
| AIV|    47.82|    47.91|   47.08|     47.59|    8304107|
| AVY|   121.43|   123.67|   120.8|    122.68|    4154402|
|BF.B|    69.52|  69.9028|   68.63|     69.58|   16152050|
| MMM|   258.51|   259.77|  255.97|    258.63|    9026991|
| PKI|    83.66|    84.49|   81.95|     82.75|    9192085|
| PPG|   120.91| 122.0697|  119.62|    121.47|   12605302|
|  RF|    19.55|     19.9|   19.28|     19.54|   59508341|
| AXP|   102.01|  102.385|  100.33|    101.22|   43783380|
|  CI|   225.76|   227.13|223.7801|    226.22|   1419969

In [14]:
lows = df.groupBy("Name").min().show(100);

+----+---------+---------+--------+----------+-----------+
|Name|min(open)|min(high)|min(low)|min(close)|min(volume)|
+----+---------+---------+--------+----------+-----------+
|ALXN|    82.77|    85.28|   81.82|     83.39|     319770|
| GIS|    42.45|    42.66| 42.4101|      42.6|     950180|
|   K|     57.0|    57.43|   55.69|      56.9|     374751|
| LEN|  30.6667|  31.2255| 30.2941|   30.9118|     714997|
|SPGI|    42.79|    43.96|   42.07|     42.67|     245434|
| AIV|    24.88|    25.14|   24.78|      25.0|     221812|
| AVY|    38.98|    39.33|    38.8|      38.8|     193074|
|BF.B|   32.095|    32.25|   31.92|     32.05|     133754|
| MMM|    102.1|   102.66|  101.75|    101.75|     651007|
| PKI|    29.62|    30.62| 29.5001|     30.35|     159429|
| PPG|   65.185|   65.325|    64.1|     65.04|     371784|
|  RF|     7.18|    7.296|     7.0|      7.08|    3487115|
| AXP|    51.22|    51.59|   50.27|     51.11|     837313|
|  CI|    56.98|    58.07|   55.97|     57.64|     24662

In [15]:
df.columns

['date', 'open', 'high', 'low', 'close', 'volume', 'Name']

In [16]:
from pyspark.sql.functions import expr

df = df.withColumn("dateTimestamp",expr("unix_timestamp(date, 'yyyy-MM-dd')"))
df.show();

+----------+-----+-----+-----+-----+--------+----+-------------+
|      date| open| high|  low|close|  volume|Name|dateTimestamp|
+----------+-----+-----+-----+-----+--------+----+-------------+
|2013-02-08|15.07|15.12|14.63|14.75| 8407500| AAL|   1360274400|
|2013-02-11|14.89|15.01|14.26|14.46| 8882000| AAL|   1360533600|
|2013-02-12|14.45|14.51| 14.1|14.27| 8126000| AAL|   1360620000|
|2013-02-13| 14.3|14.94|14.25|14.66|10259500| AAL|   1360706400|
|2013-02-14|14.94|14.96|13.16|13.99|31879900| AAL|   1360792800|
|2013-02-15|13.93|14.61|13.93| 14.5|15628000| AAL|   1360879200|
|2013-02-19|14.33|14.56|14.08|14.26|11354400| AAL|   1361224800|
|2013-02-20|14.17|14.26|13.15|13.33|14725200| AAL|   1361311200|
|2013-02-21|13.62|13.95| 12.9|13.37|11922100| AAL|   1361397600|
|2013-02-22|13.57| 13.6|13.21|13.57| 6071400| AAL|   1361484000|
|2013-02-25| 13.6|13.76| 13.0|13.02| 7186400| AAL|   1361743200|
|2013-02-26|13.14|13.42| 12.7|13.26| 9419000| AAL|   1361829600|
|2013-02-27|13.28|13.62|1

In [17]:
indexer = StringIndexer(inputCol="Name", outputCol="nameIndex")
df = indexer.fit(df).transform(df)
df.show()

+----------+-----+-----+-----+-----+--------+----+-------------+---------+
|      date| open| high|  low|close|  volume|Name|dateTimestamp|nameIndex|
+----------+-----+-----+-----+-----+--------+----+-------------+---------+
|2013-02-08|15.07|15.12|14.63|14.75| 8407500| AAL|   1360274400|      1.0|
|2013-02-11|14.89|15.01|14.26|14.46| 8882000| AAL|   1360533600|      1.0|
|2013-02-12|14.45|14.51| 14.1|14.27| 8126000| AAL|   1360620000|      1.0|
|2013-02-13| 14.3|14.94|14.25|14.66|10259500| AAL|   1360706400|      1.0|
|2013-02-14|14.94|14.96|13.16|13.99|31879900| AAL|   1360792800|      1.0|
|2013-02-15|13.93|14.61|13.93| 14.5|15628000| AAL|   1360879200|      1.0|
|2013-02-19|14.33|14.56|14.08|14.26|11354400| AAL|   1361224800|      1.0|
|2013-02-20|14.17|14.26|13.15|13.33|14725200| AAL|   1361311200|      1.0|
|2013-02-21|13.62|13.95| 12.9|13.37|11922100| AAL|   1361397600|      1.0|
|2013-02-22|13.57| 13.6|13.21|13.57| 6071400| AAL|   1361484000|      1.0|
|2013-02-25| 13.6|13.76| 

In [18]:
assembler = VectorAssembler(
    inputCols=['dateTimestamp', 'open', 'high', 'low', 'volume', 'nameIndex'],
    outputCol="features")

In [19]:
output = assembler.transform(df)

In [20]:
output.show()

+----------+-----+-----+-----+-----+--------+----+-------------+---------+--------------------+
|      date| open| high|  low|close|  volume|Name|dateTimestamp|nameIndex|            features|
+----------+-----+-----+-----+-----+--------+----+-------------+---------+--------------------+
|2013-02-08|15.07|15.12|14.63|14.75| 8407500| AAL|   1360274400|      1.0|[1.3602744E9,15.0...|
|2013-02-11|14.89|15.01|14.26|14.46| 8882000| AAL|   1360533600|      1.0|[1.3605336E9,14.8...|
|2013-02-12|14.45|14.51| 14.1|14.27| 8126000| AAL|   1360620000|      1.0|[1.36062E9,14.45,...|
|2013-02-13| 14.3|14.94|14.25|14.66|10259500| AAL|   1360706400|      1.0|[1.3607064E9,14.3...|
|2013-02-14|14.94|14.96|13.16|13.99|31879900| AAL|   1360792800|      1.0|[1.3607928E9,14.9...|
|2013-02-15|13.93|14.61|13.93| 14.5|15628000| AAL|   1360879200|      1.0|[1.3608792E9,13.9...|
|2013-02-19|14.33|14.56|14.08|14.26|11354400| AAL|   1361224800|      1.0|[1.3612248E9,14.3...|
|2013-02-20|14.17|14.26|13.15|13.33|1472

In [21]:
final_data = output.select("features",'close')
final_data.show(1000)

+--------------------+-------+
|            features|  close|
+--------------------+-------+
|[1.3602744E9,15.0...|  14.75|
|[1.3605336E9,14.8...|  14.46|
|[1.36062E9,14.45,...|  14.27|
|[1.3607064E9,14.3...|  14.66|
|[1.3607928E9,14.9...|  13.99|
|[1.3608792E9,13.9...|   14.5|
|[1.3612248E9,14.3...|  14.26|
|[1.3613112E9,14.1...|  13.33|
|[1.3613976E9,13.6...|  13.37|
|[1.361484E9,13.57...|  13.57|
|[1.3617432E9,13.6...|  13.02|
|[1.3618296E9,13.1...|  13.26|
|[1.361916E9,13.28...|  13.41|
|[1.3620024E9,13.4...|  13.43|
|[1.3620888E9,13.3...|  13.61|
|[1.362348E9,13.5,...|   13.9|
|[1.3624344E9,14.0...|  14.05|
|[1.3625208E9,14.5...|  14.57|
|[1.3626072E9,14.7...|  14.82|
|[1.3626936E9,14.9...|  14.92|
|[1.3629528E9,14.8...|  15.13|
|[1.3630392E9,15.1...|   15.5|
|[1.3631256E9,15.5...|  15.91|
|[1.363212E9,15.98...|  16.25|
|[1.3632984E9,16.4...|  15.98|
|[1.3635576E9,15.8...|  16.29|
|[1.363644E9,16.48...|  16.78|
|[1.3637304E9,17.1...|  17.23|
|[1.3638168E9,17.2...|   17.0|
|[1.3639

In [22]:
train_data,test_data = final_data.randomSplit([0.7,0.3])
train_data.describe().show()

+-------+-----------------+
|summary|            close|
+-------+-----------------+
|  count|           432728|
|   mean|83.14065450675723|
| stddev|98.08106437414945|
|    min|             1.66|
|    max|           2049.0|
+-------+-----------------+



## Apply Linear Regression to predict future close prices of companies

In [23]:
lr = LinearRegression(labelCol='close')

In [24]:
lrModel = lr.fit(train_data)

In [25]:
print("Coefficients: {} Intercept: {}".format(lrModel.coefficients,lrModel.intercept))

Coefficients: [-2.3759106659457825e-11,-0.5532282403107848,0.7856179207840421,0.7678882612044677,1.9695624141649382e-11,-3.6377929267410807e-06] Intercept: 0.027708083760438634


In [26]:
test_results = lrModel.evaluate(test_data)

In [27]:
test_results.residuals.show()

+--------------------+
|           residuals|
+--------------------+
| -0.1433306946916506|
|-0.02114977456430...|
| 0.02524057071091157|
|-0.01818503165313956|
|-0.01448733697550...|
|0.015135457820235843|
| 0.11648296855627649|
|0.016436076967167423|
| 0.11314145748556825|
|-0.05033073206829641|
| 0.13809761459751968|
|-0.08134854319467877|
|-0.24858304632955708|
| 0.09137594037534313|
| 0.06198530362613042|
|  0.1709423030982009|
| -0.0996231088693662|
| 0.03345435412131792|
|  -1.893748350586975|
|0.042047470299827694|
+--------------------+
only showing top 20 rows



In [28]:
unlabeled_data = test_data.select('features')
unlabeled_data.show()

+--------------------+
|            features|
+--------------------+
|[1.3602744E9,13.2...|
|[1.3602744E9,15.0...|
|[1.3602744E9,15.6...|
|[1.3602744E9,22.6...|
|[1.3602744E9,26.7...|
|[1.3602744E9,35.2...|
|[1.3602744E9,41.2...|
|[1.3602744E9,45.0...|
|[1.3602744E9,46.5...|
|[1.3602744E9,46.6...|
|[1.3602744E9,49.8...|
|[1.3602744E9,77.2...|
|[1.3602744E9,84.1...|
|[1.3602744E9,87.2...|
|[1.3602744E9,127....|
|[1.3602744E9,146....|
|[1.3602744E9,153....|
|[1.3602744E9,163....|
|[1.3602744E9,261....|
|[1.3605336E9,11.7...|
+--------------------+
only showing top 20 rows



In [29]:
predictions = lrModel.transform(unlabeled_data)

In [30]:
predictions.show()

+--------------------+------------------+
|            features|        prediction|
+--------------------+------------------+
|[1.3602744E9,13.2...| 13.55333069469165|
|[1.3602744E9,15.0...|14.771149774564307|
|[1.3602744E9,15.6...|15.264759429289088|
|[1.3602744E9,22.6...| 22.69818503165314|
|[1.3602744E9,26.7...| 26.83448733697551|
|[1.3602744E9,35.2...|35.304864542179764|
|[1.3602744E9,41.2...|41.343517031443724|
|[1.3602744E9,45.0...| 45.12356392303283|
|[1.3602744E9,46.5...| 46.77685854251443|
|[1.3602744E9,46.6...|46.783630732068296|
|[1.3602744E9,49.8...| 50.46190238540248|
|[1.3602744E9,77.2...| 76.64134854319468|
|[1.3602744E9,84.1...| 84.69858304632956|
|[1.3602744E9,87.2...| 87.35862405962466|
|[1.3602744E9,127....|128.77801469637387|
|[1.3602744E9,146....| 146.2790576969018|
|[1.3602744E9,153....|154.17962310886938|
|[1.3602744E9,163....|164.40654564587868|
|[1.3602744E9,261....|263.84374835058696|
|[1.3605336E9,11.7...|11.817952529700172|
+--------------------+------------

In [31]:
print("RMSE: {}".format(test_results.rootMeanSquaredError))
print("MSE: {}".format(test_results.meanSquaredError))

RMSE: 0.6919378837733751
MSE: 0.47877803500077665


In [32]:
import time
import datetime
d = datetime.date(2020,6,7)

unixtime = time.mktime(d.timetuple())

In [33]:
# Find index of apple stock

df.createOrReplaceTempView('stock')
spark1.sql("SELECT Name, nameIndex FROM stock WHERE Name = 'AAPL'").show()

+----+---------+
|Name|nameIndex|
+----+---------+
|AAPL|      3.0|
|AAPL|      3.0|
|AAPL|      3.0|
|AAPL|      3.0|
|AAPL|      3.0|
|AAPL|      3.0|
|AAPL|      3.0|
|AAPL|      3.0|
|AAPL|      3.0|
|AAPL|      3.0|
|AAPL|      3.0|
|AAPL|      3.0|
|AAPL|      3.0|
|AAPL|      3.0|
|AAPL|      3.0|
|AAPL|      3.0|
|AAPL|      3.0|
|AAPL|      3.0|
|AAPL|      3.0|
|AAPL|      3.0|
+----+---------+
only showing top 20 rows



In [34]:
#Create data frame

predictDf = spark1.createDataFrame(
    [
        (unixtime, 330, 335, 330, 3.0, 34312550), 
    ],
    ['dateTimestamp', 'open', 'high', 'low', 'nameIndex', 'volume'] 
)

In [35]:
pred = assembler.transform(predictDf)
pred.show()

+-------------+----+----+---+---------+--------+--------------------+
|dateTimestamp|open|high|low|nameIndex|  volume|            features|
+-------------+----+----+---+---------+--------+--------------------+
|  1.5914772E9| 330| 335|330|      3.0|34312550|[1.5914772E9,330....|
+-------------+----+----+---+---------+--------+--------------------+



In [36]:
predFeatures = pred.select('features')
predFeatures.show()

+--------------------+
|            features|
+--------------------+
|[1.5914772E9,330....|
+--------------------+



In [37]:
predResult = lrModel.transform(predFeatures)

In [38]:
# Prediction for apple stock price 
predResult.show()

+--------------------+-----------------+
|            features|       prediction|
+--------------------+-----------------+
|[1.5914772E9,330....|334.0103712584984|
+--------------------+-----------------+



## Apply Deep Learning methods to predict future close prices of companies

In [1]:
from pyspark import SparkContext, SparkConf
from pyspark.sql import SQLContext
from keras.models import Sequential
from keras.layers.core import Dense, Dropout, Activation
from keras import optimizers, regularizers
from pyspark.mllib.evaluation import MulticlassMetrics
from keras.optimizers import Adam
from pyspark.ml import Pipeline 

Using TensorFlow backend.
  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])
  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


In [2]:
from pyspark.sql.functions import expr
from pyspark.sql.functions import format_number,dayofmonth,hour,dayofyear,month,year,weekofyear,date_format

In [3]:
from pyspark.ml.linalg import Vectors
from pyspark.ml.feature import VectorAssembler

from pyspark.ml.feature import StringIndexer


In [4]:
# Spark Session
conf = SparkConf().setAppName('Spark DL Tabular Pipeline').setMaster('local[6]')
sc = SparkContext(conf=conf)
sql_context = SQLContext(sc)

In [5]:
df = sql_context.read.csv('./all_stocks_5yr.csv',inferSchema=True,header=True)

In [6]:
df.printSchema()

root
 |-- date: string (nullable = true)
 |-- open: double (nullable = true)
 |-- high: double (nullable = true)
 |-- low: double (nullable = true)
 |-- close: double (nullable = true)
 |-- volume: integer (nullable = true)
 |-- Name: string (nullable = true)



In [7]:
df = df.withColumn("dateTimestamp",expr("unix_timestamp(date, 'yyyy-MM-dd')"))

In [8]:
df = df.na.drop()

In [9]:
indexer = StringIndexer(inputCol="Name", outputCol="nameIndex")
df = indexer.fit(df).transform(df)
df = df.drop("name")

df.show()

+----------+-----+-----+-----+-----+--------+-------------+---------+
|      date| open| high|  low|close|  volume|dateTimestamp|nameIndex|
+----------+-----+-----+-----+-----+--------+-------------+---------+
|2013-02-08|15.07|15.12|14.63|14.75| 8407500|   1360274400|      1.0|
|2013-02-11|14.89|15.01|14.26|14.46| 8882000|   1360533600|      1.0|
|2013-02-12|14.45|14.51| 14.1|14.27| 8126000|   1360620000|      1.0|
|2013-02-13| 14.3|14.94|14.25|14.66|10259500|   1360706400|      1.0|
|2013-02-14|14.94|14.96|13.16|13.99|31879900|   1360792800|      1.0|
|2013-02-15|13.93|14.61|13.93| 14.5|15628000|   1360879200|      1.0|
|2013-02-19|14.33|14.56|14.08|14.26|11354400|   1361224800|      1.0|
|2013-02-20|14.17|14.26|13.15|13.33|14725200|   1361311200|      1.0|
|2013-02-21|13.62|13.95| 12.9|13.37|11922100|   1361397600|      1.0|
|2013-02-22|13.57| 13.6|13.21|13.57| 6071400|   1361484000|      1.0|
|2013-02-25| 13.6|13.76| 13.0|13.02| 7186400|   1361743200|      1.0|
|2013-02-26|13.14|13

In [10]:
assembler = VectorAssembler(
    inputCols=['dateTimestamp', 'open', 'close', 'high', 'low', 'volume', 'nameIndex'],
    outputCol="features")

In [11]:
def growth(close, open):
    sign = -1;
    a = close;
    b = open;
  
    if (close > open):
        sign = 1;
        a = open
        b = close

    diff = b-a;
    percentage = diff / close;
    
    # 3 = big move, 2 = medium move, 1 = small move

    if (percentage <= 0.02):
        return 1 * sign
    if (percentage <= 0.05):
        return 2 * sign
    return 3 * sign;

In [12]:
from pyspark.sql.functions import udf
udf_compute_growth = udf(growth)
df = df.withColumn("label",udf_compute_growth(df["close"], df["open"]))
df.show()

+----------+-----+-----+-----+-----+--------+-------------+---------+-----+
|      date| open| high|  low|close|  volume|dateTimestamp|nameIndex|label|
+----------+-----+-----+-----+-----+--------+-------------+---------+-----+
|2013-02-08|15.07|15.12|14.63|14.75| 8407500|   1360274400|      1.0|   -2|
|2013-02-11|14.89|15.01|14.26|14.46| 8882000|   1360533600|      1.0|   -2|
|2013-02-12|14.45|14.51| 14.1|14.27| 8126000|   1360620000|      1.0|   -1|
|2013-02-13| 14.3|14.94|14.25|14.66|10259500|   1360706400|      1.0|    2|
|2013-02-14|14.94|14.96|13.16|13.99|31879900|   1360792800|      1.0|   -3|
|2013-02-15|13.93|14.61|13.93| 14.5|15628000|   1360879200|      1.0|    2|
|2013-02-19|14.33|14.56|14.08|14.26|11354400|   1361224800|      1.0|   -1|
|2013-02-20|14.17|14.26|13.15|13.33|14725200|   1361311200|      1.0|   -3|
|2013-02-21|13.62|13.95| 12.9|13.37|11922100|   1361397600|      1.0|   -1|
|2013-02-22|13.57| 13.6|13.21|13.57| 6071400|   1361484000|      1.0|   -1|
|2013-02-25|

In [13]:
assembledData = assembler.transform(df)
assembledData.show()

+----------+-----+-----+-----+-----+--------+-------------+---------+-----+--------------------+
|      date| open| high|  low|close|  volume|dateTimestamp|nameIndex|label|            features|
+----------+-----+-----+-----+-----+--------+-------------+---------+-----+--------------------+
|2013-02-08|15.07|15.12|14.63|14.75| 8407500|   1360274400|      1.0|   -2|[1.3602744E9,15.0...|
|2013-02-11|14.89|15.01|14.26|14.46| 8882000|   1360533600|      1.0|   -2|[1.3605336E9,14.8...|
|2013-02-12|14.45|14.51| 14.1|14.27| 8126000|   1360620000|      1.0|   -1|[1.36062E9,14.45,...|
|2013-02-13| 14.3|14.94|14.25|14.66|10259500|   1360706400|      1.0|    2|[1.3607064E9,14.3...|
|2013-02-14|14.94|14.96|13.16|13.99|31879900|   1360792800|      1.0|   -3|[1.3607928E9,14.9...|
|2013-02-15|13.93|14.61|13.93| 14.5|15628000|   1360879200|      1.0|    2|[1.3608792E9,13.9...|
|2013-02-19|14.33|14.56|14.08|14.26|11354400|   1361224800|      1.0|   -1|[1.3612248E9,14.3...|
|2013-02-20|14.17|14.26|13.15|

In [14]:
final_data = assembledData.select("features",'label')

In [15]:
train_data,test_data = final_data.randomSplit([0.7,0.3])
train_data.describe().show()

+-------+--------------------+
|summary|               label|
+-------+--------------------+
|  count|              433583|
|   mean|0.025312339275294464|
| stddev|  1.1555130476474396|
|    min|                  -1|
|    max|                   3|
+-------+--------------------+



In [16]:
nr_classes = train_data.select("label").distinct().count();

In [17]:
input_len = len(train_data.select("features").first()[0])

In [18]:
dlModel = Sequential()




In [19]:
dlModel.add(Dense(128, input_shape=(input_len,), activity_regularizer=regularizers.l2(0.01)))





In [20]:
dlModel.add(Activation('relu'))

In [21]:
dlModel.add(Dropout(rate=0.3))


Instructions for updating:
Please use `rate` instead of `keep_prob`. Rate should be set to `rate = 1 - keep_prob`.


In [22]:
dlModel.add(Dense(128, activity_regularizer=regularizers.l2(0.01)))

In [23]:
dlModel.add(Activation('relu'))

In [24]:
dlModel.add(Dropout(rate=0.3))

In [25]:
dlModel.add(Dense(nr_classes))

In [26]:
dlModel.add(Activation("sigmoid"))

In [27]:
dlModel.compile(loss="binary_crossentropy", optimizer="adam")



Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where


In [41]:
dlModel.summary()

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
dense_1 (Dense)              (None, 128)               1024      
_________________________________________________________________
activation_1 (Activation)    (None, 128)               0         
_________________________________________________________________
dropout_1 (Dropout)          (None, 128)               0         
_________________________________________________________________
dense_2 (Dense)              (None, 128)               16512     
_________________________________________________________________
activation_2 (Activation)    (None, 128)               0         
_________________________________________________________________
dropout_2 (Dropout)          (None, 128)               0         
_________________________________________________________________
dense_3 (Dense)              (None, 6)                 774       
__________

In [42]:
optimizer_conf = optimizers.Adam(lr=0.01)

In [43]:
opt_conf = optimizers.serialize(optimizer_conf)

In [44]:
from elephas.ml_model import ElephasEstimator

In [45]:
estimator = ElephasEstimator()

In [46]:
#estimator.setFeaturesCol("features")
#estimator.setLabelCol("growth")
estimator.set_keras_model_config(dlModel.to_yaml())
estimator.set_categorical_labels(True)
estimator.set_nb_classes(nr_classes)
estimator.set_num_workers(2)
estimator.set_epochs(65)
estimator.set_batch_size(64)
estimator.set_verbosity(1)
estimator.set_validation_split(0.10)
estimator.set_optimizer_config(opt_conf)
estimator.set_mode("synchronous")
estimator.set_loss("binary_crossentropy")
estimator.set_metrics(["acc"])

ElephasEstimator_01ecc13126b5

In [47]:
dl_pipeline = Pipeline(stages=[estimator])

In [48]:
def dl_pipeline_fit_score_results(dl_pipeline=dl_pipeline,
                                  train_data=train_data,
                                  test_data=test_data,
                                  label='label'):
    
    fit_dl_pipeline = dl_pipeline.fit(train_data)
    pred_train = fit_dl_pipeline.transform(train_data)
    pred_test = fit_dl_pipeline.transform(test_data)
    
    pnl_train = pred_train.select(label, "prediction")
    pnl_test = pred_test.select(label, "prediction")
    
    pred_and_label_train = pnl_train.rdd.map(lambda row: (row[label], row['prediction']))
    pred_and_label_test = pnl_test.rdd.map(lambda row: (row[label], row['prediction']))
    
    metrics_train = MulticlassMetrics(pred_and_label_train)
    metrics_test = MulticlassMetrics(pred_and_label_test)
    
    print("Training Data Accuracy: {}".format(round(metrics_train.accuracy,4)))
    print("Training Data Confusion Matrix")
    display(pnl_train.crosstab('label', 'prediction').toPandas())
    
    print("\nTest Data Accuracy: {}".format(round(metrics_test.accuracy, 4)))
    print("Test Data Confusion Matrix")
    display(pnl_test.crosstab('label', 'prediction').toPandas())

In [49]:
dl_pipeline_fit_score_results(dl_pipeline=dl_pipeline,
                              train_data=train_data,
                              test_data=test_data,
                              label='label');

  config = yaml.load(yaml_string)


>>> Fit model
>>> Synchronous training complete.
Training Data Accuracy: 0.0
Training Data Confusion Matrix


Unnamed: 0,label_prediction,0.0
0,-3.0,1610
1,1.0,202262
2,2.0,20040
3,3.0,1101
4,-1.0,187300
5,-2.0,21270



Test Data Accuracy: 0.0
Test Data Confusion Matrix


Unnamed: 0,label_prediction,0.0
0,-3.0,698
1,1.0,86515
2,2.0,8577
3,3.0,475
4,-1.0,80268
5,-2.0,8913
