In [0]:
#import SparkSession
from pyspark.sql import SparkSession
spark=SparkSession.builder.appName('random_forest').getOrCreate()

In [0]:
#read the dataset
df=spark.read.csv('dbfs:/FileStore/tables/affairs.csv',inferSchema=True,header=True)

In [0]:
#check the shape of the data 
print((df.count(),len(df.columns)))

(6366, 6)


In [0]:
#printSchema
df.printSchema()

root
 |-- rate_marriage: integer (nullable = true)
 |-- age: double (nullable = true)
 |-- yrs_married: double (nullable = true)
 |-- children: double (nullable = true)
 |-- religious: integer (nullable = true)
 |-- affairs: integer (nullable = true)



In [0]:
#view the dataset
df.show(5)

+-------------+----+-----------+--------+---------+-------+
|rate_marriage| age|yrs_married|children|religious|affairs|
+-------------+----+-----------+--------+---------+-------+
|            5|32.0|        6.0|     1.0|        3|      0|
|            4|22.0|        2.5|     0.0|        2|      0|
|            3|32.0|        9.0|     3.0|        3|      1|
|            3|27.0|       13.0|     3.0|        1|      1|
|            4|22.0|        2.5|     0.0|        1|      1|
+-------------+----+-----------+--------+---------+-------+
only showing top 5 rows



In [0]:
#Exploratory Data Analysis
df.describe().select('summary','rate_marriage','age','yrs_married','children','religious').show()

+-------+------------------+------------------+-----------------+------------------+------------------+
|summary|     rate_marriage|               age|      yrs_married|          children|         religious|
+-------+------------------+------------------+-----------------+------------------+------------------+
|  count|              6366|              6366|             6366|              6366|              6366|
|   mean| 4.109644989004084|29.082862079798932| 9.00942507068803|1.3968740182218033|2.4261702796104303|
| stddev|0.9614295945655025| 6.847881883668817|7.280119972766412| 1.433470828560344|0.8783688402641785|
|    min|                 1|              17.5|              0.5|               0.0|                 1|
|    max|                 5|              42.0|             23.0|               5.5|                 4|
+-------+------------------+------------------+-----------------+------------------+------------------+



In [0]:
df.groupBy('affairs').count().show()

+-------+-----+
|affairs|count|
+-------+-----+
|      1| 2053|
|      0| 4313|
+-------+-----+



In [0]:
df.groupBy('rate_marriage').count().show()

+-------------+-----+
|rate_marriage|count|
+-------------+-----+
|            1|   99|
|            3|  993|
|            5| 2684|
|            4| 2242|
|            2|  348|
+-------------+-----+



In [0]:
df.groupBy('rate_marriage','affairs').count().orderBy('rate_marriage','affairs','count',ascending=True).show()

+-------------+-------+-----+
|rate_marriage|affairs|count|
+-------------+-------+-----+
|            1|      0|   25|
|            1|      1|   74|
|            2|      0|  127|
|            2|      1|  221|
|            3|      0|  446|
|            3|      1|  547|
|            4|      0| 1518|
|            4|      1|  724|
|            5|      0| 2197|
|            5|      1|  487|
+-------------+-------+-----+



In [0]:
df.groupBy('religious','affairs').count().orderBy('religious','affairs','count',ascending=True).show()

+---------+-------+-----+
|religious|affairs|count|
+---------+-------+-----+
|        1|      0|  613|
|        1|      1|  408|
|        2|      0| 1448|
|        2|      1|  819|
|        3|      0| 1715|
|        3|      1|  707|
|        4|      0|  537|
|        4|      1|  119|
+---------+-------+-----+



In [0]:
df.groupBy('children','affairs').count().orderBy('children','affairs','count',ascending=True).show()

+--------+-------+-----+
|children|affairs|count|
+--------+-------+-----+
|     0.0|      0| 1912|
|     0.0|      1|  502|
|     1.0|      0|  747|
|     1.0|      1|  412|
|     2.0|      0|  873|
|     2.0|      1|  608|
|     3.0|      0|  460|
|     3.0|      1|  321|
|     4.0|      0|  197|
|     4.0|      1|  131|
|     5.5|      0|  124|
|     5.5|      1|   79|
+--------+-------+-----+



In [0]:
df.groupBy('affairs').mean().show()

+-------+------------------+------------------+------------------+------------------+------------------+------------+
|affairs|avg(rate_marriage)|          avg(age)|  avg(yrs_married)|     avg(children)|    avg(religious)|avg(affairs)|
+-------+------------------+------------------+------------------+------------------+------------------+------------+
|      1|3.6473453482708234|30.537018996590355|11.152459814905017|1.7289332683877252| 2.261568436434486|         1.0|
|      0| 4.329700904242986| 28.39067934152562| 7.989334569904939|1.2388128912589844|2.5045212149316023|         0.0|
+-------+------------------+------------------+------------------+------------------+------------------+------------+



In [0]:
from pyspark.ml.feature import VectorAssembler

In [0]:
df_assembler = VectorAssembler(inputCols=['rate_marriage', 'age', 'yrs_married', 'children', 'religious'], outputCol="features")
df = df_assembler.transform(df)

In [0]:
df.printSchema()

root
 |-- rate_marriage: integer (nullable = true)
 |-- age: double (nullable = true)
 |-- yrs_married: double (nullable = true)
 |-- children: double (nullable = true)
 |-- religious: integer (nullable = true)
 |-- affairs: integer (nullable = true)
 |-- features: vector (nullable = true)



In [0]:
df.select(['features','affairs']).show(10,False)

+-----------------------+-------+
|features               |affairs|
+-----------------------+-------+
|[5.0,32.0,6.0,1.0,3.0] |0      |
|[4.0,22.0,2.5,0.0,2.0] |0      |
|[3.0,32.0,9.0,3.0,3.0] |1      |
|[3.0,27.0,13.0,3.0,1.0]|1      |
|[4.0,22.0,2.5,0.0,1.0] |1      |
|[4.0,37.0,16.5,4.0,3.0]|1      |
|[5.0,27.0,9.0,1.0,1.0] |1      |
|[4.0,27.0,9.0,0.0,2.0] |1      |
|[5.0,37.0,23.0,5.5,2.0]|1      |
|[5.0,37.0,23.0,5.5,2.0]|1      |
+-----------------------+-------+
only showing top 10 rows



In [0]:
#select data for building model
model_df=df.select(['features','affairs'])

In [0]:
train_df,test_df=model_df.randomSplit([0.75,0.25])

In [0]:
train_df.count()

Out[19]: 4750

In [0]:
train_df.groupBy('affairs').count().show()

+-------+-----+
|affairs|count|
+-------+-----+
|      1| 1516|
|      0| 3234|
+-------+-----+



In [0]:
test_df.groupBy('affairs').count().show()

+-------+-----+
|affairs|count|
+-------+-----+
|      1|  537|
|      0| 1079|
+-------+-----+



In [0]:
from pyspark.ml.classification import RandomForestClassifier

In [0]:
rf_classifier=RandomForestClassifier(labelCol='affairs',numTrees=50).fit(train_df)

In [0]:
rf_predictions=rf_classifier.transform(test_df)

In [0]:
rf_predictions.show()

+--------------------+-------+--------------------+--------------------+----------+
|            features|affairs|       rawPrediction|         probability|prediction|
+--------------------+-------+--------------------+--------------------+----------+
|[1.0,22.0,2.5,0.0...|      1|[22.1570633833066...|[0.44314126766613...|       1.0|
|[1.0,27.0,2.5,0.0...|      0|[17.9558486444922...|[0.35911697288984...|       1.0|
|[1.0,27.0,2.5,0.0...|      1|[21.8730744056017...|[0.43746148811203...|       1.0|
|[1.0,27.0,6.0,1.0...|      0|[18.6671064292999...|[0.37334212858599...|       1.0|
|[1.0,27.0,6.0,2.0...|      1|[18.6604246421417...|[0.37320849284283...|       1.0|
|[1.0,27.0,9.0,1.0...|      1|[18.3610605852142...|[0.36722121170428...|       1.0|
|[1.0,32.0,2.5,1.0...|      0|[22.6733214919040...|[0.45346642983808...|       1.0|
|[1.0,32.0,13.0,0....|      1|[18.9976720750195...|[0.37995344150039...|       1.0|
|[1.0,32.0,13.0,1....|      1|[18.0822924429207...|[0.36164584885841...|    

In [0]:
rf_predictions.groupBy('prediction').count().show()

+----------+-----+
|prediction|count|
+----------+-----+
|       0.0| 1318|
|       1.0|  298|
+----------+-----+



In [0]:
rf_predictions.select(['probability','affairs','prediction']).show(10,False)

+----------------------------------------+-------+----------+
|probability                             |affairs|prediction|
+----------------------------------------+-------+----------+
|[0.4431412676661325,0.5568587323338676] |1      |1.0       |
|[0.35911697288984423,0.6408830271101558]|0      |1.0       |
|[0.4374614881120344,0.5625385118879656] |1      |1.0       |
|[0.37334212858599836,0.6266578714140018]|0      |1.0       |
|[0.37320849284283564,0.6267915071571645]|1      |1.0       |
|[0.36722121170428573,0.6327787882957143]|1      |1.0       |
|[0.4534664298380801,0.54653357016192]   |0      |1.0       |
|[0.3799534415003918,0.6200465584996081] |1      |1.0       |
|[0.3616458488584147,0.6383541511415853] |1      |1.0       |
|[0.36237653455992086,0.6376234654400791]|1      |1.0       |
+----------------------------------------+-------+----------+
only showing top 10 rows



In [0]:
from pyspark.ml.evaluation import BinaryClassificationEvaluator

In [0]:
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

In [0]:
rf_accuracy=MulticlassClassificationEvaluator(labelCol='affairs',metricName='accuracy').evaluate(rf_predictions)

In [0]:
print('The accuracy of RF on test data is {0:.0%}'.format(rf_accuracy))

The accuracy of RF on test data is 71%


In [0]:
print(rf_accuracy)

0.7147277227722773


In [0]:
rf_precision=MulticlassClassificationEvaluator(labelCol='affairs',metricName='weightedPrecision').evaluate(rf_predictions)

In [0]:
print('The precision rate on test data is {0:.0%}'.format(rf_precision))

The precision rate on test data is 70%


In [0]:
rf_precision

Out[35]: 0.6989132950974213

In [0]:
rf_auc=BinaryClassificationEvaluator(labelCol='affairs').evaluate(rf_predictions)

In [0]:
print(rf_auc)

0.7406980737733919


In [0]:
# Feature importance

In [0]:
rf_classifier.featureImportances

Out[39]: SparseVector(5, {0: 0.6397, 1: 0.0234, 2: 0.2023, 3: 0.0661, 4: 0.0685})

In [0]:
df.schema["features"].metadata["ml_attr"]["attrs"]

Out[40]: {'numeric': [{'idx': 0, 'name': 'rate_marriage'},
  {'idx': 1, 'name': 'age'},
  {'idx': 2, 'name': 'yrs_married'},
  {'idx': 3, 'name': 'children'},
  {'idx': 4, 'name': 'religious'}]}

In [0]:
# Save the model 

In [0]:
pwd

Out[42]: '/databricks/driver'

In [0]:
rf_classifier.save("/databricks/driver/RF_model")

In [0]:
from pyspark.ml.classification import RandomForestClassificationModel

In [0]:
rf=RandomForestClassificationModel.load("/databricks/driver/RF_model")

In [0]:
model_preditions=rf.transform(test_df)

In [0]:
model_preditions.show()

+--------------------+-------+--------------------+--------------------+----------+
|            features|affairs|       rawPrediction|         probability|prediction|
+--------------------+-------+--------------------+--------------------+----------+
|[1.0,22.0,2.5,0.0...|      1|[22.1570633833066...|[0.44314126766613...|       1.0|
|[1.0,27.0,2.5,0.0...|      0|[17.9558486444922...|[0.35911697288984...|       1.0|
|[1.0,27.0,2.5,0.0...|      1|[21.8730744056017...|[0.43746148811203...|       1.0|
|[1.0,27.0,6.0,1.0...|      0|[18.6671064292999...|[0.37334212858599...|       1.0|
|[1.0,27.0,6.0,2.0...|      1|[18.6604246421417...|[0.37320849284283...|       1.0|
|[1.0,27.0,9.0,1.0...|      1|[18.3610605852142...|[0.36722121170428...|       1.0|
|[1.0,32.0,2.5,1.0...|      0|[22.6733214919040...|[0.45346642983808...|       1.0|
|[1.0,32.0,13.0,0....|      1|[18.9976720750195...|[0.37995344150039...|       1.0|
|[1.0,32.0,13.0,1....|      1|[18.0822924429207...|[0.36164584885841...|    