/
random_forest.py
57 lines (42 loc) · 1.91 KB
/
random_forest.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
from pyspark.ml.classification import RandomForestClassifier
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
from pyspark.ml.feature import VectorAssembler
from pyspark.mllib.evaluation import MulticlassMetrics
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
import numpy as np
# Creating spark session
spark = SparkSession.builder.appName("DecisionTree App").getOrCreate()
spark.sparkContext.setLogLevel("ERROR")
# Loading the data
data = spark.read.format("csv").option("header", True) \
.option("inferSchema", True) \
.option("delimiter", ",") \
.load("D:\\UMKC\\__Spring2020\\CS5590BDP\Module-2\\Lesson-7\\MachineLearning\\data\\adult.data")
data.printSchema()
data = data.withColumn("X", F.when(F.col("X") == ' <=50K', 0).when(F.col("X") == ' >50K', 1))
data = data.withColumnRenamed("X", "label")
data = data.select(data.label.cast("double"),"age", "education-num", "hours-per-week")
data.show()
assembler = VectorAssembler(inputCols=data.columns[1:], outputCol="features")
data = assembler.transform(data)
data.show()
# Splitting the data into training and data set
training, test = data.select("label","features").randomSplit([0.70, 0.30])
# Create Random Forest model and fit the model with training dataset
rf = RandomForestClassifier()
model = rf.fit(training)
# Generate prediction from test dataset
pred = model.transform(test)
# Evaluate the accuracy of the model
evaluator = MulticlassClassificationEvaluator()
accuracy = evaluator.evaluate(pred)
# Show model accuracy
print("Accuracy:", accuracy)
# Report
predictionAndLabels = pred.select("prediction", "label").rdd
metrics = MulticlassMetrics(predictionAndLabels)
print("Confusion Matrix:", metrics.confusionMatrix())
print("Precision:", metrics.precision())
print("Recall:", metrics.recall())
print("F-measure:", metrics.fMeasure())