In [43]:
from pyspark.sql import SparkSession
from pyspark.ml.classification import LogisticRegression
from pyspark.ml.evaluation import BinaryClassificationEvaluator, MulticlassClassificationEvaluator

In [2]:
spark = SparkSession.builder.appName('log_reg').getOrCreate()
spark

In [3]:
df = spark.read.format('libsvm').load('sample_libsvm_data.txt')
df.show()

+-----+--------------------+
|label|            features|
+-----+--------------------+
|  0.0|(692,[127,128,129...|
|  1.0|(692,[158,159,160...|
|  1.0|(692,[124,125,126...|
|  1.0|(692,[152,153,154...|
|  1.0|(692,[151,152,153...|
|  0.0|(692,[129,130,131...|
|  1.0|(692,[158,159,160...|
|  1.0|(692,[99,100,101,...|
|  0.0|(692,[154,155,156...|
|  0.0|(692,[127,128,129...|
|  1.0|(692,[154,155,156...|
|  0.0|(692,[153,154,155...|
|  0.0|(692,[151,152,153...|
|  1.0|(692,[129,130,131...|
|  0.0|(692,[154,155,156...|
|  1.0|(692,[150,151,152...|
|  0.0|(692,[124,125,126...|
|  0.0|(692,[152,153,154...|
|  1.0|(692,[97,98,99,12...|
|  1.0|(692,[124,125,126...|
+-----+--------------------+
only showing top 20 rows



In [4]:
df.count()

100

In [5]:
df.printSchema()

root
 |-- label: double (nullable = true)
 |-- features: vector (nullable = true)



In [6]:
lr = LogisticRegression()
lr_model = lr.fit(df)

In [11]:
results = lr_model.summary
results

<pyspark.ml.classification.BinaryLogisticRegressionTrainingSummary at 0x15e2e378410>

In [12]:
results.accuracy

1.0

In [13]:
results.areaUnderROC

1.0

In [14]:
results.falsePositiveRateByLabel

[0.0, 0.0]

In [16]:
results.fMeasureByLabel()

[1.0, 1.0]

In [17]:
results.precisionByLabel

[1.0, 1.0]

In [20]:
results.pr.show()

+--------------------+---------+
|              recall|precision|
+--------------------+---------+
|                 0.0|      1.0|
|0.017543859649122806|      1.0|
| 0.03508771929824561|      1.0|
| 0.05263157894736842|      1.0|
| 0.07017543859649122|      1.0|
| 0.08771929824561403|      1.0|
| 0.10526315789473684|      1.0|
| 0.12280701754385964|      1.0|
| 0.14035087719298245|      1.0|
| 0.15789473684210525|      1.0|
| 0.17543859649122806|      1.0|
| 0.19298245614035087|      1.0|
| 0.21052631578947367|      1.0|
| 0.22807017543859648|      1.0|
| 0.24561403508771928|      1.0|
|  0.2631578947368421|      1.0|
|  0.2807017543859649|      1.0|
|  0.2982456140350877|      1.0|
|  0.3157894736842105|      1.0|
|  0.3333333333333333|      1.0|
+--------------------+---------+
only showing top 20 rows



In [21]:
results.objectiveHistory

[0.6833149135741672,
 0.013093751340219117,
 0.010701411598307597,
 0.003694038375993421,
 0.0021399761338157177,
 0.0011363374603547136,
 0.0006407228823852692,
 0.00036705428551573704,
 0.0002260026422054868,
 0.00015094375366593175,
 0.0001071063166363229,
 7.553627222762171e-05,
 2.580774200701701e-05,
 1.4067994122146346e-05,
 7.617742760820017e-06,
 3.870751420627088e-06,
 1.9854616143530575e-06,
 1.0073600214465829e-06,
 5.106945419720955e-07,
 2.5790151660584507e-07,
 1.3000317763446567e-07,
 6.541799796737615e-08,
 3.288742060702569e-08,
 1.6455075322246003e-08,
 8.237108339716514e-09,
 4.121711294758356e-09,
 2.0624916722421435e-09,
 1.0319596771097586e-09]

In [24]:
results.roc.show()

+---+--------------------+
|FPR|                 TPR|
+---+--------------------+
|0.0|                 0.0|
|0.0|0.017543859649122806|
|0.0| 0.03508771929824561|
|0.0| 0.05263157894736842|
|0.0| 0.07017543859649122|
|0.0| 0.08771929824561403|
|0.0| 0.10526315789473684|
|0.0| 0.12280701754385964|
|0.0| 0.14035087719298245|
|0.0| 0.15789473684210525|
|0.0| 0.17543859649122806|
|0.0| 0.19298245614035087|
|0.0| 0.21052631578947367|
|0.0| 0.22807017543859648|
|0.0| 0.24561403508771928|
|0.0|  0.2631578947368421|
|0.0|  0.2807017543859649|
|0.0|  0.2982456140350877|
|0.0|  0.3157894736842105|
|0.0|  0.3333333333333333|
+---+--------------------+
only showing top 20 rows



In [25]:
results.weightedPrecision

1.0

In [26]:
results.weightedRecall

1.0

In [28]:
results.weightedFMeasure()

1.0

In [22]:
results.predictions.show()

+-----+--------------------+--------------------+--------------------+----------+
|label|            features|       rawPrediction|         probability|prediction|
+-----+--------------------+--------------------+--------------------+----------+
|  0.0|(692,[127,128,129...|[20.3777627514872...|[0.99999999858729...|       0.0|
|  1.0|(692,[158,159,160...|[-21.114014198868...|[6.76550380000472...|       1.0|
|  1.0|(692,[124,125,126...|[-23.743613234676...|[4.87842678716177...|       1.0|
|  1.0|(692,[152,153,154...|[-19.192574012720...|[4.62137287298144...|       1.0|
|  1.0|(692,[151,152,153...|[-20.125398874699...|[1.81823629113068...|       1.0|
|  0.0|(692,[129,130,131...|[20.4890549504196...|[0.99999999873608...|       0.0|
|  1.0|(692,[158,159,160...|[-21.082940212814...|[6.97903542823766...|       1.0|
|  1.0|(692,[99,100,101,...|[-19.622713503550...|[3.00582577446132...|       1.0|
|  0.0|(692,[154,155,156...|[21.1594863606582...|[0.99999999935352...|       0.0|
|  0.0|(692,[127

In [29]:
train_df, test_df = df.randomSplit([0.7,0.3])
train_df.show()

+-----+--------------------+
|label|            features|
+-----+--------------------+
|  0.0|(692,[95,96,97,12...|
|  0.0|(692,[100,101,102...|
|  0.0|(692,[121,122,123...|
|  0.0|(692,[122,123,124...|
|  0.0|(692,[122,123,148...|
|  0.0|(692,[123,124,125...|
|  0.0|(692,[124,125,126...|
|  0.0|(692,[124,125,126...|
|  0.0|(692,[124,125,126...|
|  0.0|(692,[124,125,126...|
|  0.0|(692,[124,125,126...|
|  0.0|(692,[124,125,126...|
|  0.0|(692,[125,126,127...|
|  0.0|(692,[126,127,128...|
|  0.0|(692,[126,127,128...|
|  0.0|(692,[126,127,128...|
|  0.0|(692,[126,127,128...|
|  0.0|(692,[126,127,128...|
|  0.0|(692,[126,127,128...|
|  0.0|(692,[126,127,128...|
+-----+--------------------+
only showing top 20 rows



In [30]:
test_df.show()

+-----+--------------------+
|label|            features|
+-----+--------------------+
|  0.0|(692,[98,99,100,1...|
|  0.0|(692,[123,124,125...|
|  0.0|(692,[123,124,125...|
|  0.0|(692,[126,127,128...|
|  0.0|(692,[152,153,154...|
|  0.0|(692,[154,155,156...|
|  0.0|(692,[154,155,156...|
|  0.0|(692,[234,235,237...|
|  1.0|(692,[99,100,101,...|
|  1.0|(692,[124,125,126...|
|  1.0|(692,[125,126,127...|
|  1.0|(692,[125,126,153...|
|  1.0|(692,[127,128,129...|
|  1.0|(692,[127,128,155...|
|  1.0|(692,[128,129,130...|
|  1.0|(692,[128,129,130...|
|  1.0|(692,[128,129,130...|
|  1.0|(692,[129,130,131...|
|  1.0|(692,[150,151,152...|
|  1.0|(692,[151,152,153...|
+-----+--------------------+
only showing top 20 rows



In [31]:
train_df.describe().show()

+-------+------------------+
|summary|             label|
+-------+------------------+
|  count|                79|
|   mean|0.5569620253164557|
| stddev|0.4999188509286224|
|    min|               0.0|
|    max|               1.0|
+-------+------------------+



In [32]:
test_df.describe().show()

+-------+-------------------+
|summary|              label|
+-------+-------------------+
|  count|                 21|
|   mean| 0.6190476190476191|
| stddev|0.49761335152811925|
|    min|                0.0|
|    max|                1.0|
+-------+-------------------+



In [33]:
lr = LogisticRegression()
lr_model = lr.fit(train_df)
lr_model.summary

<pyspark.ml.classification.BinaryLogisticRegressionTrainingSummary at 0x15e4020e690>

In [34]:
test_results = lr_model.evaluate(test_df)
test_results

<pyspark.ml.classification.BinaryLogisticRegressionSummary at 0x15e40239590>

In [35]:
test_results.accuracy

1.0

In [37]:
test_results.roc.show()

+-----+-------------------+
|  FPR|                TPR|
+-----+-------------------+
|  0.0|                0.0|
|  0.0|0.07692307692307693|
|  0.0|0.15384615384615385|
|  0.0|0.23076923076923078|
|  0.0| 0.3076923076923077|
|  0.0|0.38461538461538464|
|  0.0|0.46153846153846156|
|  0.0| 0.5384615384615384|
|  0.0| 0.6153846153846154|
|  0.0| 0.6923076923076923|
|  0.0| 0.7692307692307693|
|  0.0| 0.8461538461538461|
|  0.0| 0.9230769230769231|
|  0.0|                1.0|
|0.125|                1.0|
| 0.25|                1.0|
|0.375|                1.0|
|  0.5|                1.0|
|0.625|                1.0|
| 0.75|                1.0|
+-----+-------------------+
only showing top 20 rows



In [38]:
test_results.pr.show()

+-------------------+------------------+
|             recall|         precision|
+-------------------+------------------+
|                0.0|               1.0|
|0.07692307692307693|               1.0|
|0.15384615384615385|               1.0|
|0.23076923076923078|               1.0|
| 0.3076923076923077|               1.0|
|0.38461538461538464|               1.0|
|0.46153846153846156|               1.0|
| 0.5384615384615384|               1.0|
| 0.6153846153846154|               1.0|
| 0.6923076923076923|               1.0|
| 0.7692307692307693|               1.0|
| 0.8461538461538461|               1.0|
| 0.9230769230769231|               1.0|
|                1.0|               1.0|
|                1.0|0.9285714285714286|
|                1.0|0.8666666666666667|
|                1.0|            0.8125|
|                1.0|0.7647058823529411|
|                1.0|0.7222222222222222|
|                1.0|0.6842105263157895|
+-------------------+------------------+
only showing top

In [39]:
test_results.areaUnderROC

1.0

In [41]:
test_results.fMeasureByLabel()

[1.0, 1.0]

In [42]:
test_results.predictions.show()

+-----+--------------------+--------------------+--------------------+----------+
|label|            features|       rawPrediction|         probability|prediction|
+-----+--------------------+--------------------+--------------------+----------+
|  0.0|(692,[98,99,100,1...|[32.2163396254175...|[0.99999999999998...|       0.0|
|  0.0|(692,[123,124,125...|[36.4345001296687...|[0.99999999999999...|       0.0|
|  0.0|(692,[123,124,125...|[35.8280567859877...|[0.99999999999999...|       0.0|
|  0.0|(692,[126,127,128...|[24.9725866590964...|[0.99999999998572...|       0.0|
|  0.0|(692,[152,153,154...|[11.4143807763787...|[0.99998896448458...|       0.0|
|  0.0|(692,[154,155,156...|[14.2324152110700...|[0.99999934091667...|       0.0|
|  0.0|(692,[154,155,156...|[19.3791935971361...|[0.99999999616537...|       0.0|
|  0.0|(692,[234,235,237...|[0.99563574687644...|[0.73019964960712...|       0.0|
|  1.0|(692,[99,100,101,...|[-2.3808212504864...|[0.08464691204540...|       1.0|
|  1.0|(692,[124

In [44]:
evaluator = BinaryClassificationEvaluator()
area_under_roc = evaluator.evaluate(test_results.predictions)
area_under_roc # Area under ROC curve

1.0

In [45]:
evaluator = BinaryClassificationEvaluator(metricName='areaUnderPR')
area_under_pr = evaluator.evaluate(test_results.predictions)
area_under_pr # Area under precision-recall curve

1.0