Skip to content

Commit

Permalink
Multi-label metrics: Hamming-loss, strict and normal accuracy, fix to…
Browse files Browse the repository at this point in the history
… macro measures, bunch of tests
  • Loading branch information
avulanov committed Jun 30, 2014
1 parent ad62df0 commit 40593f5
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,33 @@ import org.apache.spark.SparkContext._
*/
class MultilabelMetrics(predictionAndLabels:RDD[(Set[Double], Set[Double])]) extends Logging{

private lazy val numDocs = predictionAndLabels.count()
private lazy val numDocs = predictionAndLabels.count

private lazy val numLabels = predictionAndLabels.flatMap{case(_, labels) => labels}.distinct.count

/**
* Returns strict Accuracy
* (for equal sets of labels)
* @return strictAccuracy.
*/
lazy val strictAccuracy = predictionAndLabels.filter{case(predictions, labels) =>
predictions == labels}.count.toDouble / numDocs

/**
* Returns Accuracy
* @return Accuracy.
*/
lazy val accuracy = predictionAndLabels.map{ case(predictions, labels) =>
labels.intersect(predictions).size.toDouble / labels.union(predictions).size}.
fold(0.0)(_ + _) / numDocs

/**
* Returns Hamming-loss
* @return hammingLoss.
*/
lazy val hammingLoss = (predictionAndLabels.map{ case(predictions, labels) =>
labels.diff(predictions).size + predictions.diff(labels).size}.
fold(0)(_ + _)).toDouble / (numDocs * numLabels)

/**
* Returns Document-based Precision averaged by the number of documents
Expand All @@ -47,31 +73,36 @@ class MultilabelMetrics(predictionAndLabels:RDD[(Set[Double], Set[Double])]) ext
* @return macroRecallDoc.
*/
lazy val macroRecallDoc = (predictionAndLabels.map{ case(predictions, labels) =>
predictions.intersect(labels).size.toDouble / labels.size}.fold(0.0)(_ + _)) / numDocs
labels.intersect(predictions).size.toDouble / labels.size}.fold(0.0)(_ + _)) / numDocs

/**
* Returns Document-based F1-measure averaged by the number of documents
* @return macroRecallDoc.
*/
lazy val macroF1MeasureDoc = (predictionAndLabels.map{ case(predictions, labels) =>
2.0 * predictions.intersect(labels).size /
(predictions.size + labels.size)}.fold(0.0)(_ + _)) / numDocs

/**
* Returns micro-averaged document-based Precision
* (equals to label-based microPrecision)
* @return microPrecisionDoc.
*/
lazy val microPrecisionDoc = {
val (sumTp, sumPredictions) = predictionAndLabels.map{ case(predictions, labels) =>
(predictions.intersect(labels).size, predictions.size)}.
fold((0, 0)){ case((tp1, predictions1), (tp2, predictions2)) =>
(tp1 + tp2, predictions1 + predictions2)}
sumTp.toDouble / sumPredictions
}
lazy val microPrecisionDoc = microPrecisionClass

/**
* Returns micro-averaged document-based Recall
* (equals to label-based microRecall)
* @return microRecallDoc.
*/
lazy val microRecallDoc = {
val (sumTp, sumLabels) = predictionAndLabels.map{ case(predictions, labels) =>
(predictions.intersect(labels).size, labels.size)}.
fold((0, 0)){ case((tp1, labels1), (tp2, labels2)) =>
(tp1 + tp2, labels1 + labels2)}
sumTp.toDouble / sumLabels
}
lazy val microRecallDoc = microRecallClass

/**
* Returns micro-averaged document-based F1-measure
* (equals to label-based microF1measure)
* @return microF1MeasureDoc.
*/
lazy val microF1MeasureDoc = microF1MeasureClass

private lazy val tpPerClass = predictionAndLabels.flatMap{ case(predictions, labels) =>
predictions.intersect(labels).map(category => (category, 1))}.reduceByKey(_ + _).collectAsMap()
Expand Down Expand Up @@ -110,7 +141,9 @@ class MultilabelMetrics(predictionAndLabels:RDD[(Set[Double], Set[Double])]) ext
if((precision + recall) == 0) 0 else 2 * precision * recall / (precision + recall)
}

private lazy val sumTp = tpPerClass.foldLeft(0L){ case(sumTp, (_, tp)) => sumTp + tp}
private lazy val sumTp = tpPerClass.foldLeft(0L){ case(sum, (_, tp)) => sum + tp}
private lazy val sumFpClass = fpPerClass.foldLeft(0L){ case(sum, (_, fp)) => sum + fp}
private lazy val sumFnClass = fnPerClass.foldLeft(0L){ case(sum, (_, fn)) => sum + fn}

/**
* Returns micro-averaged label-based Precision
Expand All @@ -134,10 +167,6 @@ class MultilabelMetrics(predictionAndLabels:RDD[(Set[Double], Set[Double])]) ext
* Returns micro-averaged label-based F1-measure
* @return microRecallClass.
*/
lazy val microF1MeasureClass = {
val precision = microPrecisionClass
val recall = microRecallClass
if((precision + recall) == 0) 0 else 2 * precision * recall / (precision + recall)
}
lazy val microF1MeasureClass = 2.0 * sumTp / (2 * sumTp + sumFnClass + sumFpClass)

}
Original file line number Diff line number Diff line change
Expand Up @@ -47,17 +47,26 @@ class MultilabelMetricsSuite extends FunSuite with LocalSparkContext {
val f1measure0 = 2 * precision0 * recall0 / (precision0 + recall0)
val f1measure1 = 2 * precision1 * recall1 / (precision1 + recall1)
val f1measure2 = 2 * precision2 * recall2 / (precision2 + recall2)
val microPrecisionClass = (4.0 + 2.0 + 2.0) / (4 + 0 + 2 + 1 + 2 + 2)
val microRecallClass = (4.0 + 2.0 + 2.0) / (4 + 1 + 2 + 1 + 2 + 2)
val microF1MeasureClass = 2 * microPrecisionClass * microRecallClass / (microPrecisionClass + microRecallClass)
val sumTp = 4 + 2 + 2
assert(sumTp == (1 + 1 + 0 + 1 + 2 + 2 + 1))
val microPrecisionClass = sumTp.toDouble / (4 + 0 + 2 + 1 + 2 + 2)
val microRecallClass = sumTp.toDouble / (4 + 1 + 2 + 1 + 2 + 2)
val microF1MeasureClass = 2.0 * sumTp.toDouble /
(2 * sumTp.toDouble + (1 + 1 + 2) + (0 + 1 + 2))

val macroPrecisionDoc = 1.0 / 7 * (1.0 / 2 + 1.0 / 2 + 0 + 1.0 / 1 + 2.0 / 2 + 2.0 / 3 + 1.0 / 1.0)
val macroRecallDoc = 1.0 / 7 * (1.0 / 2 + 1.0 / 2 + 0 / 1 + 1.0 / 1 + 2.0 / 2 + 2.0 / 2 + 1.0 / 2)
val macroPrecisionDoc = 1.0 / 7 *
(1.0 / 2 + 1.0 / 2 + 0 + 1.0 / 1 + 2.0 / 2 + 2.0 / 3 + 1.0 / 1.0)
val macroRecallDoc = 1.0 / 7 *
(1.0 / 2 + 1.0 / 2 + 0 / 1 + 1.0 / 1 + 2.0 / 2 + 2.0 / 2 + 1.0 / 2)
val macroF1MeasureDoc = (1.0 / 7) *
2 * ( 1.0 / (2 + 2) + 1.0 / (2 + 2) + 0 + 1.0 / (1 + 1) +
2.0 / (2 + 2) + 2.0 / (3 + 2) + 1.0 / (1 + 2) )

val hammingLoss = (1.0 / (7 * 3)) * (2 + 2 + 1 + 0 + 0 + 1 + 1)

val strictAccuracy = 2.0 / 7
val accuracy = 1.0 / 7 * (1.0 / 3 + 1.0 /3 + 0 + 1.0 / 1 + 2.0 / 2 + 2.0 / 3 + 1.0 / 2)

println("Ev" + metrics.macroPrecisionDoc)
println(macroPrecisionDoc)
println("Ev" + metrics.macroRecallDoc)
println(macroRecallDoc)
assert(math.abs(metrics.precisionClass(0.0) - precision0) < delta)
assert(math.abs(metrics.precisionClass(1.0) - precision1) < delta)
assert(math.abs(metrics.precisionClass(2.0) - precision2) < delta)
Expand All @@ -74,6 +83,11 @@ class MultilabelMetricsSuite extends FunSuite with LocalSparkContext {

assert(math.abs(metrics.macroPrecisionDoc - macroPrecisionDoc) < delta)
assert(math.abs(metrics.macroRecallDoc - macroRecallDoc) < delta)
assert(math.abs(metrics.macroF1MeasureDoc - macroF1MeasureDoc) < delta)

assert(math.abs(metrics.hammingLoss - hammingLoss) < delta)
assert(math.abs(metrics.strictAccuracy - strictAccuracy) < delta)
assert(math.abs(metrics.accuracy - accuracy) < delta)


}
Expand Down

0 comments on commit 40593f5

Please sign in to comment.