Skip to content

Commit

Permalink
Merge branch 'optimisation' of github.com:CeON/sparkling-ferns into o…
Browse files Browse the repository at this point in the history
…ptimisation
  • Loading branch information
pdendek committed Sep 22, 2015
2 parents c6d9cb0 + 193edc0 commit bd087c4
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 42 deletions.
10 changes: 4 additions & 6 deletions src/main/scala/pl/edu/icm/sparkling_ferns/FernBuilder.scala
@@ -1,22 +1,20 @@
package pl.edu.icm.sparkling_ferns

import breeze.numerics._
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.rdd.RDD

/**
* @author Mateusz Fedoryszak (m.fedoryszak@icm.edu.pl)
*/
class FernBuilder(featureIndices: List[Int], binarisers: List[FeatureBinariser]) extends Serializable {
def toIndex(featureVector: Vector): Int = {
case class FernBuilder(featureIndices: List[Int], thresholds: Map[Int, Double], categoricalFeaturesInfo: Map[Int, Int]) extends Serializable {
val binarisers = Fern.sampleBinarisersPresetThresholds(thresholds, featureIndices, categoricalFeaturesInfo)

def toCombinationIndex(featureVector: Vector): Int = {
val features = featureVector.toArray
val selected = featureIndices.map(features)

Fern.toPointIndex(selected, binarisers)
}



def build(counts: Array[((Double, Int), Long)], labels: Array[Double]): FernModel = {
val numFeatures = featureIndices.length
val numDistinctPoints = 1 << numFeatures
Expand Down
73 changes: 38 additions & 35 deletions src/main/scala/pl/edu/icm/sparkling_ferns/FernForest.scala
Expand Up @@ -12,11 +12,18 @@ import scala.util.Random
/**
* @author Mateusz Fedoryszak (m.fedoryszak@icm.edu.pl)
*/
class FernForestModel(private val ferns: List[FernModel]) extends ClassificationModel with Serializable {
class FernForestModel(val ferns: Array[FernModel]) extends ClassificationModel with Serializable {
override def predict(testData: RDD[Vector]): RDD[Double] = testData.map(predict)

override def predict(testData: Vector): Double = {
val scores = ferns.map(_.scores(testData))
predictSubset(testData, 0 until ferns.length)
}

/**
* Make a prediction using only a subsed of ferns specified by fernIdcs parameter.
*/
def predictSubset(testData: Vector, fernIdcs: TraversableOnce[Int]): Double = {
val scores = fernIdcs.map(ferns).map(_.scores(testData))
val scoreSums = scores.reduce(util.arrayReduction[Double](_ + _))
val labels = ferns.head.labels
val labelIdx = (0 until labels.length) maxBy scoreSums
Expand All @@ -32,19 +39,36 @@ case class FernForestModelWithStats(model: FernForestModel, oobConfusionMatrix:
*/
class FernForest {
def run(data: RDD[LabeledPoint], numFerns: Int, numFeatures: Int, categoricalFeaturesInfo: Map[Int, Int]): FernForestModel = {
runAndAssess(data, numFerns, numFeatures, categoricalFeaturesInfo).model
//val labels = util.extractLabels(data)
//new FernForestModel(List.fill(numFerns)(Fern.train(data, numFeatures, categoricalFeaturesInfo, labels)))
val withMultipliers = data.map(x => (x, Array.fill(numFerns)(Poisson.distribution(1.0).draw())))

runWithMultipliers(withMultipliers, numFerns, numFeatures, categoricalFeaturesInfo)
}

def runAndAssess(data: RDD[LabeledPoint], numFerns: Int, numFeatures: Int, categoricalFeaturesInfo: Map[Int, Int]): FernForestModelWithStats = {
val metadata = DatasetMetadata.fromData(data)

val withMultipliers = data.map(x => (x, Array.fill(numFerns)(Poisson.distribution(1.0).draw())))

val model = runWithMultipliers(withMultipliers, numFerns, numFeatures, categoricalFeaturesInfo)

val confusionMatrix = withMultipliers.filter(_._2.contains(0)).map{ case (point, muls) =>
val fernIndices = muls.toList.zipWithIndex.filter(_._1 == 0).map(_._2)
((point.label, model.predictSubset(point.features, fernIndices)), 1l)
}.reduceByKey(_ + _).collect().toList


//TODO: constant number of passes implementation
val featureImportance = model.ferns.zipWithIndex
.flatMap{ case (fern, i) => fern.featureImportance(withMultipliers.filter(_._2(i) == 0).map(_._1)) }
.groupBy(_._1).mapValues(_.unzip._2).mapValues(util.mean(_)).toList

FernForestModelWithStats(model, confusionMatrix, featureImportance)
}

def runWithMultipliers(withMultipliers: RDD[(LabeledPoint, Array[Int])], numFerns: Int, numFeatures: Int, categoricalFeaturesInfo: Map[Int, Int]): FernForestModel = {
val metadata = DatasetMetadata.fromData(withMultipliers.map(_._1))

val featureIndicesPerFern = Array.fill(numFerns)(Fern.sampleFeatureIndices(metadata.numFeatures, numFeatures))

val thresholdsPerFern = withMultipliers.flatMap { case(point, muls) =>
val thresholds = withMultipliers.flatMap { case(point, muls) =>
val features = point.features.toArray
for {
fernIdx <- 0 until numFerns
Expand All @@ -53,43 +77,25 @@ class FernForest {
} yield ((fernIdx, featureIdx), List((Random.nextFloat(), features(featureIdx))))
}.reduceByKey((list1, list2) => (list1 ++ list2).sortBy(_._1).take(2))
.mapValues(_.unzip._2).mapValues(list => list.sum / list.size).collect()
.groupBy(_._1._1).mapValues(_.map{case ((fernIdx, featureIdx), threshold) => (featureIdx, threshold)}.toMap)

val binarisersPerFern = Array.tabulate(numFerns)(i =>
Fern.sampleBinarisersPresetThresholds(
thresholdsPerFern.getOrElse(i, Map.empty), featureIndicesPerFern(i), categoricalFeaturesInfo))
val thresholdsPerFern = thresholds.groupBy(_._1._1).mapValues(
_.map{case ((fernIdx, featureIdx), threshold) => (featureIdx, threshold)}.toMap)

val fernBuilders = (0 until numFerns).map{i =>
new FernBuilder(featureIndicesPerFern(i), binarisersPerFern(i))
new FernBuilder(featureIndicesPerFern(i), thresholdsPerFern.getOrElse(i, Map.empty), categoricalFeaturesInfo)
}

val counts = withMultipliers.flatMap { case (point, muls) =>
(0 until numFerns).map { i =>
((i, point.label, fernBuilders(i).toIndex(point.features)), muls(i).toLong)
((i, point.label, fernBuilders(i).toCombinationIndex(point.features)), muls(i).toLong)
}
}.reduceByKey(_ + _).collect()

val countsPerFern = counts.groupBy(_._1._1).mapValues(_.map{ case ((_, label, idx), count) => (label, idx) -> count})

val ferns = (0 until numFerns).toList.map { i => fernBuilders(i).build(countsPerFern(i), metadata.labels)}

val model = new FernForestModel(ferns)

val confusionMatrix = withMultipliers.flatMap{ case (point, muls) =>
val fernIndices = muls.toList.zipWithIndex.filter(_._1 == 0).map(_._2)
fernIndices.map(ferns).map(fern => ((point.label, fern.predict(point.features)), 1l))
}.reduceByKey(_ + _).collect().toList

val modelsWithStats = List.fill(numFerns)(Fern.trainAndAssess(data, numFeatures, categoricalFeaturesInfo, metadata.labels))

val featureImportance = Nil //modelsWithStats.flatMap(_.featureImportance).groupBy(_._1).map{case (idx, list) => (idx, util.mean(list.unzip._2))}.toList

FernForestModelWithStats(model, confusionMatrix, featureImportance)
}
val ferns = (0 until numFerns).map{ i => fernBuilders(i).build(countsPerFern(i), metadata.labels)}.toArray

def run(data: RDD[LabeledPoint], featureIndices: List[List[Int]]): FernForestModel = {
val labels = util.extractLabels(data)
new FernForestModel(featureIndices.map(Fern.train(data, _, labels)))
new FernForestModel(ferns)
}
}

Expand All @@ -102,7 +108,4 @@ object FernForest {

def trainAndAssess(input: RDD[LabeledPoint], numFerns: Int, numFeatures: Int, categoricalFeaturesInfo: Map[Int, Int]): FernForestModelWithStats =
new FernForest().runAndAssess(input, numFerns, numFeatures, categoricalFeaturesInfo)

def train(input: RDD[LabeledPoint], featureIndices: List[List[Int]]): FernForestModel =
new FernForest().run(input, featureIndices)
}
2 changes: 1 addition & 1 deletion src/main/scala/pl/edu/icm/sparkling_ferns/util.scala
Expand Up @@ -12,7 +12,7 @@ object util {
Array.tabulate(minLen)(i => f(a1(i), a2(i)))
}

def mean[T](s: Seq[T])(implicit n: Fractional[T]) = n.div(s.sum, n.fromInt(s.size))
def mean[T](s: Traversable[T])(implicit n: Fractional[T]) = n.div(s.sum, n.fromInt(s.size))

def extractLabels(data: RDD[LabeledPoint]) =
data.map(p => p.label).distinct().collect()
Expand Down

0 comments on commit bd087c4

Please sign in to comment.