Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
8 changed files
with
229 additions
and
42 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
package workflow | ||
|
||
import org.apache.spark.rdd.RDD | ||
|
||
import scala.reflect.ClassTag | ||
|
||
private[workflow] class GatherTransformer[T] extends TransformerNode[Seq[T]] { | ||
def transform(dataDependencies: Seq[_], fitDependencies: Seq[TransformerNode[_]]): Seq[T] = dataDependencies.map(_.asInstanceOf[T]) | ||
|
||
def transformRDD(dataDependencies: Seq[RDD[_]], fitDependencies: Seq[TransformerNode[_]]): RDD[Seq[T]] = { | ||
dataDependencies.map(_.asInstanceOf[RDD[T]].map(t => Seq(t))).reduceLeft((x, y) => { | ||
x.zip(y).map(z => z._1 ++ z._2) | ||
}) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
package workflow | ||
|
||
import org.apache.spark.SparkContext | ||
import org.scalatest.FunSuite | ||
import pipelines.{Logging, LocalSparkContext} | ||
|
||
class DelegatingTransformerSuite extends FunSuite with LocalSparkContext with Logging { | ||
test("single apply") { | ||
val hashTransformer = Transformer[String, Int](_.hashCode) | ||
val delegatingTransformer = new DelegatingTransformer[Int]("label") | ||
|
||
val string = "A31DFSsafds*be31" | ||
assert(delegatingTransformer.transform(Seq(string), Seq(hashTransformer)) === string.hashCode) | ||
} | ||
|
||
test("rdd apply") { | ||
sc = new SparkContext("local", "test") | ||
|
||
val hashTransformer = Transformer[String, Int](_.hashCode) | ||
val delegatingTransformer = new DelegatingTransformer[Int]("label") | ||
|
||
val strings = Seq("A31DFSsafds*be31", "lj32fsd", "woadsf8923") | ||
val transformedStrings = delegatingTransformer.transformRDD(Seq(sc.parallelize(strings)), Seq(hashTransformer)).collect() | ||
assert(transformedStrings.toSeq === strings.map(_.hashCode)) | ||
} | ||
|
||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
package workflow | ||
|
||
import org.apache.spark.SparkContext | ||
import org.apache.spark.rdd.RDD | ||
import org.scalatest.FunSuite | ||
import pipelines.{LocalSparkContext, Logging} | ||
|
||
class EstimatorSuite extends FunSuite with LocalSparkContext with Logging { | ||
test("estimator withData") { | ||
sc = new SparkContext("local", "test") | ||
|
||
val intEstimator = new Estimator[Int, Int] { | ||
protected def fit(data: RDD[Int]): Transformer[Int, Int] = { | ||
val first = data.first() | ||
Transformer(_ => first) | ||
} | ||
} | ||
|
||
val trainData = sc.parallelize(Seq(32, 94, 12)) | ||
val testData = sc.parallelize(Seq(42, 58, 61)) | ||
|
||
val pipeline = intEstimator.withData(trainData) | ||
assert(pipeline.apply(testData).collect().toSeq === Seq(32, 32, 32)) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
package workflow | ||
|
||
import org.apache.spark.SparkContext | ||
import org.apache.spark.rdd.RDD | ||
import org.scalatest.FunSuite | ||
import pipelines.{LocalSparkContext, Logging} | ||
|
||
class LabelEstimatorSuite extends FunSuite with LocalSparkContext with Logging { | ||
test("estimator withData") { | ||
sc = new SparkContext("local", "test") | ||
|
||
val intEstimator = new LabelEstimator[Int, Int, String] { | ||
protected def fit(data: RDD[Int], labels: RDD[String]): Transformer[Int, Int] = { | ||
val first = data.first() | ||
val label = labels.first().hashCode | ||
Transformer(_ => first + label) | ||
|
||
} | ||
} | ||
|
||
val trainData = sc.parallelize(Seq(32, 94, 12)) | ||
val trainLabels = sc.parallelize(Seq("sjkfdl", "iw", "432")) | ||
val testData = sc.parallelize(Seq(42, 58, 61)) | ||
|
||
val pipeline = intEstimator.withData(trainData, trainLabels) | ||
assert(pipeline.apply(testData).collect().toSeq === Seq.fill(3)(32 + "sjkfdl".hashCode)) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
package workflow | ||
|
||
import org.apache.spark.SparkContext | ||
import org.apache.spark.rdd.RDD | ||
import org.scalatest.FunSuite | ||
import pipelines.{LocalSparkContext, Logging} | ||
|
||
class PipelineSuite extends FunSuite with LocalSparkContext with Logging { | ||
test("pipeline chaining") { | ||
sc = new SparkContext("local", "test") | ||
|
||
val first = Transformer[Int, Int](_ * 2) | ||
val second = Transformer[Int, Int](_ - 3) | ||
|
||
val data = sc.parallelize(Seq(32, 94, 12)) | ||
val pipeline = first andThen second | ||
|
||
val pipelineOut = pipeline(data).collect().toSeq | ||
|
||
assert(pipeline(7) === (7 * 2) - 3) | ||
assert(pipelineOut === Seq((32*2) - 3, (94*2) - 3, (12*2) - 3)) | ||
} | ||
|
||
test("estimator chaining") { | ||
sc = new SparkContext("local", "test") | ||
|
||
val doubleTransformer = Transformer[Int, Int](_ * 2) | ||
|
||
val intEstimator = new Estimator[Int, Int] { | ||
protected def fit(data: RDD[Int]): Transformer[Int, Int] = { | ||
val first = data.first() | ||
Transformer(x => x + first) | ||
} | ||
} | ||
|
||
|
||
val data = sc.parallelize(Seq(32, 94, 12)) | ||
val pipeline = doubleTransformer andThen (intEstimator, data) | ||
|
||
val pipelineOut = pipeline(data).collect().toSeq | ||
val pipelineLastTransformerOut = pipeline.fittedTransformer(data).collect().toSeq | ||
|
||
assert(pipelineOut === Seq(32*2 + 32*2, 94*2 + 32*2, 12*2 + 32*2)) | ||
assert(pipelineLastTransformerOut === Seq(32 + 32*2, 94 + 32*2, 12 + 32*2)) | ||
} | ||
|
||
test("label estimator chaining") { | ||
sc = new SparkContext("local", "test") | ||
|
||
val doubleTransformer = Transformer[Int, Int](_ * 2) | ||
|
||
val intEstimator = new LabelEstimator[Int, Int, String] { | ||
protected def fit(data: RDD[Int], labels: RDD[String]): Transformer[Int, Int] = { | ||
val first = data.first() + labels.first().toInt | ||
Transformer(x => x + first) | ||
} | ||
} | ||
|
||
|
||
val data = sc.parallelize(Seq(32, 94, 12)) | ||
val labels = sc.parallelize(Seq("10", "7", "14")) | ||
val pipeline = doubleTransformer andThen (intEstimator, data, labels) | ||
|
||
val pipelineOut = pipeline(data).collect().toSeq | ||
val pipelineLastTransformerOut = pipeline.fittedTransformer(data).collect().toSeq | ||
|
||
assert(pipelineOut === Seq(32*2 + 32*2 + 10, 94*2 + 32*2 + 10, 12*2 + 32*2 + 10)) | ||
assert(pipelineLastTransformerOut === Seq(32 + 32*2 + 10, 94 + 32*2 + 10, 12 + 32*2 + 10)) | ||
} | ||
|
||
test("Pipeline gather") { | ||
sc = new SparkContext("local", "test") | ||
|
||
val firstPipeline = Transformer[Int, Int](_ * 2) andThen Transformer[Int, Int](_ - 3) | ||
|
||
val secondPipeline = Transformer[Int, Int](_ * 2) andThen (new Estimator[Int, Int] { | ||
protected def fit(data: RDD[Int]): Transformer[Int, Int] = { | ||
val first = data.first() | ||
Transformer(x => x + first) | ||
} | ||
}, sc.parallelize(Seq(32, 94, 12))) | ||
|
||
val thirdPipeline = Transformer[Int, Int](_ * 4) andThen (new LabelEstimator[Int, Int, String] { | ||
protected def fit(data: RDD[Int], labels: RDD[String]): Transformer[Int, Int] = { | ||
val first = data.first() + labels.first().toInt | ||
Transformer(x => x + first) | ||
} | ||
}, sc.parallelize(Seq(32, 94, 12)), sc.parallelize(Seq("10", "7", "14"))) | ||
|
||
val pipeline = Pipeline.gather { | ||
firstPipeline :: secondPipeline :: thirdPipeline :: Nil | ||
} | ||
|
||
val single = 7 | ||
assert(pipeline(single) === Seq(firstPipeline.apply(single), secondPipeline.apply(single), thirdPipeline.apply(single))) | ||
|
||
val data = Seq(13, 2, 83) | ||
val correctOut = data.map(x => Seq(firstPipeline.apply(x), secondPipeline.apply(x), thirdPipeline.apply(x))) | ||
assert(pipeline(sc.parallelize(data)).collect().toSeq === correctOut) | ||
} | ||
} |