In [1]:
from pyspark.sql import SparkSession
spark = SparkSession.builder.appName('LogReg').getOrCreate()

In [2]:
from pyspark.ml.classification import LogisticRegression

In [4]:
# load data set
my_data = spark.read.format('libsvm').load('sample_libsvm_data.txt')

In [5]:
my_data.show(5)

+-----+--------------------+
|label|            features|
+-----+--------------------+
|  0.0|(692,[127,128,129...|
|  1.0|(692,[158,159,160...|
|  1.0|(692,[124,125,126...|
|  1.0|(692,[152,153,154...|
|  1.0|(692,[151,152,153...|
+-----+--------------------+
only showing top 5 rows



## Build a model

In [45]:
logReg = LogisticRegression(featuresCol='features',labelCol='label')

In [46]:
fitted_logReg = logReg.fit(my_data)

## Model summary after fitting

In [9]:
summary = fitted_logReg.summary

In [10]:
summary.predictions.printSchema()

root
 |-- label: double (nullable = true)
 |-- features: vector (nullable = true)
 |-- rawPrediction: vector (nullable = true)
 |-- probability: vector (nullable = true)
 |-- prediction: double (nullable = true)



In [11]:
summary.predictions.show(5)

+-----+--------------------+--------------------+--------------------+----------+
|label|            features|       rawPrediction|         probability|prediction|
+-----+--------------------+--------------------+--------------------+----------+
|  0.0|(692,[127,128,129...|[39.9727764450750...|[1.0,4.3655982185...|       0.0|
|  1.0|(692,[158,159,160...|[-35.662380562160...|[3.25105944320044...|       1.0|
|  1.0|(692,[124,125,126...|[-39.336799621156...|[8.24603148700906...|       1.0|
|  1.0|(692,[152,153,154...|[-28.219286248176...|[5.55289803944932...|       1.0|
|  1.0|(692,[151,152,153...|[-28.142070329444...|[5.99865861146384...|       1.0|
+-----+--------------------+--------------------+--------------------+----------+
only showing top 5 rows



## Let's try splitting the data properly and evaluate the fitted results

In [51]:
train_data, test_data = my_data.randomSplit([0.7, 0.3])

In [52]:
fit_final = logReg.fit(train_data)

In [53]:
predictions_and_labels = fit_final.evaluate(test_data)

In [54]:
predictions_and_labels.predictions.show(5)

+-----+--------------------+--------------------+--------------------+----------+
|label|            features|       rawPrediction|         probability|prediction|
+-----+--------------------+--------------------+--------------------+----------+
|  0.0|(692,[95,96,97,12...|[53.3356415452094...|[1.0,6.8647573088...|       0.0|
|  0.0|(692,[100,101,102...|[23.1681786624138...|[0.99999999991326...|       0.0|
|  0.0|(692,[121,122,123...|[80.7466955046080...|[1.0,8.5537334048...|       0.0|
|  0.0|(692,[122,123,124...|[69.7504913868260...|[1.0,5.1020707972...|       0.0|
|  0.0|(692,[122,123,148...|[70.5686579127862...|[1.0,2.2512372861...|       0.0|
+-----+--------------------+--------------------+--------------------+----------+
only showing top 5 rows



In [145]:
# number of true_positive + true_negative
TP_and_TN = predictions_and_labels.predictions.filter(
    predictions_and_labels.predictions['label']==predictions_and_labels.predictions['prediction']).count()
total = predictions_and_labels.predictions.count()
acc = TP_and_TN/total
print("Model accuracy: %.6f" % acc)

Model accuracy: 0.971429


## Try built-in evaluators

In [55]:
from pyspark.ml.evaluation import BinaryClassificationEvaluator, MulticlassClassificationEvaluator

In [142]:
evaluator = BinaryClassificationEvaluator(rawPredictionCol="rawPrediction",labelCol="label",metricName='areaUnderROC')
evaluator.evaluate(predictions_and_labels.predictions)

1.0

In [144]:
# another approach to get accuracy
evaluator2 = MulticlassClassificationEvaluator(predictionCol="prediction",labelCol='label',metricName='accuracy')
evaluator2.evaluate(predictions_and_labels.predictions)

0.9714285714285714

In [113]:
predictions_and_labels.predictions.select('label','prediction').show(5)

+-----+----------+
|label|prediction|
+-----+----------+
|  0.0|       0.0|
|  0.0|       0.0|
|  0.0|       0.0|
|  0.0|       0.0|
|  0.0|       0.0|
+-----+----------+
only showing top 5 rows

