Skip to content

Commit

Permalink
initial optimisation: reduce number of passes
Browse files Browse the repository at this point in the history
  • Loading branch information
matfed committed Sep 14, 2015
1 parent 520293a commit 9efa4ad
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 33 deletions.
61 changes: 34 additions & 27 deletions src/main/scala/pl/edu/icm/sparkling_ferns/Fern.scala
Expand Up @@ -112,31 +112,8 @@ class Fern(val presetLabels: Option[Array[Double]] = None) {
}

def computeScores(training: RDD[(Double, Int)], numDistinctPoints: Int, labels: Array[Double]) = {
val aggregated = training.groupBy(identity).map(x => (x._1, x._2.size)).collect()

val labelsRev = labels.toList.zipWithIndex.toMap
val numLabels = labels.length

val objectsInLeafPerLabel = Array.fill[Long](numLabels, numDistinctPoints)(1)
val objectsInLeaf = Array.fill[Long](numDistinctPoints)(0)
val objectsPerLabel = Array.fill[Long](numLabels)(0)

aggregated.foreach{ case ((label, pointIdx), count) =>
val labelIdx = labelsRev(label)
objectsInLeafPerLabel(labelIdx)(pointIdx) += count
objectsInLeaf(pointIdx) += count
objectsPerLabel(labelIdx) += count
}

val numSamples = objectsPerLabel.sum

val scores = Array.tabulate[Double](numLabels, numDistinctPoints) { case (label, pointIdx) => log(
(objectsInLeafPerLabel(label)(pointIdx) + 1).toDouble/(objectsInLeaf(pointIdx) + numLabels)
*
(numSamples + numLabels).toDouble/(objectsPerLabel(label) + 1)
)}

scores
val aggregated = training.groupBy(identity).map(x => (x._1, x._2.size.toLong)).collect()
Fern.computeScores(aggregated, numDistinctPoints, labels)
}
}

Expand Down Expand Up @@ -200,8 +177,38 @@ object Fern {
}

def sampleFeatureIndices(data: RDD[LabeledPoint], numFeatures: Int): List[Int] = {
val allFeaturesNo = data.first().features.size
Random.shuffle((0 until allFeaturesNo).toList).take(numFeatures).sorted
val numFeaturesInData = data.first().features.size
sampleFeatureIndices(numFeaturesInData, numFeatures)
}

def sampleFeatureIndices(numFeaturesInData: Int, numFeatures: Int): List[Int] = {
Random.shuffle((0 until numFeaturesInData).toList).take(numFeatures).sorted
}

def computeScores(aggregated: Array[((Double, Int), Long)], numDistinctPoints: Int, labels: Array[Double]) = {
val labelsRev = labels.toList.zipWithIndex.toMap
val numLabels = labels.length

val objectsInLeafPerLabel = Array.fill[Long](numLabels, numDistinctPoints)(1)
val objectsInLeaf = Array.fill[Long](numDistinctPoints)(0)
val objectsPerLabel = Array.fill[Long](numLabels)(0)

aggregated.foreach{ case ((label, pointIdx), count) =>
val labelIdx = labelsRev(label)
objectsInLeafPerLabel(labelIdx)(pointIdx) += count
objectsInLeaf(pointIdx) += count
objectsPerLabel(labelIdx) += count
}

val numSamples = objectsPerLabel.sum

val scores = Array.tabulate[Double](numLabels, numDistinctPoints) { case (label, pointIdx) => log(
(objectsInLeafPerLabel(label)(pointIdx) + 1).toDouble/(objectsInLeaf(pointIdx) + numLabels)
*
(numSamples + numLabels).toDouble/(objectsPerLabel(label) + 1)
)}

scores
}

def shuffleFeatureValues(data: RDD[LabeledPoint], featureIndex: Int): RDD[(LabeledPoint, LabeledPoint)] = {
Expand Down
26 changes: 26 additions & 0 deletions src/main/scala/pl/edu/icm/sparkling_ferns/FernBuilder.scala
@@ -0,0 +1,26 @@
package pl.edu.icm.sparkling_ferns

import breeze.numerics._
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.rdd.RDD

/**
* @author Mateusz Fedoryszak (m.fedoryszak@icm.edu.pl)
*/
class FernBuilder(featureIndices: List[Int], binarisers: List[FeatureBinariser]) extends Serializable {
def toIndex(featureVector: Vector): Int = {
val features = featureVector.toArray
val selected = featureIndices.map(features)

Fern.toPointIndex(selected, binarisers)
}



def build(counts: Array[((Double, Int), Long)], labels: Array[Double]): FernModel = {
val numFeatures = featureIndices.length
val numDistinctPoints = 1 << numFeatures
val scores = Fern.computeScores(counts, numDistinctPoints, labels)
new FernModel(labels, featureIndices, binarisers, scores)
}
}
45 changes: 39 additions & 6 deletions src/main/scala/pl/edu/icm/sparkling_ferns/FernForest.scala
@@ -1,5 +1,7 @@
package pl.edu.icm.sparkling_ferns

import breeze.stats.distributions.Poisson
import org.apache.spark.SparkContext._
import org.apache.spark.mllib.classification.ClassificationModel
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.regression.LabeledPoint
Expand Down Expand Up @@ -28,18 +30,49 @@ case class FernForestModelWithStats(model: FernForestModel, oobConfusionMatrix:
*/
class FernForest {
def run(data: RDD[LabeledPoint], numFerns: Int, numFeatures: Int, categoricalFeaturesInfo: Map[Int, Int]): FernForestModel = {
val labels = util.extractLabels(data)
new FernForestModel(List.fill(numFerns)(Fern.train(data, numFeatures, categoricalFeaturesInfo, labels)))
runAndAssess(data, numFerns, numFeatures, categoricalFeaturesInfo).model
//val labels = util.extractLabels(data)
//new FernForestModel(List.fill(numFerns)(Fern.train(data, numFeatures, categoricalFeaturesInfo, labels)))
}

def runAndAssess(data: RDD[LabeledPoint], numFerns: Int, numFeatures: Int, categoricalFeaturesInfo: Map[Int, Int]): FernForestModelWithStats = {
val labels = util.extractLabels(data)
val modelsWithStats = List.fill(numFerns)(Fern.trainAndAssess(data, numFeatures, categoricalFeaturesInfo, labels))

val featureImportance = modelsWithStats.flatMap(_.featureImportance).groupBy(_._1).map{case (idx, list) => (idx, util.mean(list.unzip._2))}.toList
val confusionMatrix = modelsWithStats.flatMap(_.oobConfusionMatrix).groupBy(_._1).map{case (cell, list) => (cell, list.unzip._2.sum)}.toList
val numFeaturesInData = data.take(1).head.features.size

val withMultipliers = data.map(x => (x, Array.fill(numFerns)(Poisson.distribution(1.0).draw())))

val featureIndicesPerFern = Array.fill(numFerns)(Fern.sampleFeatureIndices(numFeaturesInData, numFeatures))

val binarisersPerFern = Array.tabulate(numFerns)(i =>
Fern.sampleBinarisers(
withMultipliers.flatMap{case (point, muls) => List.fill(muls(i))(point)},
featureIndicesPerFern(i), categoricalFeaturesInfo))

val fernBuilders = (0 until numFerns).map{i =>
new FernBuilder(featureIndicesPerFern(i), binarisersPerFern(i))
}

val counts = withMultipliers.flatMap { case (point, muls) =>
(0 until numFerns).map { i =>
((i, point.label, fernBuilders(i).toIndex(point.features)), muls(i).toLong)
}
}.reduceByKey(_ + _).collect()

val countsPerFern = counts.groupBy(_._1._1).mapValues(_.map{ case ((_, label, idx), count) => (label, idx) -> count})

val ferns = (0 until numFerns).toList.map { i => fernBuilders(i).build(countsPerFern(i), labels)}

val model = new FernForestModel(ferns)

val confusionMatrix = withMultipliers.flatMap{ case (point, muls) =>
val fernIndices = muls.toList.zipWithIndex.filter(_._1 == 0).map(_._2)
fernIndices.map(ferns).map(fern => ((point.label, fern.predict(point.features)), 1l))
}.reduceByKey(_ + _).collect().toList

val modelsWithStats = List.fill(numFerns)(Fern.trainAndAssess(data, numFeatures, categoricalFeaturesInfo, labels))

val model = new FernForestModel(modelsWithStats.map(_.model))
val featureImportance = Nil //modelsWithStats.flatMap(_.featureImportance).groupBy(_._1).map{case (idx, list) => (idx, util.mean(list.unzip._2))}.toList

FernForestModelWithStats(model, confusionMatrix, featureImportance)
}
Expand Down

0 comments on commit 9efa4ad

Please sign in to comment.