/
FittedPipeline.scala
48 lines (42 loc) · 1.69 KB
/
FittedPipeline.scala
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
package workflow.graph
import org.apache.spark.rdd.RDD
/**
* This is the result of fitting a [[Pipeline]]. It is logically equivalent to the Pipeline it is produced by,
* but with all Estimators pre-fit, and only containing Transformers in the underlying graph.
* Applying a FittedPipeline to new data does not trigger any new optimization or estimator fitting.
*
* Unlike normal Pipelines, FittedPipelines are serializable and may be written to and from disk.
*
* @param transformerGraph The DAG representing the execution (only contains Transformers)
* @param source The SourceId of the Pipeline
* @param sink The SinkId of the Pipeline
* @tparam A type of the data this FittedPipeline expects as input
* @tparam B type of the data this FittedPipeline outputs
*/
class FittedPipeline[A, B] private[graph] (
private[graph] val transformerGraph: TransformerGraph,
private[graph] val source: SourceId,
private[graph] val sink: SinkId
) extends Chainable[A, B] with Serializable {
/**
* Converts this FittedPipeline back into a Pipeline.
*/
private[graph] override def toPipeline: Pipeline[A, B] = new Pipeline(
new GraphExecutor(transformerGraph.toGraph, optimize = false),
source,
sink)
/**
* The application of this FittedPipeline to a single input item.
*
* @param in The input item to pass into this transformer
* @return The output value
*/
def apply(in: A): B = toPipeline.apply(in).get()
/**
* The application of this FittedPipeline to an RDD of input items.
*
* @param in The RDD input to pass into this transformer
* @return The RDD output for the given input
*/
def apply(in: RDD[A]): RDD[B] = toPipeline.apply(in).get()
}