Skip to content

Commit

Permalink
Run test 100 times.
Browse files Browse the repository at this point in the history
  • Loading branch information
yhuai committed Sep 17, 2015
1 parent 8bde803 commit b4bdef6
Showing 1 changed file with 61 additions and 59 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -551,67 +551,69 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te
}
}

test("udaf with all data types") {
val struct =
StructType(
StructField("f1", FloatType, true) ::
StructField("f2", ArrayType(BooleanType), true) :: Nil)
val dataTypes = Seq(StringType, BinaryType, NullType, BooleanType,
ByteType, ShortType, IntegerType, LongType,
FloatType, DoubleType, DecimalType(25, 5), DecimalType(6, 5),
DateType, TimestampType,
ArrayType(IntegerType), MapType(StringType, LongType), struct,
new MyDenseVectorUDT())
// Right now, we will use SortBasedAggregate to handle UDAFs.
// UnsafeRow.mutableFieldTypes.asScala.toSeq will trigger SortBasedAggregate to use
// UnsafeRow as the aggregation buffer. While, dataTypes will trigger
// SortBasedAggregate to use a safe row as the aggregation buffer.
Seq(dataTypes, UnsafeRow.mutableFieldTypes.asScala.toSeq).foreach { dataTypes =>
val fields = dataTypes.zipWithIndex.map { case (dataType, index) =>
StructField(s"col$index", dataType, nullable = true)
}
// The schema used for data generator.
val schemaForGenerator = StructType(fields)
// The schema used for the DataFrame df.
val schema = StructType(StructField("id", IntegerType) +: fields)

logInfo(s"Testing schema: ${schema.treeString}")

val udaf = new ScalaAggregateFunction(schema)
// Generate data at the driver side. We need to materialize the data first and then
// create RDD.
val maybeDataGenerator =
RandomDataGenerator.forType(
dataType = schemaForGenerator,
nullable = true,
seed = Some(System.nanoTime()))
val dataGenerator =
maybeDataGenerator
.getOrElse(fail(s"Failed to create data generator for schema $schemaForGenerator"))
val data = (1 to 50).map { i =>
dataGenerator.apply() match {
case row: Row => Row.fromSeq(i +: row.toSeq)
case null => Row.fromSeq(i +: Seq.fill(schemaForGenerator.length)(null))
case other =>
fail(s"Row or null is expected to be generated, " +
s"but a ${other.getClass.getCanonicalName} is generated.")
(1 to 100).foreach { i =>
test(s"udaf with all data types: run $i") {
val struct =
StructType(
StructField("f1", FloatType, true) ::
StructField("f2", ArrayType(BooleanType), true) :: Nil)
val dataTypes = Seq(StringType, BinaryType, NullType, BooleanType,
ByteType, ShortType, IntegerType, LongType,
FloatType, DoubleType, DecimalType(25, 5), DecimalType(6, 5),
DateType, TimestampType,
ArrayType(IntegerType), MapType(StringType, LongType), struct,
new MyDenseVectorUDT())
// Right now, we will use SortBasedAggregate to handle UDAFs.
// UnsafeRow.mutableFieldTypes.asScala.toSeq will trigger SortBasedAggregate to use
// UnsafeRow as the aggregation buffer. While, dataTypes will trigger
// SortBasedAggregate to use a safe row as the aggregation buffer.
Seq(dataTypes, UnsafeRow.mutableFieldTypes.asScala.toSeq).foreach { dataTypes =>
val fields = dataTypes.zipWithIndex.map { case (dataType, index) =>
StructField(s"col$index", dataType, nullable = true)
}
// The schema used for data generator.
val schemaForGenerator = StructType(fields)
// The schema used for the DataFrame df.
val schema = StructType(StructField("id", IntegerType) +: fields)

logInfo(s"Testing schema: ${schema.treeString}")

val udaf = new ScalaAggregateFunction(schema)
// Generate data at the driver side. We need to materialize the data first and then
// create RDD.
val maybeDataGenerator =
RandomDataGenerator.forType(
dataType = schemaForGenerator,
nullable = true,
seed = Some(System.nanoTime()))
val dataGenerator =
maybeDataGenerator
.getOrElse(fail(s"Failed to create data generator for schema $schemaForGenerator"))
val data = (1 to 50).map { i =>
dataGenerator.apply() match {
case row: Row => Row.fromSeq(i +: row.toSeq)
case null => Row.fromSeq(i +: Seq.fill(schemaForGenerator.length)(null))
case other =>
fail(s"Row or null is expected to be generated, " +
s"but a ${other.getClass.getCanonicalName} is generated.")
}
}
}

// Create a DF for the schema with random data.
val rdd = sqlContext.sparkContext.parallelize(data, 1)
val df = sqlContext.createDataFrame(rdd, schema)

val allColumns = df.schema.fields.map(f => col(f.name))
val expectedAnaswer =
data
.find(r => r.getInt(0) == 50)
.getOrElse(fail("A row with id 50 should be the expected answer."))
checkAnswer(
df.groupBy().agg(udaf(allColumns: _*)),
// udaf returns a Row as the output value.
Row(expectedAnaswer)
)
// Create a DF for the schema with random data.
val rdd = sqlContext.sparkContext.parallelize(data, 1)
val df = sqlContext.createDataFrame(rdd, schema)

val allColumns = df.schema.fields.map(f => col(f.name))
val expectedAnaswer =
data
.find(r => r.getInt(0) == 50)
.getOrElse(fail("A row with id 50 should be the expected answer."))
checkAnswer(
df.groupBy().agg(udaf(allColumns: _*)),
// udaf returns a Row as the output value.
Row(expectedAnaswer)
)
}
}
}
}
Expand Down

0 comments on commit b4bdef6

Please sign in to comment.