In [1]:
import pandas as pd
from pyspark.sql import SparkSession, functions as F
import lbl2vec
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.cluster import KMeans
import numpy as np
from pyspark.sql.functions import date_format
import statsmodels.api as sm
from statsmodels.formula.api import ols
from pyspark.ml.feature import IndexToString, StringIndexer, VectorIndexer
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
from pyspark.ml.feature import OneHotEncoder, VectorAssembler
from pyspark.ml.regression import RandomForestRegressor
from pyspark.ml.evaluation import RegressionEvaluator
import seaborn as sns
import matplotlib
import matplotlib.pyplot as plt


In [2]:
# Create a spark session
spark = (
    SparkSession.builder.appName("MAST30034 Project 2")
    .config("spark.sql.repl.eagerEval.enabled", True) 
    .config("spark.sql.parquet.cacheMetadata", "true")
    .config("spark.sql.session.timeZone", "Etc/UTC")
    .config("spark.driver.memory", "4g")
    .config("spark.executor.memory", "8g")
    .getOrCreate()
)

22/10/04 02:33:30 WARN Utils: Your hostname, MacBook-Air-3.local resolves to a loopback address: 127.0.0.1; using 192.168.0.66 instead (on interface en0)
22/10/04 02:33:30 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address


Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).


22/10/04 02:33:30 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
22/10/04 02:33:31 WARN Utils: Service 'SparkUI' could not bind on port 4040. Attempting port 4041.
22/10/04 02:33:31 WARN Utils: Service 'SparkUI' could not bind on port 4041. Attempting port 4042.


In [3]:
# Read in data from ETL.py file
%run '../scripts/ETL.py' '../scripts/paths.json'
final_join3.limit(5)



22/10/04 02:33:32 WARN SparkSession: Using an existing Spark session; only runtime SQL configurations will take effect.


                                                                                

22/10/04 02:34:24 WARN package: Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.sql.debug.maxToStringFields'.


                                                                                

merchant_name,merchant_abn,categories,take_rate,revenue_levels,name,address,state,gender,trans_merchant_abn,dollar_value,order_id,order_datetime,user_id,consumer_id,postcodes,int_sa2,SA2_code,SA2_name,income_2018-2019,total_males,total_females,total_persons,state_code,state_name,population_2020,population_2021
Egestas Nunc Asso...,11121775571,digital goods: bo...,6.58,a,Christopher Rodri...,30554 Evans Strea...,NSW,Male,11121775571,11.28829564583802,2bd2a61d-72e5-42d...,2021-08-20,3698,1175,2299,111031231,111031231,Shortland - Jesmond,242936885,6412,6179,12593,1,New South Wales,12598,12694
Morbi Accumsan In...,19618998054,tent and aWning s...,1.52,c,Christopher Rodri...,30554 Evans Strea...,NSW,Male,19618998054,62.90176609196828,3582b1f8-4577-403...,2021-05-16,3698,1175,2299,111031231,111031231,Shortland - Jesmond,242936885,6412,6179,12593,1,New South Wales,12598,12694
Eu Dolor Egestas PC,94472466107,"cable, satellite,...",6.23,a,Christopher Rodri...,30554 Evans Strea...,NSW,Male,94472466107,172.15375126873164,cb05d49f-c2fa-453...,2021-07-22,3698,1175,2299,111031231,111031231,Shortland - Jesmond,242936885,6412,6179,12593,1,New South Wales,12598,12694
Urna Justo Indust...,31472801314,music shops - mus...,6.56,a,Christopher Rodri...,30554 Evans Strea...,NSW,Male,31472801314,0.4894787650356477,aeec15c1-67e8-4cb...,2021-05-18,3698,1175,2299,111031231,111031231,Shortland - Jesmond,242936885,6412,6179,12593,1,New South Wales,12598,12694
Eu Sem Pellentesq...,35424691626,"computers, comput...",3.9,b,Christopher Rodri...,30554 Evans Strea...,NSW,Male,35424691626,7.360217018778133,9df473ba-102d-461...,2021-07-04,3698,1175,2299,111031231,111031231,Shortland - Jesmond,242936885,6412,6179,12593,1,New South Wales,12598,12694


In [4]:
final_join3.count()

                                                                                

10540181

In [5]:
tagged_merchants = pd.read_csv("../data/curated/tagged_merchants.csv")
tagged_merchants = tagged_merchants.iloc[:,1:]
tagged_merchants.drop(['tags', 'name', 'cleaned_tags', 'store_type'], axis=1, inplace=True)
tagged_merchants.to_parquet("../data/curated/tagged_merchants.parquet")
tagged_merchants_sdf = spark.read.parquet("../data/curated/tagged_merchants.parquet")

In [6]:
tagged_merchants_sdf = tagged_merchants_sdf.withColumnRenamed('merchant_abn',

    'tagged_merchant_abn'
)

In [7]:
tagged_merchants_sdf.show(5)

+-------------------+--------------------+
|tagged_merchant_abn|            category|
+-------------------+--------------------+
|        10023283211|           Furniture|
|        10142254217|         Electronics|
|        10165489824|        Toys and DIY|
|        10187291046|        Toys and DIY|
|        10192359162|Books, Stationary...|
+-------------------+--------------------+
only showing top 5 rows



In [8]:
final_join3.createOrReplaceTempView("join")
tagged_merchants_sdf.createOrReplaceTempView("tagged")

joint = spark.sql(""" 

SELECT *
FROM join
INNER JOIN tagged
ON join.merchant_abn = tagged.tagged_merchant_abn
""")

joint = joint.drop('tagged_merchant_abn')

In [9]:
joint.count()

                                                                                

10109254

In [10]:
joint.createOrReplaceTempView("group")

a = spark.sql(""" 

SELECT *, (dollar_value - take_rate) AS total_earning
FROM group
""")

In [11]:
# Extracting the year, month, day from the timestamp

a = a.withColumn("Year", 
date_format('order_datetime', 'yyyy'))

a  = a.withColumn("Month", 
date_format('order_datetime', 'MMMM'))


a = a.withColumn("Day",
date_format(("order_datetime"), "E"))


In [12]:
a = a.drop('merchant_abn', 'categories','name', 'address', 'trans_merchant_abn', 'order_id','order_datetime','user_id','consumer_id','int_sa2','SA2_code',
'SA2_name','income_2018-2019','total_males', 'total_females','total_persons', 'state_code','state_name','population_2020', 'population_2021','total_earning')

In [13]:
 
# Find Count of Null, None, NaN of All DataFrame Columns
from pyspark.sql.functions import col,isnan, when, count
a.select([count(when(isnan(c) | col(c).isNull(), c)).alias(c) for c in a.columns]
   ).show()



+-------------+---------+--------------+-----+------+------------+---------+--------+----+-----+---+
|merchant_name|take_rate|revenue_levels|state|gender|dollar_value|postcodes|category|Year|Month|Day|
+-------------+---------+--------------+-----+------+------------+---------+--------+----+-----+---+
|            0|        0|             0|    0|     0|           0|        0|       0|   0|    0|  0|
+-------------+---------+--------------+-----+------+------------+---------+--------+----+-----+---+



                                                                                

In [14]:
a.printSchema()

root
 |-- merchant_name: string (nullable = true)
 |-- take_rate: double (nullable = true)
 |-- revenue_levels: string (nullable = true)
 |-- state: string (nullable = true)
 |-- gender: string (nullable = true)
 |-- dollar_value: double (nullable = true)
 |-- postcodes: string (nullable = true)
 |-- category: string (nullable = true)
 |-- Year: string (nullable = true)
 |-- Month: string (nullable = true)
 |-- Day: string (nullable = true)



In [21]:
a.createOrReplaceTempView("agg")

male = spark.sql(""" 

SELECT CONCAT(merchant_name, Year, Month, Day) AS m_name, COUNT(gender) as males
FROM agg
WHERE gender = 'Male'
GROUP BY merchant_name, Year, Month, Day
""")

male.show(5)

female = spark.sql(""" 

SELECT CONCAT(merchant_name, Year, Month, Day) AS f_name, COUNT(gender) as females
FROM agg
WHERE gender = 'Female'
GROUP BY merchant_name, Year, Month, Day
""")
female.show(5)

                                                                                

+--------------------+-----+
|              m_name|males|
+--------------------+-----+
|Tempus Scelerisqu...|   43|
|Euismod In Corp.2...|  127|
|Felis Purus Found...|   19|
|Mauris Consulting...|   39|
|In Nec Industries...|   22|
+--------------------+-----+
only showing top 5 rows





+--------------------+-------+
|              f_name|females|
+--------------------+-------+
|Ante Dictum LLC20...|     45|
|Fusce Mollis Duis...|      2|
|Luctus LLC2022Jun...|     11|
|Nulla Semper LLC2...|     17|
|Vitae Risus Indus...|     14|
+--------------------+-------+
only showing top 5 rows



                                                                                

In [16]:
a.show(2)



+--------------------+---------+--------------+-----+------+-----------------+---------+--------------------+----+------+---+
|       merchant_name|take_rate|revenue_levels|state|gender|     dollar_value|postcodes|            category|Year| Month|Day|
+--------------------+---------+--------------+-----+------+-----------------+---------+--------------------+----+------+---+
|Egestas Nunc Asso...|     6.58|             a|  NSW|  Male|11.28829564583802|     2299|Books, Stationary...|2021|August|Fri|
|Morbi Accumsan In...|     1.52|             c|  NSW|  Male|62.90176609196828|     2299|Books, Stationary...|2021|   May|Sun|
+--------------------+---------+--------------+-----+------+-----------------+---------+--------------------+----+------+---+
only showing top 2 rows



                                                                                

In [22]:
a.createOrReplaceTempView("agg")

temp = spark.sql(""" 

SELECT merchant_name, COUNT(merchant_name) AS no_of_transactions, Year, Month, Day, SUM(dollar_value - take_rate) AS total_earnings,
    CONCAT(merchant_name, Year, Month, Day) AS join_col
FROM agg
GROUP BY merchant_name, Year, Month, Day
""")

temp.show()




+--------------------+------------------+----+------+---+------------------+--------------------+
|       merchant_name|no_of_transactions|Year| Month|Day|    total_earnings|            join_col|
+--------------------+------------------+----+------+---+------------------+--------------------+
|      Vestibulum LLP|                34|2021|August|Thu| 1594.124325853504|Vestibulum LLP202...|
|      Mattis Limited|                11|2021|August|Sat|  8830.62893275408|Mattis Limited202...|
|Maecenas Iaculis ...|                 4|2021|August|Sun|3036.5175241306893|Maecenas Iaculis ...|
|Ipsum Dolor Sit C...|              1110|2021|  June|Wed| 39711.61412523745|Ipsum Dolor Sit C...|
|Hendrerit Donec C...|                48|2021|  June|Wed| 2485.051283433287|Hendrerit Donec C...|
|Suspendisse Dui C...|              1029|2021|   May|Thu| 27970.46271006672|Suspendisse Dui C...|
|Euismod Urna Inst...|               344|2021|  June|Fri| 3342.046480591664|Euismod Urna Inst...|
|   In Nec Industrie

                                                                                

In [23]:
temp.createOrReplaceTempView("gender_join")
male.createOrReplaceTempView("m")
female.createOrReplaceTempView("f")

temp2 = spark.sql(""" 

SELECT *
FROM gender_join
INNER JOIN m
ON gender_join.join_col = m.m_name
""")

temp2.createOrReplaceTempView("temp2")

temp3 = spark.sql(""" 

SELECT *
FROM temp2
INNER JOIN f
ON temp2.join_col = f.f_name
""")

temp3.limit(5)

                                                                                

merchant_name,no_of_transactions,Year,Month,Day,total_earnings,join_col,m_name,males,f_name,females
A Aliquet Ltd,8,2021,April,Mon,1830.476073709004,A Aliquet Ltd2021...,A Aliquet Ltd2021...,5,A Aliquet Ltd2021...,3
A Aliquet Ltd,7,2021,April,Sat,1962.1571924103384,A Aliquet Ltd2021...,A Aliquet Ltd2021...,5,A Aliquet Ltd2021...,2
A Aliquet Ltd,7,2021,August,Sat,1979.523232289712,A Aliquet Ltd2021...,A Aliquet Ltd2021...,4,A Aliquet Ltd2021...,3
A Aliquet Ltd,11,2021,December,Thu,2115.328500805588,A Aliquet Ltd2021...,A Aliquet Ltd2021...,3,A Aliquet Ltd2021...,7
A Aliquet Ltd,9,2021,December,Tue,1864.8219116639036,A Aliquet Ltd2021...,A Aliquet Ltd2021...,5,A Aliquet Ltd2021...,4


In [24]:
a.createOrReplaceTempView("features")

e = spark.sql(""" 

SELECT merchant_name AS drop_name, FIRST(take_rate) AS take_rate, FIRST(revenue_levels) AS revenue_levels, FIRST(category) AS category
FROM features
GROUP BY merchant_name
""")

e.show(2)



+----------------+---------+--------------+--------------------+
|       drop_name|take_rate|revenue_levels|            category|
+----------------+---------+--------------+--------------------+
|    A Associates|     4.95|             b|Books, Stationary...|
|A Enim Institute|     6.49|             a|        Toys and DIY|
+----------------+---------+--------------+--------------------+
only showing top 2 rows



                                                                                

In [25]:
temp3.createOrReplaceTempView("edit")
e.createOrReplaceTempView("rates")

temp4 = spark.sql(""" 

SELECT *
FROM edit
INNER JOIN rates
ON edit.merchant_name = rates.drop_name
""")

train = temp4.drop('m_name', 'f_name', 'drop_name','join_col')

train.limit(5)

                                                                                

merchant_name,no_of_transactions,Year,Month,Day,total_earnings,males,females,take_rate,revenue_levels,category
A Aliquet Ltd,8,2021,April,Mon,1830.476073709004,5,3,3.87,b,Furniture
A Aliquet Ltd,7,2021,April,Sat,1962.1571924103384,5,2,3.87,b,Furniture
A Aliquet Ltd,7,2021,August,Sat,1979.523232289712,4,3,3.87,b,Furniture
A Aliquet Ltd,11,2021,December,Thu,2115.328500805588,3,7,3.87,b,Furniture
A Aliquet Ltd,9,2021,December,Tue,1864.8219116639036,5,4,3.87,b,Furniture


In [26]:
train.printSchema()

root
 |-- merchant_name: string (nullable = true)
 |-- no_of_transactions: long (nullable = false)
 |-- Year: string (nullable = true)
 |-- Month: string (nullable = true)
 |-- Day: string (nullable = true)
 |-- total_earnings: double (nullable = true)
 |-- males: long (nullable = false)
 |-- females: long (nullable = false)
 |-- take_rate: double (nullable = true)
 |-- revenue_levels: string (nullable = true)
 |-- category: string (nullable = true)



In [27]:
train.count()

                                                                                

267456

In [28]:
# String indexing the categorical columns

indexer = StringIndexer(inputCols = ['merchant_name', 'Year', 'Month', 'Day', 'revenue_levels','category'],
outputCols = ['merchant_name_num', 'Year_num', 'Month_num', 'Day_num', 'revenue_levels_num','category_num'])

indexd_data = indexer.fit(train).transform(train)


# Applying onehot encoding to the categorical data that is string indexed above
encoder = OneHotEncoder(inputCols = ['merchant_name_num', 'Year_num', 'Month_num', 'Day_num', 'revenue_levels_num','category_num'],
outputCols = ['merchant_name_vec', 'Year_vec', 'Month_vec', 'Day_vec', 'revenue_levels_vec','category_vec'])

onehotdata = encoder.fit(indexd_data).transform(indexd_data)


# Assembling the training data as a vector of features 
assembler1 = VectorAssembler(
inputCols=['no_of_transactions','take_rate', 'merchant_name_vec', 'Year_vec', 'Month_vec', 'Day_vec', 'revenue_levels_vec','category_vec'],
outputCol= "features" )

outdata1 = assembler1.transform(onehotdata)

                                                                                

In [29]:
# Renaming the target column as label

outdata1 = outdata1.withColumnRenamed(
    "total_earnings",
    "label"
)

In [30]:
# Assembling the features as a feature vector 

featureIndexer =\
    VectorIndexer(inputCol="features", 
    outputCol="indexedFeatures").fit(outdata1)

outdata1 = featureIndexer.transform(outdata1)

[Stage 846:> (0 + 8) / 14][Stage 847:> (0 + 0) / 14][Stage 848:> (0 + 0) / 14]  

22/10/04 02:45:16 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
22/10/04 02:45:16 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
22/10/04 02:45:16 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
22/10/04 02:45:16 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
22/10/04 02:45:16 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
22/10/04 02:45:16 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
22/10/04 02:45:16 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
22/10/04 02:45:16 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.


[Stage 846:=>(8 + 6) / 14][Stage 847:> (0 + 2) / 14][Stage 848:> (0 + 0) / 14]

22/10/04 02:45:21 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
22/10/04 02:45:21 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
22/10/04 02:45:21 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
22/10/04 02:45:21 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
22/10/04 02:45:21 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.


                                                                                

In [31]:
# Split the data into training and validation sets (30% held out for testing)

trainingData, testData = outdata1.randomSplit([0.7, 0.3], seed = 20)

In [32]:
trainingData.count(), testData.count()

                                                                                

22/10/04 02:46:16 WARN DAGScheduler: Broadcasting large task binary with size 1563.0 KiB




22/10/04 02:46:42 WARN DAGScheduler: Broadcasting large task binary with size 1563.0 KiB


                                                                                

(187155, 80301)

In [33]:
# Train a RandomForest model.
rf = RandomForestRegressor(featuresCol="indexedFeatures")


# Train model.  
model = rf.fit(trainingData)

# Make predictions.
predictions_validation = model.transform(testData)



22/10/04 02:47:09 WARN DAGScheduler: Broadcasting large task binary with size 1569.6 KiB


                                                                                

22/10/04 02:47:10 WARN DAGScheduler: Broadcasting large task binary with size 1569.7 KiB


                                                                                

22/10/04 02:47:11 WARN DAGScheduler: Broadcasting large task binary with size 1573.6 KiB


                                                                                

22/10/04 02:47:13 WARN DAGScheduler: Broadcasting large task binary with size 1766.5 KiB


[Stage 1086:>                                                       (0 + 8) / 8]

22/10/04 02:47:15 WARN MemoryStore: Not enough space to cache rdd_2395_0 in memory! (computed 213.8 MiB so far)
22/10/04 02:47:15 WARN MemoryStore: Not enough space to cache rdd_2395_7 in memory! (computed 213.8 MiB so far)
22/10/04 02:47:15 WARN MemoryStore: Not enough space to cache rdd_2395_2 in memory! (computed 213.8 MiB so far)
22/10/04 02:47:15 WARN MemoryStore: Not enough space to cache rdd_2395_3 in memory! (computed 213.8 MiB so far)
22/10/04 02:47:15 WARN MemoryStore: Not enough space to cache rdd_2395_1 in memory! (computed 142.4 MiB so far)
22/10/04 02:47:15 WARN MemoryStore: Not enough space to cache rdd_2395_6 in memory! (computed 142.4 MiB so far)
22/10/04 02:47:15 WARN MemoryStore: Not enough space to cache rdd_2395_4 in memory! (computed 213.8 MiB so far)
22/10/04 02:47:15 WARN BlockManager: Persisting block rdd_2395_2 to disk instead.
22/10/04 02:47:15 WARN BlockManager: Persisting block rdd_2395_0 to disk instead.
22/10/04 02:47:15 WARN BlockManager: Persisting bloc

                                                                                

22/10/04 02:47:21 WARN DAGScheduler: Broadcasting large task binary with size 1867.9 KiB


[Stage 1098:>                                                       (0 + 8) / 8]

22/10/04 02:47:21 WARN MemoryStore: Not enough space to cache rdd_2395_1 in memory! (computed 142.4 MiB so far)
22/10/04 02:47:22 WARN MemoryStore: Not enough space to cache rdd_2395_3 in memory! (computed 213.8 MiB so far)
22/10/04 02:47:22 WARN MemoryStore: Not enough space to cache rdd_2395_7 in memory! (computed 213.8 MiB so far)
22/10/04 02:47:22 WARN MemoryStore: Not enough space to cache rdd_2395_0 in memory! (computed 213.8 MiB so far)
22/10/04 02:47:22 WARN MemoryStore: Not enough space to cache rdd_2395_6 in memory! (computed 213.8 MiB so far)


                                                                                

22/10/04 02:47:25 WARN DAGScheduler: Broadcasting large task binary with size 2.0 MiB


[Stage 1110:>                                                       (0 + 8) / 8]

22/10/04 02:47:26 WARN MemoryStore: Not enough space to cache rdd_2395_7 in memory! (computed 142.4 MiB so far)
22/10/04 02:47:26 WARN MemoryStore: Not enough space to cache rdd_2395_3 in memory! (computed 213.8 MiB so far)
22/10/04 02:47:26 WARN MemoryStore: Not enough space to cache rdd_2395_0 in memory! (computed 213.8 MiB so far)
22/10/04 02:47:26 WARN MemoryStore: Not enough space to cache rdd_2395_1 in memory! (computed 213.8 MiB so far)
22/10/04 02:47:26 WARN MemoryStore: Not enough space to cache rdd_2395_6 in memory! (computed 213.8 MiB so far)


                                                                                

22/10/04 02:47:30 WARN DAGScheduler: Broadcasting large task binary with size 2.4 MiB


[Stage 1122:>                                                       (0 + 8) / 8]

22/10/04 02:47:30 WARN MemoryStore: Not enough space to cache rdd_2395_0 in memory! (computed 142.4 MiB so far)
22/10/04 02:47:31 WARN MemoryStore: Not enough space to cache rdd_2395_7 in memory! (computed 213.8 MiB so far)
22/10/04 02:47:31 WARN MemoryStore: Not enough space to cache rdd_2395_6 in memory! (computed 213.8 MiB so far)
22/10/04 02:47:31 WARN MemoryStore: Not enough space to cache rdd_2395_1 in memory! (computed 213.8 MiB so far)
22/10/04 02:47:31 WARN MemoryStore: Not enough space to cache rdd_2395_3 in memory! (computed 213.8 MiB so far)


                                                                                

22/10/04 02:47:34 WARN DAGScheduler: Broadcasting large task binary with size 3.0 MiB
22/10/04 02:47:35 WARN MemoryStore: Not enough space to cache rdd_2395_1 in memory! (computed 142.4 MiB so far)
22/10/04 02:47:35 WARN MemoryStore: Not enough space to cache rdd_2395_6 in memory! (computed 213.8 MiB so far)
22/10/04 02:47:35 WARN MemoryStore: Not enough space to cache rdd_2395_7 in memory! (computed 213.8 MiB so far)
22/10/04 02:47:35 WARN MemoryStore: Not enough space to cache rdd_2395_0 in memory! (computed 213.8 MiB so far)
22/10/04 02:47:35 WARN MemoryStore: Not enough space to cache rdd_2395_3 in memory! (computed 213.8 MiB so far)


                                                                                

In [34]:
# Evaluate the validation set 

predictions_validation.select("prediction", "label", "features").show(5)

# Select (prediction, true label) and compute test error

evaluator_train_rmse = RegressionEvaluator(
    labelCol="label", predictionCol="prediction", metricName="rmse")
rmse_train = evaluator_train_rmse.evaluate(predictions_validation)
print("Root Mean Squared Error (RMSE) on train data = %g" % rmse_train)

evaluator_train_mae = RegressionEvaluator(
    labelCol="label", predictionCol="prediction", metricName="mae")
mae_train = evaluator_train_mae.evaluate(predictions_validation)
print("Root Mean Squared Error (MAE) on raint data = %g" % mae_train)



22/10/04 02:48:05 WARN DAGScheduler: Broadcasting large task binary with size 1577.4 KiB


                                                                                

+-----------------+------------------+--------------------+
|       prediction|             label|            features|
+-----------------+------------------+--------------------+
| 3462.74890804621|1450.4324499682334|(3604,[0,1,1376,3...|
|3554.280947682284| 1830.476073709004|(3604,[0,1,1376,3...|
|3554.280947682284|1887.5757915518245|(3604,[0,1,1376,3...|
|3554.280947682284| 2115.328500805588|(3604,[0,1,1376,3...|
|3554.280947682284| 4478.059685948354|(3604,[0,1,664,35...|
+-----------------+------------------+--------------------+
only showing top 5 rows





22/10/04 02:48:30 WARN DAGScheduler: Broadcasting large task binary with size 1574.8 KiB


[Stage 1237:>                                                       (0 + 8) / 8]

22/10/04 02:48:32 WARN DAGScheduler: Broadcasting large task binary with size 1576.0 KiB
Root Mean Squared Error (RMSE) on train data = 6800.62




22/10/04 02:48:57 WARN DAGScheduler: Broadcasting large task binary with size 1574.8 KiB


[Stage 1289:>                                                       (0 + 8) / 8]

22/10/04 02:48:58 WARN DAGScheduler: Broadcasting large task binary with size 1575.9 KiB
Root Mean Squared Error (MAE) on raint data = 3759.11


                                                                                