# Create the full feature vector dataframe 
## Optimised to broadcast joins and trying to avoid data spills to disk in the executiors

### First! we set up the Spark Context

In [1]:
# Set the new configuration
conf = SparkConf().setAll([('spark.executor.memory', '8g'),\
                           ('spark.driver.memory', '10g'),\
                           ('spark.driver.maxResultSize', 0), \
                           ('spark.shuffle.service.enabled', True), \
                           ('spark.dynamicAllocation.enabled', True), \
                           #('spark.executor.instances', 50), \
                           ('spark.dynamicAllocation.executorIdleTimeout', 600), \
                           ('spark.sql.autoBroadcastJoinThreshold', 52428800), \
                           ('spark.executor.cores', 4),\
                           ('spark.default.parallelism', 90),\
                           ('spark.executor.memoryOverhead', '4g'),\
                           ('spark.driver.memoryOverhead', '4g'),\
                           ('spark.scheduler.mode', 'FAIR'),\
                           ('spark.kryoserializer.buffer.max', '512m'),\
                           ('spark.app.name','FullVectorTesting - JupyterHub version')])# Show the current options




#                           ('spark.dynamicAllocation.maxExecutors', 90), \


# Stop the old context
sc.stop()

# And restart the context with the new configuration
sc = SparkContext(conf=conf)
sqlContext = SQLContext(sc)

In [2]:
import matplotlib
matplotlib.use('Agg')
%matplotlib inline

import matplotlib.pyplot as plt
plt.rcParams['axes.labelsize'] = 14
plt.rcParams['xtick.labelsize'] = 12
plt.rcParams['ytick.labelsize'] = 12

In [3]:
import os
import os.path as osp
#import commands
import time
import random

import numpy as np

import numpy as np
from pyspark import SparkConf,SparkContext, StorageLevel
from pyspark.sql import Row, SQLContext, SparkSession
from pyspark.sql.functions import monotonically_increasing_id
from pyspark.sql.types import *
from pyspark.ml.linalg import Vectors


from datetime import datetime
LogFile=datetime.now().strftime('Program.log')

import logging
logger = logging.getLogger('myapp')
hdlr = logging.FileHandler(LogFile)
formatter = logging.Formatter('%(asctime)s %(levelname)s %(message)s')
hdlr.setFormatter(formatter)
logger.addHandler(hdlr)
logger.setLevel(logging.INFO)

In [4]:
import gc # manual garbag collection to stop leaks on Collect() gc.collect()
import pandas as pd

In [5]:
#from keras.layers import *
#from keras.models import Model, load_model
#from keras.optimizers import Adam, Nadam, SGD
#from keras.callbacks import EarlyStopping, ModelCheckpoint, TensorBoard
from keras.utils import to_categorical
#from keras.preprocessing.sequence import pad_sequences

Using TensorFlow backend.


In [6]:
pgm_start=time.time()
pgm_startCpu=time.clock()

In [7]:
classes = np.array([6, 15, 16, 42, 52, 53, 62, 64, 65, 67, 88, 90, 92, 95, 99], dtype='int32')
class_names = ['class_6','class_15','class_16','class_42','class_52','class_53','class_62','class_64','class_65','class_67','class_88','class_90','class_92','class_95','class_99']
class_weight = {6: 1, 15: 2, 16: 1, 42: 1, 52: 1, 53: 1, 62: 1, 64: 2, 65: 1, 67: 1, 88: 1, 90: 1, 92: 1, 95: 1, 99: 1}

# LSST passbands (nm)  u    g    r    i    z    y      
passbands = np.array([357, 477, 621, 754, 871, 1004], dtype='float32')

In [8]:
augment_count = 35
batch_size = 1000
batch_size2 = 5000
optimizer = 'nadam'
num_models = 1
use_specz = False
valid_size = 0.1
max_epochs = 1000

limit = 1000000
sequence_len = 256

In [9]:
sqlContext = SQLContext(sc)

In [10]:
sqlContext.sql("use plasticc")

DataFrame[]

In [11]:
## Get the augmented vector training set
vectorTable="training_set_vectors"
trainingVectorsDF=sqlContext.sql("select * from {}".format(vectorTable)).persist()

## get some data on the dataframe

In [12]:
trainingVectorsDF.printSchema()

root
 |-- object_id: integer (nullable = true)
 |-- meta: array (nullable = true)
 |    |-- element: double (containsNull = true)
 |-- target: integer (nullable = true)
 |-- specz: double (nullable = true)
 |-- band: array (nullable = true)
 |    |-- element: array (containsNull = true)
 |    |    |-- element: double (containsNull = true)
 |-- hist: array (nullable = true)
 |    |-- element: struct (containsNull = true)
 |    |    |-- interval: array (nullable = true)
 |    |    |    |-- element: array (containsNull = true)
 |    |    |    |    |-- element: double (containsNull = true)
 |    |    |-- deltaMjd: array (nullable = true)
 |    |    |    |-- element: array (containsNull = true)
 |    |    |    |    |-- element: double (containsNull = true)
 |    |    |-- rval: array (nullable = true)
 |    |    |    |-- element: array (containsNull = true)
 |    |    |    |    |-- element: double (containsNull = true)
 |    |    |-- flux: array (nullable = true)
 |    |    |    |-- element:

In [13]:
trainingVectorsDF.explain()

== Physical Plan ==
InMemoryTableScan [object_id#0, meta#1, target#2, specz#3, band#4, hist#5]
   +- InMemoryRelation [object_id#0, meta#1, target#2, specz#3, band#4, hist#5], true, 10000, StorageLevel(disk, memory, 1 replicas)
         +- *(1) FileScan parquet plasticc.training_set_vectors[object_id#0,meta#1,target#2,specz#3,band#4,hist#5] Batched: false, Format: Parquet, Location: InMemoryFileIndex[hdfs://athena-1.nimbus.pawsey.org.au:8020/user/hive/warehouse/plasticc.db/train..., PartitionFilters: [], PushedFilters: [], ReadSchema: struct<object_id:int,meta:array<double>,target:int,specz:double,band:array<array<double>>,hist:ar...


## Set up the training, test and validation splits
We'll do this on the metadata table as we use this table for the master joins

In [14]:
weights = [.8, .1, .1]
seed = 42 # seed=0L
train_df, validation_df, test_df = trainingVectorsDF.randomSplit(weights, seed)

#### Between here
We have dictionary testing. May not actually be any faster, because we still have to do the collect() but it may be only once collect, and not five

In [15]:
list_objects = map(lambda row: row.asDict(), train_df.collect())
#


In [16]:
object_vectors = {object['object_id']: object for object in list_objects}

In [17]:
for key in object_vectors.keys():
    print(key)


7315
37149
38754
62230
63718
73610
80155
81252
81665
88980
99452
103145
110958
112462
115792
122716
123437
138415
151458
151462
160048
169678
178664
189164
190303
194669
198690
203203
207413
213707
215463
220893
225461
231097
234657
235141
235402
245610
252924
253633
271554
275128
285912
286720
1658030
4063882
4625732
6273384
7086564
7390884
8667146
8998750
9754761
17559797
18703334
19370388
19402817
21386867
21400791
21816635
21932332
22393703
22408171
24415853
25517472
27570837
29331377
31294384
31529185
31965152
32237771
33463014
34001553
34276621
34971934
35689739
35730975
37077473
37268602
38396910
38941537
42157427
42364797
44982484
45308850
46609480
47894890
50935885
52400071
53540611
54793663
56949500
58412884
60983185
62662302
64264583
65114256
66586389
66743070
68699474
70078582
72361403
72568232
75011768
76485976
78889492
81531508
83370784
83894218
85568314
85779051
86068216
86758363
87417982
87771300
90321306
90325864
92053866
93971222
94098306
101326940
103081913
105502451

56837447
58527015
58635449
60478784
61134104
61668962
64054152
66679789
67108107
68270337
70055532
70764218
74193504
78542168
82196474
82956932
83994919
87082980
88867245
88949087
90880969
92529037
94935328
97072448
100826219
101806291
101955355
102520034
103310781
103943330
104122586
104719061
106262435
106438996
108374759
108788660
108857559
109083977
112424999
113204830
117649911
121207692
123056998
124331722
125238759
128698690
129257128
23396
25039
33088
88587
108358
177557
197776
201803
215159
218627
222059
228629
252115
291304
294782
316866
319255
324898
331844
335815
5799543
5823700
6026236
6125178
6788861
7536832
7713926
8247543
11492513
12190308
14193352
14705653
17183361
18795072
19887645
20135303
20148786
23039398
23592539
23918800
25158157
28920225
30559210
36512644
36689698
38092995
39026097
39389857
42638097
43537711
45573017
46026984
48734543
49686525
50979853
51872509
56002762
56241450
59625761
59851245
62622592
63142088
64584558
65978932
70603566
73196956
73557446
741

99158421
102389747
107760699
108481380
112007948
112220827
114421739
118657889
119235488
120676520
121255943
124687561
125572934
126019119
127897121
128490634
129964738
130088373
130739978
4088
10478
16983
24193
45349
59732
60407
70430
79515
91335
133513
137645
188405
190878
246511
248046
254314
272926
273526
302689
313565
315664
316154
325346
337060
1497514
4051758
5174610
7484291
7609754
11417363
14329136
18394375
19451710
19496909
19850495
21007001
23730608
24915102
28978765
29518677
33061874
33476053
35935438
41562706
42485979
43645967
47069739
47942566
48679089
50941150
51220364
51859949
55427870
58112745
63039137
67406546
67691418
68143849
70819205
70947093
74961768
74994691
78242149
80863915
81097702
81787666
82160413
82943308
85030920
86742908
89745225
90648418
93303567
95810605
103817167
104099126
106520041
109694858
109786118
111401707
112163086
116941623
117846457
118230321
119220981
119911135
120016736
122640408
126686045
127727509
130695262
39223
52740
68003
94004
103100
1

83423091
84224194
85713794
86637920
90813331
91900996
92115051
92364022
94410053
94785585
96728510
97035078
97639845
98118937
111941650
116764335
118760315
118926278
123049303
124781998
125554681
128932972
128949060
128961480
130408188
3041
37661
55060
62078
80780
97406
110387
111448
145257
166697
175824
185839
224100
225829
310353
325290
335624
1426132
6460481
7079540
8476473
8698141
9547880
10068587
10120851
13698236
14119494
15600667
19020589
20767388
23106151
23897466
24960613
27410622
28532408
29559631
29676720
31732480
32654684
41056547
41284720
41388767
49308793
56942944
58006418
60292952
60641958
61665356
63210303
63677512
63729957
65925566
66000729
66209070
67679513
68339955
69285064
69659448
71171595
76454672
76475797
86258843
88451685
89299828
91177254
91932161
95450752
96298291
96975445
97011617
98714884
99021575
102904396
104115170
104321005
107732897
110107673
111660438
118970396
119209616
120060791
125889687
127145862
129329222
129391307
4819
8784
17094
28301
37776
41515

73569281
78186028
78617001
79358538
85732072
92045511
92263953
92321294
94066636
95404147
95866530
96914178
99836793
102354931
102700077
103802998
104001614
105327965
111327134
111575197
113044327
114090538
119879213
120526488
123275660
127843563
128895075
19213
29670
30191
34166
34299
60376
61165
95864
114341
123211
139362
145107
158241
168465
175409
177211
204379
209911
225529
229855
237559
239364
241329
247729
256581
267911
284093
301369
317884
332603
4092696
4190006
7285136
10803573
11471117
13465284
15091832
15926229
17787345
18217408
23734911
29074185
29587925
31930691
32515689
34875805
36498915
37956124
39171778
40271287
44885482
45854551
50150449
51053939
54732090
56369732
59206712
60448244
61014952
69542885
73402968
75498290
78480403
80730786
81274374
92253169
97111379
103428710
104711764
106363739
106793212
108275496
114499870
117633376
117700715
118381916
118606724
121211169
123038520
123434604
126626855
127884877
15251
17172
39846
46804
70816
73509
107615
161521
166956
1855

#### And here ^^^^^

In [18]:
idArr=np.array(train_df.select('object_id').collect(), dtype='float32')

In [19]:
r,c=idArr.shape
idArr.reshape(r,)
meta_len=10

In [20]:
metaArr=np.array(train_df.select('meta').collect(), dtype='float32').reshape(r,meta_len)
bandArr= np.array(train_df.select('band').collect() , dtype='int32').reshape(r,sequence_len)

histArray=np.zeros((r,sequence_len,8), dtype='float32') 
# this will work brilliantly as get_keras_data sets three columns to zeros anyway

#mjdInt=np.array(vectors_df.select('hist.interval').collect(), dtype='float32').reshape(r,sequence_len)
deltaMjd=np.array(train_df.select('hist.deltaMjd').collect(), dtype='float32').reshape(r,sequence_len)
rval=np.array(train_df.select('hist.rval').collect(), dtype='float32').reshape(r,sequence_len)
fluxTest=np.array(train_df.select('hist.flux').collect(), dtype='float32').reshape(r,sequence_len)
flux_err_test=np.array(train_df.select('hist.flux_err').collect(), dtype='float32').reshape(r,sequence_len)
#detected=np.array(train_df.select('hist.detected').collect(), dtype='float32').reshape(r,sequence_len)
source_wavelength=np.array(train_df.select('hist.source_wavelength').collect(), dtype='float32').reshape(r,sequence_len)
#received_wavelength=np.array(train_df.select('hist.received_wavelength').collect(), dtype='float32').reshape(r,sequence_len)


#### as per the baseline program, we remove the abs time, detected and receoved_wavelength data

In [21]:

#histArray[:,:,0]=mjdInt
histArray[:,:,1]=fluxTest
histArray[:,:,2]=flux_err_test
#histArray[:,:,3]=detected
histArray[:,:,4]=deltaMjd
histArray[:,:,5]=rval
histArray[:,:,6]=source_wavelength
#histArray[:,:,7]=received_wavelength

In [22]:
# Create the final vector dictionary
X = {
        'id': idArr,
        'meta': metaArr,
        'band': bandArr,
        'hist': histArray
    }

In [23]:
Y = to_categorical(np.array(train_df.select('target').collect(), dtype='int32'), num_classes=len(classes))

In [None]:
pgm_elapsed=time.time() - pgm_start
pgm_elapsedCpu=time.clock() - pgm_startCpu

In [None]:
print(pgm_elapsed)
print(pgm_elapsedCpu)

# Model testing

Basic Elephas!

In [32]:
from keras.layers import *
from keras.models import Model, load_model

In [26]:
hist_input = Input(shape=X['hist'][0].shape, name='hist')

W0924 23:52:13.599566 140136970643200 deprecation_wrapper.py:119] From /home/hduser/.virtualenvs/Elephas/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py:517: The name tf.placeholder is deprecated. Please use tf.compat.v1.placeholder instead.



In [27]:
hist_input

<tf.Tensor 'hist:0' shape=(?, 256, 8) dtype=float32>

In [28]:
meta_input = Input(shape=X['meta'][0].shape, name='meta')
band_input = Input(shape=X['band'][0].shape, name='band')


In [29]:
band_emb = Embedding(8, 8)(band_input)

hist = concatenate([hist_input, band_emb])
hist = TimeDistributed(Dense(40, activation='relu'))(hist)


W0924 23:53:51.583766 140136970643200 deprecation_wrapper.py:119] From /home/hduser/.virtualenvs/Elephas/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py:74: The name tf.get_default_graph is deprecated. Please use tf.compat.v1.get_default_graph instead.

W0924 23:53:51.604385 140136970643200 deprecation_wrapper.py:119] From /home/hduser/.virtualenvs/Elephas/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py:4138: The name tf.random_uniform is deprecated. Please use tf.random.uniform instead.



In [33]:
rnn = Bidirectional(GRU(80, return_sequences=True))(hist)
rnn = SpatialDropout1D(0.5)(rnn)

gmp = GlobalMaxPool1D()(rnn)
gmp = Dropout(0.5)(gmp)

x = concatenate([meta_input, gmp])
x = Dense(128, activation='relu')(x)
x = Dense(128, activation='relu')(x)
x = Dropout(0.5)(x)

output = Dense(15, activation='softmax')(x)

model = Model(inputs=[hist_input, meta_input, band_input], outputs=output)


In [40]:
 model.compile(optimizer=optimizer, loss='categorical_crossentropy', metrics=['accuracy'])

W0925 00:09:30.809822 140136970643200 deprecation_wrapper.py:119] From /home/hduser/.virtualenvs/Elephas/lib/python3.6/site-packages/keras/optimizers.py:790: The name tf.train.Optimizer is deprecated. Please use tf.compat.v1.train.Optimizer instead.

W0925 00:09:30.858464 140136970643200 deprecation_wrapper.py:119] From /home/hduser/.virtualenvs/Elephas/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py:3295: The name tf.log is deprecated. Please use tf.math.log instead.



In [34]:
Y

array([[0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 1., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 1., ..., 0., 0., 0.]], dtype=float32)

In [35]:
from elephas.spark_model import SparkModel
from elephas.utils.rdd_utils import to_simple_rdd




In [37]:
rdd = to_simple_rdd(sc, X, Y)

In [38]:
rdd.collect()

[('id', array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
        dtype=float32)),
 ('meta', array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
        dtype=float32)),
 ('band', array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
        dtype=float32)),
 ('hist', array([0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        dtype=float32))]

In [41]:
spark_model = SparkModel(model, frequency='epoch', mode='asynchronous')

In [42]:
spark_model.fit(rdd, epochs=50, batch_size=32, verbose=2, validation_split=0.1)

>>> Fit model
 * Serving Flask app "elephas.parameter.server" (lazy loading)
 * Environment: production
   Use a production WSGI server instead.
 * Debug mode: off


I0925 00:10:57.011522 140136970643200 _internal.py:122]  * Running on http://192.168.1.31:4000/ (Press CTRL+C to quit)


>>> Initialize workers
>>> Distribute load


I0925 00:11:35.016265 140136970643200 _internal.py:122] 192.168.1.15 - - [25/Sep/2019 00:11:35] "[37mGET /parameters HTTP/1.1[0m" 200 -
I0925 00:11:35.028871 140136970643200 _internal.py:122] 192.168.1.15 - - [25/Sep/2019 00:11:35] "[37mGET /parameters HTTP/1.1[0m" 200 -
I0925 00:11:35.090870 140136970643200 _internal.py:122] 192.168.1.15 - - [25/Sep/2019 00:11:35] "[37mPOST /update HTTP/1.1[0m" 200 -
I0925 00:11:35.101553 140136970643200 _internal.py:122] 192.168.1.15 - - [25/Sep/2019 00:11:35] "[37mPOST /update HTTP/1.1[0m" 200 -
I0925 00:11:35.112890 140136970643200 _internal.py:122] 192.168.1.15 - - [25/Sep/2019 00:11:35] "[37mGET /parameters HTTP/1.1[0m" 200 -
I0925 00:11:35.124606 140136970643200 _internal.py:122] 192.168.1.15 - - [25/Sep/2019 00:11:35] "[37mGET /parameters HTTP/1.1[0m" 200 -
I0925 00:11:35.141435 140136970643200 _internal.py:122] 192.168.1.15 - - [25/Sep/2019 00:11:35] "[37mPOST /update HTTP/1.1[0m" 200 -
I0925 00:11:35.150116 140136970643200 _inte

I0925 00:11:35.691044 140136970643200 _internal.py:122] 192.168.1.15 - - [25/Sep/2019 00:11:35] "[37mGET /parameters HTTP/1.1[0m" 200 -
I0925 00:11:35.706017 140136970643200 _internal.py:122] 192.168.1.15 - - [25/Sep/2019 00:11:35] "[37mPOST /update HTTP/1.1[0m" 200 -
I0925 00:11:35.715635 140136970643200 _internal.py:122] 192.168.1.15 - - [25/Sep/2019 00:11:35] "[37mPOST /update HTTP/1.1[0m" 200 -
I0925 00:11:35.721364 140136970643200 _internal.py:122] 192.168.1.15 - - [25/Sep/2019 00:11:35] "[37mGET /parameters HTTP/1.1[0m" 200 -
I0925 00:11:35.728691 140136970643200 _internal.py:122] 192.168.1.15 - - [25/Sep/2019 00:11:35] "[37mGET /parameters HTTP/1.1[0m" 200 -
I0925 00:11:35.743804 140136970643200 _internal.py:122] 192.168.1.15 - - [25/Sep/2019 00:11:35] "[37mPOST /update HTTP/1.1[0m" 200 -
I0925 00:11:35.754377 140136970643200 _internal.py:122] 192.168.1.15 - - [25/Sep/2019 00:11:35] "[37mPOST /update HTTP/1.1[0m" 200 -
I0925 00:11:35.759623 140136970643200 _interna

I0925 00:11:36.313020 140136970643200 _internal.py:122] 192.168.1.15 - - [25/Sep/2019 00:11:36] "[37mPOST /update HTTP/1.1[0m" 200 -
I0925 00:11:36.321848 140136970643200 _internal.py:122] 192.168.1.15 - - [25/Sep/2019 00:11:36] "[37mPOST /update HTTP/1.1[0m" 200 -
I0925 00:11:36.326326 140136970643200 _internal.py:122] 192.168.1.15 - - [25/Sep/2019 00:11:36] "[37mGET /parameters HTTP/1.1[0m" 200 -
I0925 00:11:36.335960 140136970643200 _internal.py:122] 192.168.1.15 - - [25/Sep/2019 00:11:36] "[37mGET /parameters HTTP/1.1[0m" 200 -
I0925 00:11:36.350026 140136970643200 _internal.py:122] 192.168.1.15 - - [25/Sep/2019 00:11:36] "[37mPOST /update HTTP/1.1[0m" 200 -
I0925 00:11:36.360415 140136970643200 _internal.py:122] 192.168.1.15 - - [25/Sep/2019 00:11:36] "[37mPOST /update HTTP/1.1[0m" 200 -
I0925 00:11:36.366075 140136970643200 _internal.py:122] 192.168.1.15 - - [25/Sep/2019 00:11:36] "[37mGET /parameters HTTP/1.1[0m" 200 -
I0925 00:11:36.372603 140136970643200 _interna

I0925 00:11:36.892786 140136970643200 _internal.py:122] 192.168.1.15 - - [25/Sep/2019 00:11:36] "[37mPOST /update HTTP/1.1[0m" 200 -
I0925 00:11:36.902547 140136970643200 _internal.py:122] 192.168.1.42 - - [25/Sep/2019 00:11:36] "[37mPOST /update HTTP/1.1[0m" 200 -
I0925 00:11:36.907654 140136970643200 _internal.py:122] 192.168.1.15 - - [25/Sep/2019 00:11:36] "[37mGET /parameters HTTP/1.1[0m" 200 -
I0925 00:11:36.917544 140136970643200 _internal.py:122] 192.168.1.42 - - [25/Sep/2019 00:11:36] "[37mGET /parameters HTTP/1.1[0m" 200 -
I0925 00:11:36.930564 140136970643200 _internal.py:122] 192.168.1.15 - - [25/Sep/2019 00:11:36] "[37mPOST /update HTTP/1.1[0m" 200 -
I0925 00:11:36.944579 140136970643200 _internal.py:122] 192.168.1.15 - - [25/Sep/2019 00:11:36] "[37mGET /parameters HTTP/1.1[0m" 200 -
I0925 00:11:36.964418 140136970643200 _internal.py:122] 192.168.1.42 - - [25/Sep/2019 00:11:36] "[37mPOST /update HTTP/1.1[0m" 200 -
I0925 00:11:36.974369 140136970643200 _interna

I0925 00:11:37.511696 140136970643200 _internal.py:122] 192.168.1.15 - - [25/Sep/2019 00:11:37] "[37mPOST /update HTTP/1.1[0m" 200 -
I0925 00:11:37.520089 140136970643200 _internal.py:122] 192.168.1.9 - - [25/Sep/2019 00:11:37] "[37mGET /parameters HTTP/1.1[0m" 200 -
I0925 00:11:37.531562 140136970643200 _internal.py:122] 192.168.1.42 - - [25/Sep/2019 00:11:37] "[37mPOST /update HTTP/1.1[0m" 200 -
I0925 00:11:37.540886 140136970643200 _internal.py:122] 192.168.1.9 - - [25/Sep/2019 00:11:37] "[37mPOST /update HTTP/1.1[0m" 200 -
I0925 00:11:37.546993 140136970643200 _internal.py:122] 192.168.1.42 - - [25/Sep/2019 00:11:37] "[37mGET /parameters HTTP/1.1[0m" 200 -
I0925 00:11:37.557978 140136970643200 _internal.py:122] 192.168.1.9 - - [25/Sep/2019 00:11:37] "[37mGET /parameters HTTP/1.1[0m" 200 -
I0925 00:11:37.572556 140136970643200 _internal.py:122] 192.168.1.42 - - [25/Sep/2019 00:11:37] "[37mPOST /update HTTP/1.1[0m" 200 -
I0925 00:11:37.580922 140136970643200 _internal.p

I0925 00:11:38.100516 140136970643200 _internal.py:122] 192.168.1.9 - - [25/Sep/2019 00:11:38] "[37mGET /parameters HTTP/1.1[0m" 200 -
I0925 00:11:38.135693 140136970643200 _internal.py:122] 192.168.1.42 - - [25/Sep/2019 00:11:38] "[37mPOST /update HTTP/1.1[0m" 200 -
I0925 00:11:38.148846 140136970643200 _internal.py:122] 192.168.1.9 - - [25/Sep/2019 00:11:38] "[37mPOST /update HTTP/1.1[0m" 200 -
I0925 00:11:38.154732 140136970643200 _internal.py:122] 192.168.1.42 - - [25/Sep/2019 00:11:38] "[37mGET /parameters HTTP/1.1[0m" 200 -
I0925 00:11:38.163147 140136970643200 _internal.py:122] 192.168.1.9 - - [25/Sep/2019 00:11:38] "[37mGET /parameters HTTP/1.1[0m" 200 -
I0925 00:11:38.178677 140136970643200 _internal.py:122] 192.168.1.42 - - [25/Sep/2019 00:11:38] "[37mPOST /update HTTP/1.1[0m" 200 -
I0925 00:11:38.189018 140136970643200 _internal.py:122] 192.168.1.9 - - [25/Sep/2019 00:11:38] "[37mPOST /update HTTP/1.1[0m" 200 -
I0925 00:11:38.194770 140136970643200 _internal.py

I0925 00:11:38.732466 140136970643200 _internal.py:122] 192.168.1.42 - - [25/Sep/2019 00:11:38] "[37mPOST /update HTTP/1.1[0m" 200 -
I0925 00:11:38.740836 140136970643200 _internal.py:122] 192.168.1.9 - - [25/Sep/2019 00:11:38] "[37mPOST /update HTTP/1.1[0m" 200 -
I0925 00:11:38.749275 140136970643200 _internal.py:122] 192.168.1.42 - - [25/Sep/2019 00:11:38] "[37mGET /parameters HTTP/1.1[0m" 200 -
I0925 00:11:38.756978 140136970643200 _internal.py:122] 192.168.1.9 - - [25/Sep/2019 00:11:38] "[37mGET /parameters HTTP/1.1[0m" 200 -
I0925 00:11:38.773954 140136970643200 _internal.py:122] 192.168.1.42 - - [25/Sep/2019 00:11:38] "[37mPOST /update HTTP/1.1[0m" 200 -
I0925 00:11:38.784068 140136970643200 _internal.py:122] 192.168.1.9 - - [25/Sep/2019 00:11:38] "[37mPOST /update HTTP/1.1[0m" 200 -
I0925 00:11:38.793789 140136970643200 _internal.py:122] 192.168.1.9 - - [25/Sep/2019 00:11:38] "[37mGET /parameters HTTP/1.1[0m" 200 -
I0925 00:11:38.816409 140136970643200 _internal.py

>>> Async training complete.


In [44]:
score = spark_model.master_network.evaluate(X, Y, verbose=2)
print('Test accuracy:', score[1])

Test accuracy: 0.12281259944152195


# Model testing
Elephas ML pipeline!

In [154]:
train_df=sqlContext.sql("select * from training_set_flat_augmented_vectors")

In [155]:
train_df.show(10)

+---------+--------------------+------+------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+
|object_id|                meta|target| specz|                band|            interval|            deltaMjd|                rval|                flux|            flux_err|            detected|   source_wavelength| received_wavelength|
+---------+--------------------+------+------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+
|    23409|[0.0, 0.0, 0.0, 0...|     4|0.1407|[0.0, 2.0, 5.0, 3...|[5.99590000000171...|[0.99199999999837...|[8.18419999999605...|[59876.0231, 6020...|[2.181061, 1.4428...|[0.0, 0.0, 0.0, 0...|[313.377808988764...|[357.0, 621.0, 10...|
|    63718|[0.0, 0.0, 0.0, 0...|    11|0.2891|[0.0, 1.0,

In [156]:
train_df.printSchema()

root
 |-- object_id: integer (nullable = true)
 |-- meta: array (nullable = true)
 |    |-- element: double (containsNull = true)
 |-- target: integer (nullable = true)
 |-- specz: double (nullable = true)
 |-- band: array (nullable = true)
 |    |-- element: double (containsNull = true)
 |-- interval: array (nullable = true)
 |    |-- element: double (containsNull = true)
 |-- deltaMjd: array (nullable = true)
 |    |-- element: double (containsNull = true)
 |-- rval: array (nullable = true)
 |    |-- element: double (containsNull = true)
 |-- flux: array (nullable = true)
 |    |-- element: double (containsNull = true)
 |-- flux_err: array (nullable = true)
 |    |-- element: double (containsNull = true)
 |-- detected: array (nullable = true)
 |    |-- element: double (containsNull = true)
 |-- source_wavelength: array (nullable = true)
 |    |-- element: double (containsNull = true)
 |-- received_wavelength: array (nullable = true)
 |    |-- element: double (containsNull = true)



In [48]:
from pyspark.sql.functions import udf

In [49]:
from pyspark.ml.linalg import Vectors, VectorUDT

In [50]:
to_vector = udf(lambda a: Vectors.dense(a), VectorUDT())

In [133]:
train_df.select("target","meta","band","interval","rval").show(5)

+------+--------------------+--------------------+--------------------+--------------------+
|target|                meta|                band|            interval|                rval|
+------+--------------------+--------------------+--------------------+--------------------+
|     4|[0.0, 0.0, 0.0, 0...|[0.0, 2.0, 5.0, 3...|[5.99590000000171...|[8.18419999999605...|
|    11|[0.0, 0.0, 0.0, 0...|[0.0, 1.0, 5.0, 0...|[356.232700000000...|[121.828099999998...|
|    11|[0.0, 0.0, 0.0, 0...|[5.0, 5.0, 3.0, 0...|[0.0, 715.9913999...|[0.0, 478.8455000...|
|     6|[0.0, 0.0, 0.0, 0...|[3.0, 4.0, 3.0, 5...|[269.152800000003...|[47.0485999999946...|
|     3|[0.0, 0.0, 0.0, 0...|[3.0, 2.0, 3.0, 2...|[327.106300000006...|[75.1097999999983...|
+------+--------------------+--------------------+--------------------+--------------------+
only showing top 5 rows



In [157]:
train_df = train_df.select("object_id", "target",\
                       to_vector("meta").alias("metaV"),\
                       to_vector("band").alias("bandV"),\
                       to_vector("interval").alias("intV"),\
                       to_vector("deltaMjd").alias("deltaV"),\
                       to_vector("rval").alias("rvalV"),\
                       to_vector("flux").alias("fluxV"),\
                       to_vector("flux_err").alias("flux_errV"),\
                       to_vector("source_wavelength").alias("source_wavelengthV"),\
                       to_vector("received_wavelength").alias("received_wavelengthV")
                      )

In [158]:
train_df.printSchema()

root
 |-- object_id: integer (nullable = true)
 |-- target: integer (nullable = true)
 |-- metaV: vector (nullable = true)
 |-- bandV: vector (nullable = true)
 |-- intV: vector (nullable = true)
 |-- deltaV: vector (nullable = true)
 |-- rvalV: vector (nullable = true)
 |-- fluxV: vector (nullable = true)
 |-- flux_errV: vector (nullable = true)
 |-- source_wavelengthV: vector (nullable = true)
 |-- received_wavelengthV: vector (nullable = true)



In [138]:
train_df.show(10)

+---------+------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+
|object_id|target|               metaV|               bandV|                intV|              deltaV|               rvalV|               fluxV|           flux_errV|  source_wavelengthV|received_wavelengthV|
+---------+------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+
|    23409|     4|[0.0,0.0,0.0,0.0,...|[0.0,2.0,5.0,3.0,...|[5.99590000000171...|[0.99199999999837...|[8.18419999999605...|[59876.0231,60208...|[2.181061,1.44283...|[313.377808988764...|[357.0,621.0,1004...|
|    63718|    11|[0.0,0.0,0.0,0.0,...|[0.0,1.0,5.0,0.0,...|[356.232700000000...|[4.24030000000493...|[121.828099999998...|[60233.2652,60237...|[1.748593,0.85526...|[25

In [159]:
from pyspark.ml.feature import VectorAssembler

ignore = ['object_id', 'target','metaV','bandV','intV','deltaV','rvalV','flux_errV','source_wavelengthV','received_wavelengthV']
assembler = VectorAssembler(
    inputCols=[x for x in train_df.columns if x not in ignore],
    outputCol='features')

train_df=assembler.transform(train_df)

In [160]:
train_df.printSchema()

root
 |-- object_id: integer (nullable = true)
 |-- target: integer (nullable = true)
 |-- metaV: vector (nullable = true)
 |-- bandV: vector (nullable = true)
 |-- intV: vector (nullable = true)
 |-- deltaV: vector (nullable = true)
 |-- rvalV: vector (nullable = true)
 |-- fluxV: vector (nullable = true)
 |-- flux_errV: vector (nullable = true)
 |-- source_wavelengthV: vector (nullable = true)
 |-- received_wavelengthV: vector (nullable = true)
 |-- features: vector (nullable = true)



In [153]:
8*256+10+14

2072

In [161]:
from pyspark.ml.feature import StringIndexer, StandardScaler

In [162]:
scaler = StandardScaler(inputCol="features", outputCol="scaled_features", withStd=True, withMean=True)
fitted_scaler = scaler.fit(train_df)
scaled_df = fitted_scaler.transform(train_df)
scaled_df.show(5)

+---------+------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+
|object_id|target|               metaV|               bandV|                intV|              deltaV|               rvalV|               fluxV|           flux_errV|  source_wavelengthV|received_wavelengthV|            features|     scaled_features|
+---------+------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+
|    23409|     4|[0.0,0.0,0.0,0.0,...|[0.0,2.0,5.0,3.0,...|[5.99590000000171...|[0.99199999999837...|[8.18419999999605...|[59876.0231,60208...|[2.181061,1.44283...|[313.377808988764...|[357.0,621.0,1004...|[59876.0231,60208...|[0.93990762173110...|


In [163]:
nb_classes = train_df.select("target").distinct().count()
input_dim = len(train_df.select("features").first()[0])
print(f"We have {nb_classes} classes and {input_dim} features")


We have 14 classes and 256 features


In [164]:
from keras import optimizers
from keras.models import Sequential
from keras.layers import Dense, Dropout, Activation
from keras.utils import np_utils, generic_utils

In [165]:
model = Sequential()
model.add(Dense(512, input_shape=(input_dim,)))
model.add(Activation('relu'))
model.add(Dropout(0.5))
model.add(Dense(512))
model.add(Activation('relu'))
model.add(Dropout(0.5))
model.add(Dense(512))
model.add(Activation('relu'))
model.add(Dropout(0.5))
model.add(Dense(nb_classes))
model.add(Activation('softmax'))

model.compile(loss='categorical_crossentropy', optimizer='adam')

In [166]:
from elephas.ml_model import ElephasEstimator

In [167]:
adam = optimizers.Adam(lr=0.01)
opt_conf = optimizers.serialize(adam)

# Initialize SparkML Estimator and set all relevant properties
estimator = ElephasEstimator()
estimator.setFeaturesCol("scaled_features")             # These two come directly from pyspark,
estimator.setLabelCol("target")                 # hence the camel case. Sorry :)
estimator.set_keras_model_config(model.to_yaml())       # Provide serialized Keras model
estimator.set_categorical_labels(True)
estimator.set_nb_classes(nb_classes)
estimator.set_num_workers(40)  # We just use one worker here. Feel free to adapt it.
estimator.set_epochs(50) 
estimator.set_batch_size(32) # was 128
estimator.set_verbosity(2) # was 1
estimator.set_validation_split(0.15)
estimator.set_optimizer_config(opt_conf)
estimator.set_mode("synchronous") # Was synchronous
estimator.set_loss("categorical_crossentropy")
estimator.set_metrics(['acc'])

ElephasEstimator_42cbac6c17c16139f8fc

In [168]:
from pyspark.ml import Pipeline

In [169]:
pipeline = Pipeline(stages=[scaler, estimator])

In [170]:
import time

start=time.time()
fitted_pipeline = pipeline.fit(train_df) # Fit model to data
elapsedTime=time.time()-start
print(f"Model trained in {elapsedTime} seconds")

>>> Fit model
>>> Synchronous training complete.
Model trained in 72.89576244354248 seconds


In [171]:
prediction = fitted_pipeline.transform(train_df) # Evaluate on train data.
# prediction = fitted_pipeline.transform(test_df) # <-- The same code evaluates test data.
#pnl = prediction.select("target", "prediction")
#pnl.show(100)

In [173]:
pnl = prediction.select("target", "prediction")

In [174]:
pnl.show(10)

Py4JJavaError: An error occurred while calling o2145.showString.
: org.apache.spark.SparkException: Job aborted due to stage failure: Task 0 in stage 102.0 failed 4 times, most recent failure: Lost task 0.3 in stage 102.0 (TID 2223, hercules-1-4.nimbus.pawsey.org.au, executor 9): org.apache.spark.api.python.PythonException: Traceback (most recent call last):
  File "/opt/cloudera/parcels/SPARK2/lib/spark2/python/pyspark/worker.py", line 253, in main
    process()
  File "/opt/cloudera/parcels/SPARK2/lib/spark2/python/pyspark/worker.py", line 248, in process
    serializer.dump_stream(func(split_index, iterator), outfile)
  File "/opt/cloudera/parcels/SPARK2/lib/spark2/python/pyspark/serializers.py", line 379, in dump_stream
    vs = list(itertools.islice(iterator, batch))
  File "/opt/cloudera/parcels/SPARK2/lib/spark2/python/pyspark/util.py", line 55, in wrapper
    return f(*args, **kwargs)
  File "/opt/cloudera/parcels/SPARK2/lib/spark2/python/pyspark/sql/session.py", line 707, in prepare
    verify_func(obj)
  File "/opt/cloudera/parcels/SPARK2/lib/spark2/python/pyspark/sql/types.py", line 1421, in verify
    verify_value(obj)
  File "/opt/cloudera/parcels/SPARK2/lib/spark2/python/pyspark/sql/types.py", line 1400, in verify_struct
    "length of fields (%d)" % (len(obj), len(verifiers))))
ValueError: Length of object (15) does not match with length of fields (14)

	at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.handlePythonException(PythonRunner.scala:330)
	at org.apache.spark.api.python.PythonRunner$$anon$1.read(PythonRunner.scala:470)
	at org.apache.spark.api.python.PythonRunner$$anon$1.read(PythonRunner.scala:453)
	at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.hasNext(PythonRunner.scala:284)
	at org.apache.spark.InterruptibleIterator.hasNext(InterruptibleIterator.scala:37)
	at scala.collection.Iterator$$anon$12.hasNext(Iterator.scala:439)
	at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:408)
	at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:408)
	at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:408)
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage1.processNext(Unknown Source)
	at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
	at org.apache.spark.sql.execution.WholeStageCodegenExec$$anonfun$10$$anon$1.hasNext(WholeStageCodegenExec.scala:614)
	at org.apache.spark.sql.execution.SparkPlan$$anonfun$2.apply(SparkPlan.scala:253)
	at org.apache.spark.sql.execution.SparkPlan$$anonfun$2.apply(SparkPlan.scala:247)
	at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$25.apply(RDD.scala:836)
	at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$25.apply(RDD.scala:836)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:49)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:49)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
	at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:87)
	at org.apache.spark.scheduler.Task.run(Task.scala:109)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:381)
	at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
	at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
	at java.lang.Thread.run(Thread.java:748)

Driver stacktrace:
	at org.apache.spark.scheduler.DAGScheduler.org$apache$spark$scheduler$DAGScheduler$$failJobAndIndependentStages(DAGScheduler.scala:1651)
	at org.apache.spark.scheduler.DAGScheduler$$anonfun$abortStage$1.apply(DAGScheduler.scala:1639)
	at org.apache.spark.scheduler.DAGScheduler$$anonfun$abortStage$1.apply(DAGScheduler.scala:1638)
	at scala.collection.mutable.ResizableArray$class.foreach(ResizableArray.scala:59)
	at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:48)
	at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:1638)
	at org.apache.spark.scheduler.DAGScheduler$$anonfun$handleTaskSetFailed$1.apply(DAGScheduler.scala:831)
	at org.apache.spark.scheduler.DAGScheduler$$anonfun$handleTaskSetFailed$1.apply(DAGScheduler.scala:831)
	at scala.Option.foreach(Option.scala:257)
	at org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:831)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:1872)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:1821)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:1810)
	at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:48)
	at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:642)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2034)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2055)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2074)
	at org.apache.spark.sql.execution.SparkPlan.executeTake(SparkPlan.scala:363)
	at org.apache.spark.sql.execution.CollectLimitExec.executeCollect(limit.scala:38)
	at org.apache.spark.sql.Dataset.org$apache$spark$sql$Dataset$$collectFromPlan(Dataset.scala:3278)
	at org.apache.spark.sql.Dataset$$anonfun$head$1.apply(Dataset.scala:2489)
	at org.apache.spark.sql.Dataset$$anonfun$head$1.apply(Dataset.scala:2489)
	at org.apache.spark.sql.Dataset$$anonfun$52.apply(Dataset.scala:3259)
	at org.apache.spark.sql.execution.SQLExecution$.withNewExecutionId(SQLExecution.scala:77)
	at org.apache.spark.sql.Dataset.withAction(Dataset.scala:3258)
	at org.apache.spark.sql.Dataset.head(Dataset.scala:2489)
	at org.apache.spark.sql.Dataset.take(Dataset.scala:2703)
	at org.apache.spark.sql.Dataset.showString(Dataset.scala:254)
	at sun.reflect.GeneratedMethodAccessor82.invoke(Unknown Source)
	at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
	at java.lang.reflect.Method.invoke(Method.java:498)
	at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
	at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)
	at py4j.Gateway.invoke(Gateway.java:282)
	at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
	at py4j.commands.CallCommand.execute(CallCommand.java:79)
	at py4j.GatewayConnection.run(GatewayConnection.java:238)
	at java.lang.Thread.run(Thread.java:748)
Caused by: org.apache.spark.api.python.PythonException: Traceback (most recent call last):
  File "/opt/cloudera/parcels/SPARK2/lib/spark2/python/pyspark/worker.py", line 253, in main
    process()
  File "/opt/cloudera/parcels/SPARK2/lib/spark2/python/pyspark/worker.py", line 248, in process
    serializer.dump_stream(func(split_index, iterator), outfile)
  File "/opt/cloudera/parcels/SPARK2/lib/spark2/python/pyspark/serializers.py", line 379, in dump_stream
    vs = list(itertools.islice(iterator, batch))
  File "/opt/cloudera/parcels/SPARK2/lib/spark2/python/pyspark/util.py", line 55, in wrapper
    return f(*args, **kwargs)
  File "/opt/cloudera/parcels/SPARK2/lib/spark2/python/pyspark/sql/session.py", line 707, in prepare
    verify_func(obj)
  File "/opt/cloudera/parcels/SPARK2/lib/spark2/python/pyspark/sql/types.py", line 1421, in verify
    verify_value(obj)
  File "/opt/cloudera/parcels/SPARK2/lib/spark2/python/pyspark/sql/types.py", line 1400, in verify_struct
    "length of fields (%d)" % (len(obj), len(verifiers))))
ValueError: Length of object (15) does not match with length of fields (14)

	at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.handlePythonException(PythonRunner.scala:330)
	at org.apache.spark.api.python.PythonRunner$$anon$1.read(PythonRunner.scala:470)
	at org.apache.spark.api.python.PythonRunner$$anon$1.read(PythonRunner.scala:453)
	at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.hasNext(PythonRunner.scala:284)
	at org.apache.spark.InterruptibleIterator.hasNext(InterruptibleIterator.scala:37)
	at scala.collection.Iterator$$anon$12.hasNext(Iterator.scala:439)
	at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:408)
	at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:408)
	at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:408)
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage1.processNext(Unknown Source)
	at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
	at org.apache.spark.sql.execution.WholeStageCodegenExec$$anonfun$10$$anon$1.hasNext(WholeStageCodegenExec.scala:614)
	at org.apache.spark.sql.execution.SparkPlan$$anonfun$2.apply(SparkPlan.scala:253)
	at org.apache.spark.sql.execution.SparkPlan$$anonfun$2.apply(SparkPlan.scala:247)
	at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$25.apply(RDD.scala:836)
	at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$25.apply(RDD.scala:836)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:49)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:49)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
	at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:87)
	at org.apache.spark.scheduler.Task.run(Task.scala:109)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:381)
	at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
	at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
	... 1 more


# Split testing below. you can delete after we work it out

https://stackoverflow.com/questions/37077432/how-to-estimate-dataframe-real-size-in-pyspark