Skip to content

Commit

Permalink
Make testTranformer per row special case of
Browse files Browse the repository at this point in the history
testTransformerByGlobalCheckFunc.
  • Loading branch information
MrBago committed Dec 28, 2017
1 parent 7bc588a commit de345dc
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 45 deletions.
26 changes: 9 additions & 17 deletions mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ trait MLTest extends StreamTest with TempDirectory { self: Suite =>
transformer: Transformer,
firstResultCol: String,
otherResultCols: String*)
(checkFunction: Row => Unit)
(globalCheckFunction: Seq[Row] => Unit): Unit = {

val columnNames = dataframe.schema.fieldNames
Expand All @@ -71,7 +70,7 @@ trait MLTest extends StreamTest with TempDirectory { self: Suite =>
.select(firstResultCol, otherResultCols: _*)
testStream(streamOutput) (
AddData(stream, data: _*),
CheckAnswer(checkFunction, globalCheckFunction)
CheckAnswer(globalCheckFunction)
)
}

Expand All @@ -80,18 +79,10 @@ trait MLTest extends StreamTest with TempDirectory { self: Suite =>
transformer: Transformer,
firstResultCol: String,
otherResultCols: String*)
(checkFunction: Row => Unit)
(globalCheckFunction: Seq[Row] => Unit): Unit = {
val dfOutput = transformer.transform(dataframe)
val outputs = dfOutput.select(firstResultCol, otherResultCols: _*).collect()
if (checkFunction != null) {
outputs.foreach { row =>
checkFunction(row)
}
}
if (globalCheckFunction != null) {
globalCheckFunction(outputs)
}
globalCheckFunction(outputs)
}

def testTransformer[A : Encoder](
Expand All @@ -100,10 +91,11 @@ trait MLTest extends StreamTest with TempDirectory { self: Suite =>
firstResultCol: String,
otherResultCols: String*)
(checkFunction: Row => Unit): Unit = {
testTransformerOnStreamData(dataframe, transformer, firstResultCol,
otherResultCols: _*)(checkFunction)(null)
testTransformerOnDF(dataframe, transformer, firstResultCol,
otherResultCols: _*)(checkFunction)(null)
testTransformerByGlobalCheckFunc(
dataframe,
transformer,
firstResultCol,
otherResultCols: _*) { rows: Seq[Row] => rows.foreach(checkFunction(_)) }
}

def testTransformerByGlobalCheckFunc[A : Encoder](
Expand All @@ -113,8 +105,8 @@ trait MLTest extends StreamTest with TempDirectory { self: Suite =>
otherResultCols: String*)
(globalCheckFunction: Seq[Row] => Unit): Unit = {
testTransformerOnStreamData(dataframe, transformer, firstResultCol,
otherResultCols: _*)(null)(globalCheckFunction)
otherResultCols: _*)(globalCheckFunction)
testTransformerOnDF(dataframe, transformer, firstResultCol,
otherResultCols: _*)(null)(globalCheckFunction)
otherResultCols: _*)(globalCheckFunction)
}
}
10 changes: 5 additions & 5 deletions mllib/src/test/scala/org/apache/spark/ml/util/MLTestSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

package org.apache.spark.ml.util

import org.apache.spark.ml.{PipelineModel, Transformer}
import org.apache.spark.ml.feature.StringIndexer
import org.apache.spark.sql.Row

Expand All @@ -32,21 +31,22 @@ class MLTestSuite extends MLTest {
val indexer = new StringIndexer().setStringOrderType("alphabetAsc")
.setInputCol("label").setOutputCol("indexed")
val indexerModel = indexer.fit(data)
testTransformerOnStreamData[(Int, String)](data, indexerModel, "id", "indexed") {
testTransformer[(Int, String)](data, indexerModel, "id", "indexed") {
case Row(id: Int, indexed: Double) =>
assert(id === indexed.toInt)
} { rows: Seq[Row] =>
}
testTransformerByGlobalCheckFunc[(Int, String)] (data, indexerModel, "id", "indexed") { rows =>
assert(rows.map(_.getDouble(1)).max === 5.0)
}

intercept[Exception] {
testTransformerOnStreamData[(Int, String)](data, indexerModel, "id", "indexed") {
case Row(id: Int, indexed: Double) =>
assert(id != indexed.toInt)
} (null)
}
}
intercept[Exception] {
testTransformerOnStreamData[(Int, String)](data, indexerModel, "id", "indexed") (null) {
testTransformerOnStreamData[(Int, String)](data, indexerModel, "id", "indexed") {
rows: Seq[Row] =>
assert(rows.map(_.getDouble(1)).max === 1.0)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,9 +137,8 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be

def apply(rows: Row*): CheckAnswerRows = CheckAnswerRows(rows, false, false)

def apply(checkFunction: Row => Unit,
globalCheckFunction: Seq[Row] => Unit): CheckAnswerRowsByFunc =
CheckAnswerRowsByFunc(checkFunction, globalCheckFunction, false)
def apply(globalCheckFunction: Seq[Row] => Unit): CheckAnswerRowsByFunc =
CheckAnswerRowsByFunc(globalCheckFunction, false)
}

/**
Expand All @@ -162,9 +161,8 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be

def apply(rows: Row*): CheckAnswerRows = CheckAnswerRows(rows, true, false)

def apply(checkFunction: Row => Unit,
globalCheckFunction: Seq[Row] => Unit): CheckAnswerRowsByFunc =
CheckAnswerRowsByFunc(checkFunction, globalCheckFunction, true)
def apply(globalCheckFunction: Seq[Row] => Unit): CheckAnswerRowsByFunc =
CheckAnswerRowsByFunc(globalCheckFunction, true)
}

case class CheckAnswerRows(expectedAnswer: Seq[Row], lastOnly: Boolean, isSorted: Boolean)
Expand All @@ -180,7 +178,6 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be
}

case class CheckAnswerRowsByFunc(
checkFunction: Row => Unit,
globalCheckFunction: Seq[Row] => Unit,
lastOnly: Boolean) extends StreamAction with StreamMustBeRunning {
override def toString: String = s"$operatorName"
Expand Down Expand Up @@ -643,23 +640,12 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be
error => failTest(error)
}

case CheckAnswerRowsByFunc(checkFunction, globalCheckFunction, lastOnly) =>
case CheckAnswerRowsByFunc(globalCheckFunction, lastOnly) =>
val sparkAnswer = fetchStreamAnswer(currentStream, lastOnly)
if (checkFunction != null) {
sparkAnswer.foreach { row =>
try {
checkFunction(row)
} catch {
case e: Throwable => failTest(e.toString)
}
}
}
if (globalCheckFunction != null) {
try {
globalCheckFunction(sparkAnswer)
} catch {
case e: Throwable => failTest(e.toString)
}
try {
globalCheckFunction(sparkAnswer)
} catch {
case e: Throwable => failTest(e.toString)
}
}
pos += 1
Expand Down

0 comments on commit de345dc

Please sign in to comment.