Skip to content

Commit

Permalink
rewrote Pipeline as a pseudo-case class to force pipeline nodes to be…
Browse files Browse the repository at this point in the history
… rewritten upon construction
  • Loading branch information
tomerk committed Jun 14, 2015
1 parent fbe7b5b commit 24b0a0e
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 15 deletions.
5 changes: 1 addition & 4 deletions src/main/scala/workflow/Node.scala
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,6 @@ object Node {
* An implicit conversion to turn any [[Node]] into a [[Pipeline]], to allow chaining, fitting, etc.
*/
implicit def nodeToPipeline[A, B : ClassTag](node: Node[A, B]): Pipeline[A, B] = {
node match {
case Pipeline(nodes) => Pipeline[A, B](nodes)
case _ => Pipeline[A, B](node.rewrite)
}
Pipeline(node.rewrite)
}
}
45 changes: 34 additions & 11 deletions src/main/scala/workflow/Pipeline.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import scala.reflect.ClassTag
/**
* Created by tomerk11 on 5/28/15.
*/
case class Pipeline[A, B : ClassTag] private[workflow] (nodes: Seq[Node[_, _]]) extends Node[A, B] {
final class Pipeline[A, B : ClassTag] private (val nodes: Seq[Node[_, _]]) extends Node[A, B] {

/**
* Chains an estimator onto this Pipeline, producing a new estimator that when fit on same input type as
Expand Down Expand Up @@ -35,7 +35,7 @@ case class Pipeline[A, B : ClassTag] private[workflow] (nodes: Seq[Node[_, _]])
* @return A new pipeline made up of this one and the new Transformer
*/
def then[C : ClassTag](next: Node[B, C]): Pipeline[A, C] = {
Pipeline(nodes.flatMap(_.rewrite) ++ next.rewrite)
new Pipeline(nodes ++ next.rewrite)
}

/**
Expand All @@ -52,22 +52,28 @@ case class Pipeline[A, B : ClassTag] private[workflow] (nodes: Seq[Node[_, _]])
def fit(): PipelineModel[A, B] = PipelineModel(Pipeline.fit(nodes))

override def rewrite: Seq[Node[_, _]] = {
val rewrittenNodes = nodes.flatMap(_.rewrite) match {
case Pipeline(`nodes`) +: tail => `nodes` ++ tail
case rewritten => rewritten
}

if (rewrittenNodes.forall(_.canSafelyPrependExtraNodes)) {
rewrittenNodes
if (nodes.forall(_.canSafelyPrependExtraNodes)) {
nodes
} else {
Seq(Pipeline(nodes.flatMap(_.rewrite)))
Seq(this)
}
}

def canSafelyPrependExtraNodes = true

override def toString = s"Pipeline($nodes)"

override def equals(other: Any): Boolean = other match {
case that: Pipeline[A, B] => nodes == that.nodes
case _ => false
}

override def hashCode(): Int = {
41 + (if (nodes == null) 0 else nodes.##)
}
}

object Pipeline {
object Pipeline extends Serializable {
def fit(pipeline: Seq[Node[_, _]], prefix: Seq[Transformer[_, _]] = Seq.empty): Seq[Transformer[_, _]] = {
var transformers: Seq[Transformer[_, _]] = Seq()
pipeline.flatMap(_.rewrite).foreach { node =>
Expand All @@ -89,4 +95,21 @@ object Pipeline {

transformers
}

def apply[A, B : ClassTag](nodes: Seq[Node[_, _]]): Pipeline[A, B] = {
val rewrittenNodes = nodes.flatMap(_.rewrite) match {
case Pipeline(`nodes`) +: tail => `nodes` ++ tail
case rewritten => rewritten
}

new Pipeline(rewrittenNodes)
}

def unapply[A, B](pipeline: Pipeline[A, B]): Option[Seq[Node[_, _]]] = {
if (pipeline == null) {
None
} else {
Some(pipeline.nodes)
}
}
}

0 comments on commit 24b0a0e

Please sign in to comment.