Skip to content

Commit

Permalink
refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
matfed committed Sep 15, 2015
1 parent bd26cee commit 193edc0
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 40 deletions.
10 changes: 4 additions & 6 deletions src/main/scala/pl/edu/icm/sparkling_ferns/FernBuilder.scala
@@ -1,22 +1,20 @@
package pl.edu.icm.sparkling_ferns

import breeze.numerics._
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.rdd.RDD

/**
* @author Mateusz Fedoryszak (m.fedoryszak@icm.edu.pl)
*/
class FernBuilder(featureIndices: List[Int], binarisers: List[FeatureBinariser]) extends Serializable {
def toIndex(featureVector: Vector): Int = {
case class FernBuilder(featureIndices: List[Int], thresholds: Map[Int, Double], categoricalFeaturesInfo: Map[Int, Int]) extends Serializable {
val binarisers = Fern.sampleBinarisersPresetThresholds(thresholds, featureIndices, categoricalFeaturesInfo)

def toCombinationIndex(featureVector: Vector): Int = {
val features = featureVector.toArray
val selected = featureIndices.map(features)

Fern.toPointIndex(selected, binarisers)
}



def build(counts: Array[((Double, Int), Long)], labels: Array[Double]): FernModel = {
val numFeatures = featureIndices.length
val numDistinctPoints = 1 << numFeatures
Expand Down
62 changes: 29 additions & 33 deletions src/main/scala/pl/edu/icm/sparkling_ferns/FernForest.scala
Expand Up @@ -12,7 +12,7 @@ import scala.util.Random
/**
* @author Mateusz Fedoryszak (m.fedoryszak@icm.edu.pl)
*/
class FernForestModel(private val ferns: Array[FernModel]) extends ClassificationModel with Serializable {
class FernForestModel(val ferns: Array[FernModel]) extends ClassificationModel with Serializable {
override def predict(testData: RDD[Vector]): RDD[Double] = testData.map(predict)

override def predict(testData: Vector): Double = {
Expand All @@ -39,19 +39,36 @@ case class FernForestModelWithStats(model: FernForestModel, oobConfusionMatrix:
*/
class FernForest {
def run(data: RDD[LabeledPoint], numFerns: Int, numFeatures: Int, categoricalFeaturesInfo: Map[Int, Int]): FernForestModel = {
runAndAssess(data, numFerns, numFeatures, categoricalFeaturesInfo).model
//val labels = util.extractLabels(data)
//new FernForestModel(List.fill(numFerns)(Fern.train(data, numFeatures, categoricalFeaturesInfo, labels)))
val withMultipliers = data.map(x => (x, Array.fill(numFerns)(Poisson.distribution(1.0).draw())))

runWithMultipliers(withMultipliers, numFerns, numFeatures, categoricalFeaturesInfo)
}

def runAndAssess(data: RDD[LabeledPoint], numFerns: Int, numFeatures: Int, categoricalFeaturesInfo: Map[Int, Int]): FernForestModelWithStats = {
val metadata = DatasetMetadata.fromData(data)

val withMultipliers = data.map(x => (x, Array.fill(numFerns)(Poisson.distribution(1.0).draw())))

val model = runWithMultipliers(withMultipliers, numFerns, numFeatures, categoricalFeaturesInfo)

val confusionMatrix = withMultipliers.filter(_._2.contains(0)).map{ case (point, muls) =>
val fernIndices = muls.toList.zipWithIndex.filter(_._1 == 0).map(_._2)
((point.label, model.predictSubset(point.features, fernIndices)), 1l)
}.reduceByKey(_ + _).collect().toList


//TODO: constant number of passes implementation
val featureImportance = model.ferns.zipWithIndex
.flatMap{ case (fern, i) => fern.featureImportance(withMultipliers.filter(_._2(i) == 0).map(_._1)) }
.groupBy(_._1).mapValues(_.unzip._2).mapValues(util.mean(_)).toList

FernForestModelWithStats(model, confusionMatrix, featureImportance)
}

def runWithMultipliers(withMultipliers: RDD[(LabeledPoint, Array[Int])], numFerns: Int, numFeatures: Int, categoricalFeaturesInfo: Map[Int, Int]): FernForestModel = {
val metadata = DatasetMetadata.fromData(withMultipliers.map(_._1))

val featureIndicesPerFern = Array.fill(numFerns)(Fern.sampleFeatureIndices(metadata.numFeatures, numFeatures))

val thresholdsPerFern = withMultipliers.flatMap { case(point, muls) =>
val thresholds = withMultipliers.flatMap { case(point, muls) =>
val features = point.features.toArray
for {
fernIdx <- 0 until numFerns
Expand All @@ -60,43 +77,25 @@ class FernForest {
} yield ((fernIdx, featureIdx), List((Random.nextFloat(), features(featureIdx))))
}.reduceByKey((list1, list2) => (list1 ++ list2).sortBy(_._1).take(2))
.mapValues(_.unzip._2).mapValues(list => list.sum / list.size).collect()
.groupBy(_._1._1).mapValues(_.map{case ((fernIdx, featureIdx), threshold) => (featureIdx, threshold)}.toMap)

val binarisersPerFern = Array.tabulate(numFerns)(i =>
Fern.sampleBinarisersPresetThresholds(
thresholdsPerFern.getOrElse(i, Map.empty), featureIndicesPerFern(i), categoricalFeaturesInfo))
val thresholdsPerFern = thresholds.groupBy(_._1._1).mapValues(
_.map{case ((fernIdx, featureIdx), threshold) => (featureIdx, threshold)}.toMap)

val fernBuilders = (0 until numFerns).map{i =>
new FernBuilder(featureIndicesPerFern(i), binarisersPerFern(i))
new FernBuilder(featureIndicesPerFern(i), thresholdsPerFern.getOrElse(i, Map.empty), categoricalFeaturesInfo)
}

val counts = withMultipliers.flatMap { case (point, muls) =>
(0 until numFerns).map { i =>
((i, point.label, fernBuilders(i).toIndex(point.features)), muls(i).toLong)
((i, point.label, fernBuilders(i).toCombinationIndex(point.features)), muls(i).toLong)
}
}.reduceByKey(_ + _).collect()

val countsPerFern = counts.groupBy(_._1._1).mapValues(_.map{ case ((_, label, idx), count) => (label, idx) -> count})

val ferns = (0 until numFerns).map{ i => fernBuilders(i).build(countsPerFern(i), metadata.labels)}.toArray

val model = new FernForestModel(ferns)

val confusionMatrix = withMultipliers.filter(_._2.contains(0)).map{ case (point, muls) =>
val fernIndices = muls.toList.zipWithIndex.filter(_._1 == 0).map(_._2)
((point.label, model.predictSubset(point.features, fernIndices)), 1l)
}.reduceByKey(_ + _).collect().toList

val modelsWithStats = List.fill(numFerns)(Fern.trainAndAssess(data, numFeatures, categoricalFeaturesInfo, metadata.labels))

val featureImportance = Nil //modelsWithStats.flatMap(_.featureImportance).groupBy(_._1).map{case (idx, list) => (idx, util.mean(list.unzip._2))}.toList

FernForestModelWithStats(model, confusionMatrix, featureImportance)
}

def run(data: RDD[LabeledPoint], featureIndices: List[List[Int]]): FernForestModel = {
val labels = util.extractLabels(data)
new FernForestModel(featureIndices.map(Fern.train(data, _, labels)).toArray)
new FernForestModel(ferns)
}
}

Expand All @@ -109,7 +108,4 @@ object FernForest {

def trainAndAssess(input: RDD[LabeledPoint], numFerns: Int, numFeatures: Int, categoricalFeaturesInfo: Map[Int, Int]): FernForestModelWithStats =
new FernForest().runAndAssess(input, numFerns, numFeatures, categoricalFeaturesInfo)

def train(input: RDD[LabeledPoint], featureIndices: List[List[Int]]): FernForestModel =
new FernForest().run(input, featureIndices)
}
2 changes: 1 addition & 1 deletion src/main/scala/pl/edu/icm/sparkling_ferns/util.scala
Expand Up @@ -12,7 +12,7 @@ object util {
Array.tabulate(minLen)(i => f(a1(i), a2(i)))
}

def mean[T](s: Seq[T])(implicit n: Fractional[T]) = n.div(s.sum, n.fromInt(s.size))
def mean[T](s: Traversable[T])(implicit n: Fractional[T]) = n.div(s.sum, n.fromInt(s.size))

def extractLabels(data: RDD[LabeledPoint]) =
data.map(p => p.label).distinct().collect()
Expand Down

0 comments on commit 193edc0

Please sign in to comment.