In [1]:
    var sparkConf = new SparkConf().setAppName("Multi class Classification Metrics Example");
    var sc = new SparkContext(sparkConf);

var path =  "../data/mllib/sample_multiclass_classification_data.txt";
    var data = MLUtils.loadLibSVMFile(sc, path);


 Split initial RDD into two... [60% training data, 40% testing data].


In [2]:
    var splits = data.randomSplit([0.6, 0.4], 11);
    var training = splits[0].cache();
    var test = splits[1];


 Run training algorithm to build the model.


In [3]:
    var model = new LogisticRegressionWithLBFGS()
        .setNumClasses(3)
        .run(training);


 Compute raw scores on the test set.


In [4]:
    var predictionAndLabels = test.map(function (lp, model) {
        var prediction = model.predict(lp.getFeatures());
        return new Tuple(prediction, lp.getLabel());
    }, [model]);
    var ret = {};
    ret.model = model;


 Get evaluation metrics.


In [5]:
    ret.metrics = new MulticlassMetrics(predictionAndLabels);

var result = ret;


 Confusion matrix


In [6]:
    var confusion = result.metrics.confusionMatrix();
    print("Confusion matrix: \n" + confusion);


 Overall statistics


In [7]:
    print("Precision = " + result.metrics.precision());
    print("Recall = " + result.metrics.recall());
    print("F1 Score = " + result.metrics.fMeasure());


 Stats by labels


In [8]:
    for (var i = 0; i < result.metrics.labels().length; i++) {
        print("Class " + result.metrics.labels()[i] + " precision = " + result.metrics.precision(result.metrics.labels()[i]));
        print("Class " + result.metrics.labels()[i] + " recall = " + result.metrics.recall(result.metrics.labels()[i]));
        print("Class " + result.metrics.labels()[i] + " F1 score = " + result.metrics.fMeasure(result.metrics.labels()[i]));
    }


Weighted stats


In [9]:
    print("Weighted precision = " + result.metrics.weightedPrecision());
    print("Weighted recall = " + result.metrics.weightedRecall());
    print("Weighted F1 score = " + result.metrics.weightedFMeasure());
    print("Weighted false positive rate = " + result.metrics.weightedFalsePositiveRate());


 Save and load model


In [10]:
    result.model.save(sc, "target/tmp/LogisticRegressionModel");
    var sameModel = LogisticRegressionModel.load(sc, "target/tmp/LogisticRegressionModel");

    sc.stop();
