From 9efa4adfa501bf6aacecae01698cf6580c2be8de Mon Sep 17 00:00:00 2001 From: Mateusz Fedoryszak Date: Mon, 14 Sep 2015 17:05:17 +0200 Subject: [PATCH] initial optimisation: reduce number of passes --- .../pl/edu/icm/sparkling_ferns/Fern.scala | 61 +++++++++++-------- .../edu/icm/sparkling_ferns/FernBuilder.scala | 26 ++++++++ .../edu/icm/sparkling_ferns/FernForest.scala | 45 ++++++++++++-- 3 files changed, 99 insertions(+), 33 deletions(-) create mode 100644 src/main/scala/pl/edu/icm/sparkling_ferns/FernBuilder.scala diff --git a/src/main/scala/pl/edu/icm/sparkling_ferns/Fern.scala b/src/main/scala/pl/edu/icm/sparkling_ferns/Fern.scala index 4f6f586..4746321 100644 --- a/src/main/scala/pl/edu/icm/sparkling_ferns/Fern.scala +++ b/src/main/scala/pl/edu/icm/sparkling_ferns/Fern.scala @@ -112,31 +112,8 @@ class Fern(val presetLabels: Option[Array[Double]] = None) { } def computeScores(training: RDD[(Double, Int)], numDistinctPoints: Int, labels: Array[Double]) = { - val aggregated = training.groupBy(identity).map(x => (x._1, x._2.size)).collect() - - val labelsRev = labels.toList.zipWithIndex.toMap - val numLabels = labels.length - - val objectsInLeafPerLabel = Array.fill[Long](numLabels, numDistinctPoints)(1) - val objectsInLeaf = Array.fill[Long](numDistinctPoints)(0) - val objectsPerLabel = Array.fill[Long](numLabels)(0) - - aggregated.foreach{ case ((label, pointIdx), count) => - val labelIdx = labelsRev(label) - objectsInLeafPerLabel(labelIdx)(pointIdx) += count - objectsInLeaf(pointIdx) += count - objectsPerLabel(labelIdx) += count - } - - val numSamples = objectsPerLabel.sum - - val scores = Array.tabulate[Double](numLabels, numDistinctPoints) { case (label, pointIdx) => log( - (objectsInLeafPerLabel(label)(pointIdx) + 1).toDouble/(objectsInLeaf(pointIdx) + numLabels) - * - (numSamples + numLabels).toDouble/(objectsPerLabel(label) + 1) - )} - - scores + val aggregated = training.groupBy(identity).map(x => (x._1, x._2.size.toLong)).collect() + Fern.computeScores(aggregated, numDistinctPoints, labels) } } @@ -200,8 +177,38 @@ object Fern { } def sampleFeatureIndices(data: RDD[LabeledPoint], numFeatures: Int): List[Int] = { - val allFeaturesNo = data.first().features.size - Random.shuffle((0 until allFeaturesNo).toList).take(numFeatures).sorted + val numFeaturesInData = data.first().features.size + sampleFeatureIndices(numFeaturesInData, numFeatures) + } + + def sampleFeatureIndices(numFeaturesInData: Int, numFeatures: Int): List[Int] = { + Random.shuffle((0 until numFeaturesInData).toList).take(numFeatures).sorted + } + + def computeScores(aggregated: Array[((Double, Int), Long)], numDistinctPoints: Int, labels: Array[Double]) = { + val labelsRev = labels.toList.zipWithIndex.toMap + val numLabels = labels.length + + val objectsInLeafPerLabel = Array.fill[Long](numLabels, numDistinctPoints)(1) + val objectsInLeaf = Array.fill[Long](numDistinctPoints)(0) + val objectsPerLabel = Array.fill[Long](numLabels)(0) + + aggregated.foreach{ case ((label, pointIdx), count) => + val labelIdx = labelsRev(label) + objectsInLeafPerLabel(labelIdx)(pointIdx) += count + objectsInLeaf(pointIdx) += count + objectsPerLabel(labelIdx) += count + } + + val numSamples = objectsPerLabel.sum + + val scores = Array.tabulate[Double](numLabels, numDistinctPoints) { case (label, pointIdx) => log( + (objectsInLeafPerLabel(label)(pointIdx) + 1).toDouble/(objectsInLeaf(pointIdx) + numLabels) + * + (numSamples + numLabels).toDouble/(objectsPerLabel(label) + 1) + )} + + scores } def shuffleFeatureValues(data: RDD[LabeledPoint], featureIndex: Int): RDD[(LabeledPoint, LabeledPoint)] = { diff --git a/src/main/scala/pl/edu/icm/sparkling_ferns/FernBuilder.scala b/src/main/scala/pl/edu/icm/sparkling_ferns/FernBuilder.scala new file mode 100644 index 0000000..e1a86a2 --- /dev/null +++ b/src/main/scala/pl/edu/icm/sparkling_ferns/FernBuilder.scala @@ -0,0 +1,26 @@ +package pl.edu.icm.sparkling_ferns + +import breeze.numerics._ +import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.rdd.RDD + +/** + * @author Mateusz Fedoryszak (m.fedoryszak@icm.edu.pl) + */ +class FernBuilder(featureIndices: List[Int], binarisers: List[FeatureBinariser]) extends Serializable { + def toIndex(featureVector: Vector): Int = { + val features = featureVector.toArray + val selected = featureIndices.map(features) + + Fern.toPointIndex(selected, binarisers) + } + + + + def build(counts: Array[((Double, Int), Long)], labels: Array[Double]): FernModel = { + val numFeatures = featureIndices.length + val numDistinctPoints = 1 << numFeatures + val scores = Fern.computeScores(counts, numDistinctPoints, labels) + new FernModel(labels, featureIndices, binarisers, scores) + } +} diff --git a/src/main/scala/pl/edu/icm/sparkling_ferns/FernForest.scala b/src/main/scala/pl/edu/icm/sparkling_ferns/FernForest.scala index aedb26d..7091a5a 100644 --- a/src/main/scala/pl/edu/icm/sparkling_ferns/FernForest.scala +++ b/src/main/scala/pl/edu/icm/sparkling_ferns/FernForest.scala @@ -1,5 +1,7 @@ package pl.edu.icm.sparkling_ferns +import breeze.stats.distributions.Poisson +import org.apache.spark.SparkContext._ import org.apache.spark.mllib.classification.ClassificationModel import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.regression.LabeledPoint @@ -28,18 +30,49 @@ case class FernForestModelWithStats(model: FernForestModel, oobConfusionMatrix: */ class FernForest { def run(data: RDD[LabeledPoint], numFerns: Int, numFeatures: Int, categoricalFeaturesInfo: Map[Int, Int]): FernForestModel = { - val labels = util.extractLabels(data) - new FernForestModel(List.fill(numFerns)(Fern.train(data, numFeatures, categoricalFeaturesInfo, labels))) + runAndAssess(data, numFerns, numFeatures, categoricalFeaturesInfo).model + //val labels = util.extractLabels(data) + //new FernForestModel(List.fill(numFerns)(Fern.train(data, numFeatures, categoricalFeaturesInfo, labels))) } def runAndAssess(data: RDD[LabeledPoint], numFerns: Int, numFeatures: Int, categoricalFeaturesInfo: Map[Int, Int]): FernForestModelWithStats = { val labels = util.extractLabels(data) - val modelsWithStats = List.fill(numFerns)(Fern.trainAndAssess(data, numFeatures, categoricalFeaturesInfo, labels)) - val featureImportance = modelsWithStats.flatMap(_.featureImportance).groupBy(_._1).map{case (idx, list) => (idx, util.mean(list.unzip._2))}.toList - val confusionMatrix = modelsWithStats.flatMap(_.oobConfusionMatrix).groupBy(_._1).map{case (cell, list) => (cell, list.unzip._2.sum)}.toList + val numFeaturesInData = data.take(1).head.features.size + + val withMultipliers = data.map(x => (x, Array.fill(numFerns)(Poisson.distribution(1.0).draw()))) + + val featureIndicesPerFern = Array.fill(numFerns)(Fern.sampleFeatureIndices(numFeaturesInData, numFeatures)) + + val binarisersPerFern = Array.tabulate(numFerns)(i => + Fern.sampleBinarisers( + withMultipliers.flatMap{case (point, muls) => List.fill(muls(i))(point)}, + featureIndicesPerFern(i), categoricalFeaturesInfo)) + + val fernBuilders = (0 until numFerns).map{i => + new FernBuilder(featureIndicesPerFern(i), binarisersPerFern(i)) + } + + val counts = withMultipliers.flatMap { case (point, muls) => + (0 until numFerns).map { i => + ((i, point.label, fernBuilders(i).toIndex(point.features)), muls(i).toLong) + } + }.reduceByKey(_ + _).collect() + + val countsPerFern = counts.groupBy(_._1._1).mapValues(_.map{ case ((_, label, idx), count) => (label, idx) -> count}) + + val ferns = (0 until numFerns).toList.map { i => fernBuilders(i).build(countsPerFern(i), labels)} + + val model = new FernForestModel(ferns) + + val confusionMatrix = withMultipliers.flatMap{ case (point, muls) => + val fernIndices = muls.toList.zipWithIndex.filter(_._1 == 0).map(_._2) + fernIndices.map(ferns).map(fern => ((point.label, fern.predict(point.features)), 1l)) + }.reduceByKey(_ + _).collect().toList + + val modelsWithStats = List.fill(numFerns)(Fern.trainAndAssess(data, numFeatures, categoricalFeaturesInfo, labels)) - val model = new FernForestModel(modelsWithStats.map(_.model)) + val featureImportance = Nil //modelsWithStats.flatMap(_.featureImportance).groupBy(_._1).map{case (idx, list) => (idx, util.mean(list.unzip._2))}.toList FernForestModelWithStats(model, confusionMatrix, featureImportance) }