In [1]:
%config IPCompleter.greedy=True

In [2]:
from pyspark import SparkContext

sc = SparkContext()

In [3]:
from pyspark.mllib.regression import LabeledPoint

def parsePoint(line):
    values = [int(x) for x in line.split(",")]
    return LabeledPoint(values[0], values[1:])
    

In [4]:
from itertools import islice

training_data = sc.textFile('./mnist/mnist_train.csv')\
        .mapPartitionsWithIndex(
            lambda idx, it: islice(it, 1, None) if idx == 0 else it 
        )\
        .map(parsePoint)

test_data = sc.textFile('./mnist/mnist_test.csv')\
        .mapPartitionsWithIndex(
            lambda idx, it: islice(it, 1, None) if idx == 0 else it 
        )\
        .map(parsePoint)

In [5]:
from pyspark.mllib.classification import LogisticRegressionWithLBFGS

In [6]:
model = LogisticRegressionWithLBFGS.train(training_data, regParam=0, numClasses=10)

In [7]:
test_data.map(lambda d:
    (d.label, model.predict(d.features), model.predict(d.features) == int(d.label))
 ).filter(lambda d: d[2] == False).count()

752

In [8]:
training_data.map(lambda d:
    (d.label, model.predict(d.features), model.predict(d.features) == int(d.label))
 ).filter(lambda d: d[2] == False).count()

3706

In [9]:
pred_labels = test_data.map(lambda d:
    (float(model.predict(d.features)), d.label)
 )

In [10]:
from pyspark.mllib.evaluation import MulticlassMetrics

In [11]:
metrics = MulticlassMetrics(pred_labels)

In [12]:
# Overall statistics
precision = metrics.precision()
recall = metrics.recall()
f1Score = metrics.fMeasure()
print("Summary Stats")
print("Precision = %s" % precision)
print("Recall = %s" % recall)
print("F1 Score = %s" % f1Score)

Summary Stats
Precision = 0.9248
Recall = 0.9248
F1 Score = 0.9248


In [13]:

# Statistics by class
labels = pred_labels.map(lambda lp: lp[1]).distinct().collect()
for label in sorted(labels):
    print("Class %s precision = %s" % (label, metrics.precision(label)))
    print("Class %s recall = %s" % (label, metrics.recall(label)))
    print("Class %s F1 Measure = %s" % (label, metrics.fMeasure(label, beta=1.0)))

Class 0.0 precision = 0.9561316051844466
Class 0.0 recall = 0.9785714285714285
Class 0.0 F1 Measure = 0.967221381744831
Class 1.0 precision = 0.9660869565217391
Class 1.0 recall = 0.9788546255506608
Class 1.0 F1 Measure = 0.9724288840262582
Class 2.0 precision = 0.936734693877551
Class 2.0 recall = 0.8895348837209303
Class 2.0 F1 Measure = 0.9125248508946321
Class 3.0 precision = 0.90234375
Class 3.0 recall = 0.9148514851485149
Class 3.0 F1 Measure = 0.9085545722713865
Class 4.0 precision = 0.934560327198364
Class 4.0 recall = 0.9307535641547862
Class 4.0 F1 Measure = 0.9326530612244899
Class 5.0 precision = 0.9090909090909091
Class 5.0 recall = 0.8632286995515696
Class 5.0 F1 Measure = 0.8855664174813112
Class 6.0 precision = 0.9334016393442623
Class 6.0 recall = 0.9509394572025052
Class 6.0 F1 Measure = 0.9420889348500517
Class 7.0 precision = 0.9269005847953217
Class 7.0 recall = 0.9250972762645915
Class 7.0 F1 Measure = 0.9259980525803312
Class 8.0 precision = 0.8615232443125618
Cl

In [14]:
# Weighted stats
print("Weighted recall = %s" % metrics.weightedRecall)
print("Weighted precision = %s" % metrics.weightedPrecision)
print("Weighted F(1) Score = %s" % metrics.weightedFMeasure())
print("Weighted F(0.5) Score = %s" % metrics.weightedFMeasure(beta=0.5))
print("Weighted false positive rate = %s" % metrics.weightedFalsePositiveRate)

Weighted recall = 0.9248000000000001
Weighted precision = 0.9249076315597062
Weighted F(1) Score = 0.9246777208704127
Weighted F(0.5) Score = 0.9247725529819946
Weighted false positive rate = 0.008289218806257687
