# Step 1: Create the SparkSession Object

In [1]:
from pyspark.sql import SparkSession

In [2]:
spark = SparkSession.builder.appName('random_forest').getOrCreate()

# Step 2: Read the Dataset

In [4]:
df = spark.read.csv('affairs.csv', inferSchema=True, header=True)

# Step 3: Exploratory Data Analysis

In [5]:
def shape(df):
    return (df.count(), len(df.columns))

In [6]:
df.columns

['rate_marriage', 'age', 'yrs_married', 'children', 'religious', 'affairs']

In [7]:
df.show(5)

+-------------+----+-----------+--------+---------+-------+
|rate_marriage| age|yrs_married|children|religious|affairs|
+-------------+----+-----------+--------+---------+-------+
|            5|32.0|        6.0|     1.0|        3|      0|
|            4|22.0|        2.5|     0.0|        2|      0|
|            3|32.0|        9.0|     3.0|        3|      1|
|            3|27.0|       13.0|     3.0|        1|      1|
|            4|22.0|        2.5|     0.0|        1|      1|
+-------------+----+-----------+--------+---------+-------+
only showing top 5 rows



In [12]:
df.select(df.columns[:-1]).describe().show()

+-------+------------------+------------------+-----------------+------------------+------------------+
|summary|     rate_marriage|               age|      yrs_married|          children|         religious|
+-------+------------------+------------------+-----------------+------------------+------------------+
|  count|              6366|              6366|             6366|              6366|              6366|
|   mean| 4.109644989004084|29.082862079798932| 9.00942507068803|1.3968740182218033|2.4261702796104303|
| stddev|0.9614295945655025| 6.847881883668817|7.280119972766412| 1.433470828560344|0.8783688402641785|
|    min|                 1|              17.5|              0.5|               0.0|                 1|
|    max|                 5|              42.0|             23.0|               5.5|                 4|
+-------+------------------+------------------+-----------------+------------------+------------------+



In [9]:
df.printSchema()

root
 |-- rate_marriage: integer (nullable = true)
 |-- age: double (nullable = true)
 |-- yrs_married: double (nullable = true)
 |-- children: double (nullable = true)
 |-- religious: integer (nullable = true)
 |-- affairs: integer (nullable = true)



In [13]:
df.groupBy('affairs').count().show()

+-------+-----+
|affairs|count|
+-------+-----+
|      1| 2053|
|      0| 4313|
+-------+-----+



In [17]:
df.groupBy('rate_marriage', 'affairs').count().orderBy('rate_marriage', ascending=True).show()

+-------------+-------+-----+
|rate_marriage|affairs|count|
+-------------+-------+-----+
|            1|      0|   25|
|            1|      1|   74|
|            2|      1|  221|
|            2|      0|  127|
|            3|      1|  547|
|            3|      0|  446|
|            4|      0| 1518|
|            4|      1|  724|
|            5|      0| 2197|
|            5|      1|  487|
+-------------+-------+-----+



In [19]:
df.groupBy('religious', 'affairs').count().orderBy('religious', ascending=True).show()

+---------+-------+-----+
|religious|affairs|count|
+---------+-------+-----+
|        1|      1|  408|
|        1|      0|  613|
|        2|      0| 1448|
|        2|      1|  819|
|        3|      0| 1715|
|        3|      1|  707|
|        4|      0|  537|
|        4|      1|  119|
+---------+-------+-----+



In [20]:
df.groupBy('children', 'affairs').count().orderBy('children', ascending=True).show()

+--------+-------+-----+
|children|affairs|count|
+--------+-------+-----+
|     0.0|      0| 1912|
|     0.0|      1|  502|
|     1.0|      1|  412|
|     1.0|      0|  747|
|     2.0|      1|  608|
|     2.0|      0|  873|
|     3.0|      1|  321|
|     3.0|      0|  460|
|     4.0|      1|  131|
|     4.0|      0|  197|
|     5.5|      1|   79|
|     5.5|      0|  124|
+--------+-------+-----+



In [21]:
df.groupBy('affairs').mean().show()

+-------+------------------+------------------+------------------+------------------+------------------+------------+
|affairs|avg(rate_marriage)|          avg(age)|  avg(yrs_married)|     avg(children)|    avg(religious)|avg(affairs)|
+-------+------------------+------------------+------------------+------------------+------------------+------------+
|      1|3.6473453482708234|30.537018996590355|11.152459814905017|1.7289332683877252| 2.261568436434486|         1.0|
|      0| 4.329700904242986| 28.39067934152562| 7.989334569904939|1.2388128912589844|2.5045212149316023|         0.0|
+-------+------------------+------------------+------------------+------------------+------------------+------------+



# Step 4: Feature Engineering

In [23]:
from pyspark.ml.feature import VectorAssembler

In [24]:
df_assembler = VectorAssembler(inputCols=['rate_marriage', 'age', 'yrs_married', 'children', 'religious'],
                               outputCol='features')

In [25]:
df = df_assembler.transform(df)

In [26]:
df.printSchema()

root
 |-- rate_marriage: integer (nullable = true)
 |-- age: double (nullable = true)
 |-- yrs_married: double (nullable = true)
 |-- children: double (nullable = true)
 |-- religious: integer (nullable = true)
 |-- affairs: integer (nullable = true)
 |-- features: vector (nullable = true)



In [29]:
model_df = df.select('features', 'affairs')
model_df.show(3)

+--------------------+-------+
|            features|affairs|
+--------------------+-------+
|[5.0,32.0,6.0,1.0...|      0|
|[4.0,22.0,2.5,0.0...|      0|
|[3.0,32.0,9.0,3.0...|      1|
+--------------------+-------+
only showing top 3 rows



In [30]:
train_df, test_df = model_df.randomSplit([.75, .25])
shape(train_df), shape(test_df)

((4755, 2), (1611, 2))

In [31]:
train_df.groupBy('affairs').count().show()

+-------+-----+
|affairs|count|
+-------+-----+
|      1| 1554|
|      0| 3201|
+-------+-----+



In [32]:
test_df.groupBy('affairs').count().show()

+-------+-----+
|affairs|count|
+-------+-----+
|      1|  499|
|      0| 1112|
+-------+-----+



# Step 6: Build and Train Random Forest Model

In [33]:
from pyspark.ml.classification import RandomForestClassifier

In [36]:
rf_classifier = RandomForestClassifier(labelCol='affairs', numTrees=50).fit(train_df)

# Step 7: Evaluation on Test Data

In [35]:
rf_predictions = rf_classifier.transform(test_df)
rf_predictions.show()

+--------------------+-------+--------------------+--------------------+----------+
|            features|affairs|       rawPrediction|         probability|prediction|
+--------------------+-------+--------------------+--------------------+----------+
|[1.0,22.0,2.5,0.0...|      1|[24.9078583298722...|[0.49815716659744...|       1.0|
|[1.0,22.0,2.5,1.0...|      1|[22.2805630024658...|[0.44561126004931...|       1.0|
|[1.0,27.0,2.5,0.0...|      1|[25.0809376508686...|[0.50161875301737...|       0.0|
|[1.0,27.0,6.0,1.0...|      0|[16.8320282278415...|[0.33664056455683...|       1.0|
|[1.0,27.0,6.0,1.0...|      1|[17.9560830656535...|[0.35912166131307...|       1.0|
|[1.0,27.0,6.0,3.0...|      0|[14.4777010584459...|[0.28955402116891...|       1.0|
|[1.0,27.0,13.0,2....|      1|[13.3294831600763...|[0.26658966320152...|       1.0|
|[1.0,32.0,2.5,1.0...|      0|[25.2954117727331...|[0.50590823545466...|       0.0|
|[1.0,32.0,16.5,2....|      1|[14.0228129704326...|[0.28045625940865...|    

In [37]:
rf_predictions.groupBy('prediction').count().show()

+----------+-----+
|prediction|count|
+----------+-----+
|       0.0| 1372|
|       1.0|  239|
+----------+-----+



In [41]:
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
from pyspark.ml.evaluation import BinaryClassificationEvaluator

## Accuracy

In [45]:
rf_accuracy = MulticlassClassificationEvaluator(labelCol='affairs', metricName='accuracy').evaluate(rf_predictions)
rf_accuracy

0.7219118559900682

## Precision

In [47]:
rf_precision = MulticlassClassificationEvaluator(labelCol='affairs', metricName='weightedPrecision').evaluate(rf_predictions)
rf_precision

0.7000776746610892

## AUC

In [49]:
rf_auc = BinaryClassificationEvaluator(labelCol='affairs').evaluate(rf_predictions)
rf_auc

0.7438140669828865

In [50]:
rf_classifier.featureImportances

SparseVector(5, {0: 0.6162, 1: 0.0257, 2: 0.2222, 3: 0.0754, 4: 0.0606})

In [51]:
df.schema['features'].metadata['ml_attr']['attrs']

{'numeric': [{'idx': 0, 'name': 'rate_marriage'},
  {'idx': 1, 'name': 'age'},
  {'idx': 2, 'name': 'yrs_married'},
  {'idx': 3, 'name': 'children'},
  {'idx': 4, 'name': 'religious'}]}

# Step 8: Saving the Model

In [52]:
from pyspark.ml.classification import RandomForestClassificationModel

In [53]:
rf_classifier.save('./rf_model')

In [54]:
!ls -a

[36m.[m[m                  .DS_Store          affairs.csv        [36mrf_model[m[m
[36m..[m[m                 [36m.ipynb_checkpoints[m[m main.ipynb


In [56]:
rf = RandomForestClassificationModel.load('./rf_model')
new_predictions = rf.transform(df)

In [57]:
new_predictions

DataFrame[rate_marriage: int, age: double, yrs_married: double, children: double, religious: int, affairs: int, features: vector, rawPrediction: vector, probability: vector, prediction: double]