Skip to content

Commit

Permalink
Added first rewrite rule
Browse files Browse the repository at this point in the history
  • Loading branch information
tomerk committed Jun 14, 2015
1 parent 8a4c31a commit b6c4425
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 2 deletions.
1 change: 1 addition & 0 deletions src/main/resources/log4j.properties
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,6 @@ log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}:

# Only pay attention to INFO messages from Keystone.
log4j.logger.pipelines=INFO
log4j.logger.workflow=INFO
log4j.logger.nodes=INFO
log4j.logger.utils=INFO
5 changes: 3 additions & 2 deletions src/main/scala/pipelines/text/NewsgroupsPipeline.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import nodes.util.{CommonSparseFeatures, MaxClassifier}
import org.apache.spark.{SparkConf, SparkContext}
import pipelines.Logging
import scopt.OptionParser
import workflow.Optimizer

object NewsgroupsPipeline extends Logging {
val appName = "NewsgroupsPipeline"
Expand All @@ -33,8 +34,8 @@ object NewsgroupsPipeline extends Logging {
.withData(trainData.data, trainData.labels)
.andThen(MaxClassifier)

logInfo("\n" + predictorPipeline.toDOTString)
val predictor = predictorPipeline
val predictor = Optimizer.execute(predictorPipeline)
logInfo("\n" + predictor.toDOTString)

// Evaluate the classifier
logInfo("Evaluating classifier")
Expand Down
35 changes: 35 additions & 0 deletions src/main/scala/workflow/Optimizer.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package workflow

/**
* Optimizes a Pipeline DAG
*/
object Optimizer extends RuleExecutor {
protected val batches: Seq[Batch] = Batch("DAG Optimization", FixedPoint(100), EquivalentNodeMerger) :: Nil
}

/**
* Merges equivalent node indices in the DAG.
* Nodes are considered equivalent if:
* - The nodes at the indices themselves are equal
* - They point to the same data dependencies
* - They point to the same fit dependencies
*/
object EquivalentNodeMerger extends Rule {
def apply[A, B](plan: Pipeline[A, B]): Pipeline[A, B] = {
val fullNodes = plan.nodes.zip(plan.dataDeps.zip(plan.fitDeps)).zipWithIndex
val mergableNodes = fullNodes.groupBy(_._1).mapValues(_.map(_._2)).toSeq

if (mergableNodes.size == plan.nodes.size) {
// no nodes are mergable
plan
} else {
val oldToNewNodeMapping = mergableNodes.zipWithIndex.map(x => x._1._2.map(y => (y, x._2))).flatMap(identity).toMap + (Pipeline.SOURCE -> Pipeline.SOURCE)
val newNodes = mergableNodes.map(_._1._1)
val newDataDeps = mergableNodes.map(_._1._2._1.map(x => oldToNewNodeMapping(x)))
val newFitDeps = mergableNodes.map(_._1._2._2.map(x => oldToNewNodeMapping(x)))
val newSink = oldToNewNodeMapping(plan.sink)

Pipeline(newNodes, newDataDeps, newFitDeps, newSink)
}
}
}

0 comments on commit b6c4425

Please sign in to comment.