In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from pyspark.ml.regression import GeneralizedLinearRegression
import pyspark.sql.functions as func
from pyspark.ml.linalg import Vectors
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.tuning import CrossValidator, ParamGridBuilder
from pyspark.ml.classification import DecisionTreeClassifier
from pyspark.ml.evaluation import MulticlassClassificationEvaluator


#data = spark.read.format('csv').options(header='true', inferSchema='true').load('/FileStore/tables/complaintofficer.csv')
data = spark.table("complaints")

cols = ['gender1','race1','age','complaintsbefore2017','complaints2017']
col = ['gender1','race1','age','complaintsbefore2017']
df = data[cols]

assembler = VectorAssembler(inputCols=col, outputCol="features")
df = assembler.transform(df.na.fill(0))

#cmp = data.select('complaints2017').collect()

dt = DecisionTreeClassifier(labelCol='complaints2017',featuresCol='features')

train, test = df.randomSplit([0.7, 0.3])

model = dt.fit(train)

pred = model.transform(test)

evaluator = MulticlassClassificationEvaluator(
    labelCol="complaints2017", predictionCol="prediction", metricName="accuracy")
accuracy = evaluator.evaluate(pred)
print("Test Error = %g " % (1.0 - accuracy))

print(model)

In [2]:
display(pred)


gender1,race1,age,complaintsbefore2017,complaints2017,features,rawPrediction,probability,prediction
1,1,0,0,0,"List(1, 4, List(), List(1.0, 1.0, 0.0, 0.0))","List(1, 3, List(), List(773.0, 33.0, 0.0))","List(1, 3, List(), List(0.9590570719602978, 0.04094292803970223, 0.0))",0.0
1,1,22,0,0,"List(1, 4, List(), List(1.0, 1.0, 22.0, 0.0))","List(1, 3, List(), List(773.0, 33.0, 0.0))","List(1, 3, List(), List(0.9590570719602978, 0.04094292803970223, 0.0))",0.0
1,1,22,0,0,"List(1, 4, List(), List(1.0, 1.0, 22.0, 0.0))","List(1, 3, List(), List(773.0, 33.0, 0.0))","List(1, 3, List(), List(0.9590570719602978, 0.04094292803970223, 0.0))",0.0
1,1,22,0,0,"List(1, 4, List(), List(1.0, 1.0, 22.0, 0.0))","List(1, 3, List(), List(773.0, 33.0, 0.0))","List(1, 3, List(), List(0.9590570719602978, 0.04094292803970223, 0.0))",0.0
1,1,22,0,0,"List(1, 4, List(), List(1.0, 1.0, 22.0, 0.0))","List(1, 3, List(), List(773.0, 33.0, 0.0))","List(1, 3, List(), List(0.9590570719602978, 0.04094292803970223, 0.0))",0.0
1,1,22,0,0,"List(1, 4, List(), List(1.0, 1.0, 22.0, 0.0))","List(1, 3, List(), List(773.0, 33.0, 0.0))","List(1, 3, List(), List(0.9590570719602978, 0.04094292803970223, 0.0))",0.0
1,1,22,0,0,"List(1, 4, List(), List(1.0, 1.0, 22.0, 0.0))","List(1, 3, List(), List(773.0, 33.0, 0.0))","List(1, 3, List(), List(0.9590570719602978, 0.04094292803970223, 0.0))",0.0
1,1,22,0,0,"List(1, 4, List(), List(1.0, 1.0, 22.0, 0.0))","List(1, 3, List(), List(773.0, 33.0, 0.0))","List(1, 3, List(), List(0.9590570719602978, 0.04094292803970223, 0.0))",0.0
1,1,22,0,0,"List(1, 4, List(), List(1.0, 1.0, 22.0, 0.0))","List(1, 3, List(), List(773.0, 33.0, 0.0))","List(1, 3, List(), List(0.9590570719602978, 0.04094292803970223, 0.0))",0.0
1,1,23,0,0,"List(1, 4, List(), List(1.0, 1.0, 23.0, 0.0))","List(1, 3, List(), List(773.0, 33.0, 0.0))","List(1, 3, List(), List(0.9590570719602978, 0.04094292803970223, 0.0))",0.0


In [3]:
print(max(pred.select('prediction').collect()))

In [4]:
prob = pred.select('probability').collect()


In [5]:
x = pred.filter('prediction==1')
display(x)

gender1,race1,age,complaintsbefore2017,complaints2017,features,rawPrediction,probability,prediction
1,4,32,2,0,"List(1, 4, List(), List(1.0, 4.0, 32.0, 2.0))","List(1, 3, List(), List(0.0, 1.0, 0.0))","List(1, 3, List(), List(0.0, 1.0, 0.0))",1.0
2,3,32,4,0,"List(1, 4, List(), List(2.0, 3.0, 32.0, 4.0))","List(1, 3, List(), List(0.0, 1.0, 0.0))","List(1, 3, List(), List(0.0, 1.0, 0.0))",1.0
2,4,32,2,0,"List(1, 4, List(), List(2.0, 4.0, 32.0, 2.0))","List(1, 3, List(), List(0.0, 1.0, 0.0))","List(1, 3, List(), List(0.0, 1.0, 0.0))",1.0


In [6]:
print(x.count())