diff --git a/R/pkg/inst/tests/test_binaryFile.R b/R/pkg/inst/tests/test_binaryFile.R index 4db7266abc8e2..ccaea18ecab2a 100644 --- a/R/pkg/inst/tests/test_binaryFile.R +++ b/R/pkg/inst/tests/test_binaryFile.R @@ -82,7 +82,7 @@ test_that("saveAsObjectFile()/objectFile() works with multiple paths", { saveAsObjectFile(rdd2, fileName2) rdd <- objectFile(sc, c(fileName1, fileName2)) - expect_true(count(rdd) == 2) + expect_equal(count(rdd), 2) unlink(fileName1, recursive = TRUE) unlink(fileName2, recursive = TRUE) diff --git a/R/pkg/inst/tests/test_binary_function.R b/R/pkg/inst/tests/test_binary_function.R index a1e354e567be5..3be8c65a6c1a0 100644 --- a/R/pkg/inst/tests/test_binary_function.R +++ b/R/pkg/inst/tests/test_binary_function.R @@ -38,13 +38,13 @@ test_that("union on two RDDs", { union.rdd <- unionRDD(rdd, text.rdd) actual <- collect(union.rdd) expect_equal(actual, c(as.list(nums), mockFile)) - expect_true(getSerializedMode(union.rdd) == "byte") + expect_equal(getSerializedMode(union.rdd), "byte") rdd<- map(text.rdd, function(x) {x}) union.rdd <- unionRDD(rdd, text.rdd) actual <- collect(union.rdd) expect_equal(actual, as.list(c(mockFile, mockFile))) - expect_true(getSerializedMode(union.rdd) == "byte") + expect_equal(getSerializedMode(union.rdd), "byte") unlink(fileName) }) diff --git a/R/pkg/inst/tests/test_includeJAR.R b/R/pkg/inst/tests/test_includeJAR.R index 8bc693be20c3c..844d86f3cc97f 100644 --- a/R/pkg/inst/tests/test_includeJAR.R +++ b/R/pkg/inst/tests/test_includeJAR.R @@ -31,7 +31,7 @@ runScript <- function() { test_that("sparkJars tag in SparkContext", { testOutput <- runScript() helloTest <- testOutput[1] - expect_true(helloTest == "Hello, Dave") + expect_equal(helloTest, "Hello, Dave") basicFunction <- testOutput[2] - expect_true(basicFunction == 4L) + expect_equal(basicFunction, "4") }) diff --git a/R/pkg/inst/tests/test_parallelize_collect.R b/R/pkg/inst/tests/test_parallelize_collect.R index fff028657db37..2552127cc547f 100644 --- a/R/pkg/inst/tests/test_parallelize_collect.R +++ b/R/pkg/inst/tests/test_parallelize_collect.R @@ -57,7 +57,7 @@ test_that("parallelize() on simple vectors and lists returns an RDD", { strListRDD2) for (rdd in rdds) { - expect_true(inherits(rdd, "RDD")) + expect_is(rdd, "RDD") expect_true(.hasSlot(rdd, "jrdd") && inherits(rdd@jrdd, "jobj") && isInstanceOf(rdd@jrdd, "org.apache.spark.api.java.JavaRDD")) diff --git a/R/pkg/inst/tests/test_rdd.R b/R/pkg/inst/tests/test_rdd.R index 4fe653856756e..fc3c01d837de4 100644 --- a/R/pkg/inst/tests/test_rdd.R +++ b/R/pkg/inst/tests/test_rdd.R @@ -33,9 +33,9 @@ test_that("get number of partitions in RDD", { }) test_that("first on RDD", { - expect_true(first(rdd) == 1) + expect_equal(first(rdd), 1) newrdd <- lapply(rdd, function(x) x + 1) - expect_true(first(newrdd) == 2) + expect_equal(first(newrdd), 2) }) test_that("count and length on RDD", { diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index 6a08f894313c4..0e4235ea8b4b3 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -61,7 +61,7 @@ test_that("infer types", { expect_equal(infer_type(list(1L, 2L)), list(type = 'array', elementType = "integer", containsNull = TRUE)) testStruct <- infer_type(list(a = 1L, b = "2")) - expect_true(class(testStruct) == "structType") + expect_equal(class(testStruct), "structType") checkStructField(testStruct$fields()[[1]], "a", "IntegerType", TRUE) checkStructField(testStruct$fields()[[2]], "b", "StringType", TRUE) e <- new.env() @@ -73,39 +73,39 @@ test_that("infer types", { test_that("structType and structField", { testField <- structField("a", "string") - expect_true(inherits(testField, "structField")) - expect_true(testField$name() == "a") + expect_is(testField, "structField") + expect_equal(testField$name(), "a") expect_true(testField$nullable()) testSchema <- structType(testField, structField("b", "integer")) - expect_true(inherits(testSchema, "structType")) - expect_true(inherits(testSchema$fields()[[2]], "structField")) - expect_true(testSchema$fields()[[1]]$dataType.toString() == "StringType") + expect_is(testSchema, "structType") + expect_is(testSchema$fields()[[2]], "structField") + expect_equal(testSchema$fields()[[1]]$dataType.toString(), "StringType") }) test_that("create DataFrame from RDD", { rdd <- lapply(parallelize(sc, 1:10), function(x) { list(x, as.character(x)) }) df <- createDataFrame(sqlContext, rdd, list("a", "b")) - expect_true(inherits(df, "DataFrame")) - expect_true(count(df) == 10) + expect_is(df, "DataFrame") + expect_equal(count(df), 10) expect_equal(columns(df), c("a", "b")) expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) df <- createDataFrame(sqlContext, rdd) - expect_true(inherits(df, "DataFrame")) + expect_is(df, "DataFrame") expect_equal(columns(df), c("_1", "_2")) schema <- structType(structField(x = "a", type = "integer", nullable = TRUE), structField(x = "b", type = "string", nullable = TRUE)) df <- createDataFrame(sqlContext, rdd, schema) - expect_true(inherits(df, "DataFrame")) + expect_is(df, "DataFrame") expect_equal(columns(df), c("a", "b")) expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) rdd <- lapply(parallelize(sc, 1:10), function(x) { list(a = x, b = as.character(x)) }) df <- createDataFrame(sqlContext, rdd) - expect_true(inherits(df, "DataFrame")) - expect_true(count(df) == 10) + expect_is(df, "DataFrame") + expect_equal(count(df), 10) expect_equal(columns(df), c("a", "b")) expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) }) @@ -150,26 +150,26 @@ test_that("convert NAs to null type in DataFrames", { test_that("toDF", { rdd <- lapply(parallelize(sc, 1:10), function(x) { list(x, as.character(x)) }) df <- toDF(rdd, list("a", "b")) - expect_true(inherits(df, "DataFrame")) - expect_true(count(df) == 10) + expect_is(df, "DataFrame") + expect_equal(count(df), 10) expect_equal(columns(df), c("a", "b")) expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) df <- toDF(rdd) - expect_true(inherits(df, "DataFrame")) + expect_is(df, "DataFrame") expect_equal(columns(df), c("_1", "_2")) schema <- structType(structField(x = "a", type = "integer", nullable = TRUE), structField(x = "b", type = "string", nullable = TRUE)) df <- toDF(rdd, schema) - expect_true(inherits(df, "DataFrame")) + expect_is(df, "DataFrame") expect_equal(columns(df), c("a", "b")) expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) rdd <- lapply(parallelize(sc, 1:10), function(x) { list(a = x, b = as.character(x)) }) df <- toDF(rdd) - expect_true(inherits(df, "DataFrame")) - expect_true(count(df) == 10) + expect_is(df, "DataFrame") + expect_equal(count(df), 10) expect_equal(columns(df), c("a", "b")) expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) }) @@ -219,21 +219,21 @@ test_that("create DataFrame with different data types", { test_that("jsonFile() on a local file returns a DataFrame", { df <- jsonFile(sqlContext, jsonPath) - expect_true(inherits(df, "DataFrame")) - expect_true(count(df) == 3) + expect_is(df, "DataFrame") + expect_equal(count(df), 3) }) test_that("jsonRDD() on a RDD with json string", { rdd <- parallelize(sc, mockLines) - expect_true(count(rdd) == 3) + expect_equal(count(rdd), 3) df <- jsonRDD(sqlContext, rdd) - expect_true(inherits(df, "DataFrame")) - expect_true(count(df) == 3) + expect_is(df, "DataFrame") + expect_equal(count(df), 3) rdd2 <- flatMap(rdd, function(x) c(x, x)) df <- jsonRDD(sqlContext, rdd2) - expect_true(inherits(df, "DataFrame")) - expect_true(count(df) == 6) + expect_is(df, "DataFrame") + expect_equal(count(df), 6) }) test_that("test cache, uncache and clearCache", { @@ -248,9 +248,9 @@ test_that("test cache, uncache and clearCache", { test_that("test tableNames and tables", { df <- jsonFile(sqlContext, jsonPath) registerTempTable(df, "table1") - expect_true(length(tableNames(sqlContext)) == 1) + expect_equal(length(tableNames(sqlContext)), 1) df <- tables(sqlContext) - expect_true(count(df) == 1) + expect_equal(count(df), 1) dropTempTable(sqlContext, "table1") }) @@ -258,8 +258,8 @@ test_that("registerTempTable() results in a queryable table and sql() results in df <- jsonFile(sqlContext, jsonPath) registerTempTable(df, "table1") newdf <- sql(sqlContext, "SELECT * FROM table1 where name = 'Michael'") - expect_true(inherits(newdf, "DataFrame")) - expect_true(count(newdf) == 1) + expect_is(newdf, "DataFrame") + expect_equal(count(newdf), 1) dropTempTable(sqlContext, "table1") }) @@ -279,14 +279,14 @@ test_that("insertInto() on a registered table", { registerTempTable(dfParquet, "table1") insertInto(dfParquet2, "table1") - expect_true(count(sql(sqlContext, "select * from table1")) == 5) - expect_true(first(sql(sqlContext, "select * from table1 order by age"))$name == "Michael") + expect_equal(count(sql(sqlContext, "select * from table1")), 5) + expect_equal(first(sql(sqlContext, "select * from table1 order by age"))$name, "Michael") dropTempTable(sqlContext, "table1") registerTempTable(dfParquet, "table1") insertInto(dfParquet2, "table1", overwrite = TRUE) - expect_true(count(sql(sqlContext, "select * from table1")) == 2) - expect_true(first(sql(sqlContext, "select * from table1 order by age"))$name == "Bob") + expect_equal(count(sql(sqlContext, "select * from table1")), 2) + expect_equal(first(sql(sqlContext, "select * from table1 order by age"))$name, "Bob") dropTempTable(sqlContext, "table1") }) @@ -294,16 +294,16 @@ test_that("table() returns a new DataFrame", { df <- jsonFile(sqlContext, jsonPath) registerTempTable(df, "table1") tabledf <- table(sqlContext, "table1") - expect_true(inherits(tabledf, "DataFrame")) - expect_true(count(tabledf) == 3) + expect_is(tabledf, "DataFrame") + expect_equal(count(tabledf), 3) dropTempTable(sqlContext, "table1") }) test_that("toRDD() returns an RRDD", { df <- jsonFile(sqlContext, jsonPath) testRDD <- toRDD(df) - expect_true(inherits(testRDD, "RDD")) - expect_true(count(testRDD) == 3) + expect_is(testRDD, "RDD") + expect_equal(count(testRDD), 3) }) test_that("union on two RDDs created from DataFrames returns an RRDD", { @@ -311,9 +311,9 @@ test_that("union on two RDDs created from DataFrames returns an RRDD", { RDD1 <- toRDD(df) RDD2 <- toRDD(df) unioned <- unionRDD(RDD1, RDD2) - expect_true(inherits(unioned, "RDD")) - expect_true(SparkR:::getSerializedMode(unioned) == "byte") - expect_true(collect(unioned)[[2]]$name == "Andy") + expect_is(unioned, "RDD") + expect_equal(SparkR:::getSerializedMode(unioned), "byte") + expect_equal(collect(unioned)[[2]]$name, "Andy") }) test_that("union on mixed serialization types correctly returns a byte RRDD", { @@ -333,16 +333,16 @@ test_that("union on mixed serialization types correctly returns a byte RRDD", { dfRDD <- toRDD(df) unionByte <- unionRDD(rdd, dfRDD) - expect_true(inherits(unionByte, "RDD")) - expect_true(SparkR:::getSerializedMode(unionByte) == "byte") - expect_true(collect(unionByte)[[1]] == 1) - expect_true(collect(unionByte)[[12]]$name == "Andy") + expect_is(unionByte, "RDD") + expect_equal(SparkR:::getSerializedMode(unionByte), "byte") + expect_equal(collect(unionByte)[[1]], 1) + expect_equal(collect(unionByte)[[12]]$name, "Andy") unionString <- unionRDD(textRDD, dfRDD) - expect_true(inherits(unionString, "RDD")) - expect_true(SparkR:::getSerializedMode(unionString) == "byte") - expect_true(collect(unionString)[[1]] == "Michael") - expect_true(collect(unionString)[[5]]$name == "Andy") + expect_is(unionString, "RDD") + expect_equal(SparkR:::getSerializedMode(unionString), "byte") + expect_equal(collect(unionString)[[1]], "Michael") + expect_equal(collect(unionString)[[5]]$name, "Andy") }) test_that("objectFile() works with row serialization", { @@ -352,7 +352,7 @@ test_that("objectFile() works with row serialization", { saveAsObjectFile(coalesce(dfRDD, 1L), objectPath) objectIn <- objectFile(sc, objectPath) - expect_true(inherits(objectIn, "RDD")) + expect_is(objectIn, "RDD") expect_equal(SparkR:::getSerializedMode(objectIn), "byte") expect_equal(collect(objectIn)[[2]]$age, 30) }) @@ -363,32 +363,32 @@ test_that("lapply() on a DataFrame returns an RDD with the correct columns", { row$newCol <- row$age + 5 row }) - expect_true(inherits(testRDD, "RDD")) + expect_is(testRDD, "RDD") collected <- collect(testRDD) - expect_true(collected[[1]]$name == "Michael") - expect_true(collected[[2]]$newCol == "35") + expect_equal(collected[[1]]$name, "Michael") + expect_equal(collected[[2]]$newCol, 35) }) test_that("collect() returns a data.frame", { df <- jsonFile(sqlContext, jsonPath) rdf <- collect(df) expect_true(is.data.frame(rdf)) - expect_true(names(rdf)[1] == "age") - expect_true(nrow(rdf) == 3) - expect_true(ncol(rdf) == 2) + expect_equal(names(rdf)[1], "age") + expect_equal(nrow(rdf), 3) + expect_equal(ncol(rdf), 2) }) test_that("limit() returns DataFrame with the correct number of rows", { df <- jsonFile(sqlContext, jsonPath) dfLimited <- limit(df, 2) - expect_true(inherits(dfLimited, "DataFrame")) - expect_true(count(dfLimited) == 2) + expect_is(dfLimited, "DataFrame") + expect_equal(count(dfLimited), 2) }) test_that("collect() and take() on a DataFrame return the same number of rows and columns", { df <- jsonFile(sqlContext, jsonPath) - expect_true(nrow(collect(df)) == nrow(take(df, 10))) - expect_true(ncol(collect(df)) == ncol(take(df, 10))) + expect_equal(nrow(collect(df)), nrow(take(df, 10))) + expect_equal(ncol(collect(df)), ncol(take(df, 10))) }) test_that("multiple pipeline transformations starting with a DataFrame result in an RDD with the correct values", { @@ -401,9 +401,9 @@ test_that("multiple pipeline transformations starting with a DataFrame result in row$testCol <- if (row$age == 35 && !is.na(row$age)) TRUE else FALSE row }) - expect_true(inherits(second, "RDD")) - expect_true(count(second) == 3) - expect_true(collect(second)[[2]]$age == 35) + expect_is(second, "RDD") + expect_equal(count(second), 3) + expect_equal(collect(second)[[2]]$age, 35) expect_true(collect(second)[[2]]$testCol) expect_false(collect(second)[[3]]$testCol) }) @@ -430,36 +430,36 @@ test_that("cache(), persist(), and unpersist() on a DataFrame", { test_that("schema(), dtypes(), columns(), names() return the correct values/format", { df <- jsonFile(sqlContext, jsonPath) testSchema <- schema(df) - expect_true(length(testSchema$fields()) == 2) - expect_true(testSchema$fields()[[1]]$dataType.toString() == "LongType") - expect_true(testSchema$fields()[[2]]$dataType.simpleString() == "string") - expect_true(testSchema$fields()[[1]]$name() == "age") + expect_equal(length(testSchema$fields()), 2) + expect_equal(testSchema$fields()[[1]]$dataType.toString(), "LongType") + expect_equal(testSchema$fields()[[2]]$dataType.simpleString(), "string") + expect_equal(testSchema$fields()[[1]]$name(), "age") testTypes <- dtypes(df) - expect_true(length(testTypes[[1]]) == 2) - expect_true(testTypes[[1]][1] == "age") + expect_equal(length(testTypes[[1]]), 2) + expect_equal(testTypes[[1]][1], "age") testCols <- columns(df) - expect_true(length(testCols) == 2) - expect_true(testCols[2] == "name") + expect_equal(length(testCols), 2) + expect_equal(testCols[2], "name") testNames <- names(df) - expect_true(length(testNames) == 2) - expect_true(testNames[2] == "name") + expect_equal(length(testNames), 2) + expect_equal(testNames[2], "name") }) test_that("head() and first() return the correct data", { df <- jsonFile(sqlContext, jsonPath) testHead <- head(df) - expect_true(nrow(testHead) == 3) - expect_true(ncol(testHead) == 2) + expect_equal(nrow(testHead), 3) + expect_equal(ncol(testHead), 2) testHead2 <- head(df, 2) - expect_true(nrow(testHead2) == 2) - expect_true(ncol(testHead2) == 2) + expect_equal(nrow(testHead2), 2) + expect_equal(ncol(testHead2), 2) testFirst <- first(df) - expect_true(nrow(testFirst) == 1) + expect_equal(nrow(testFirst), 1) }) test_that("distinct() on DataFrames", { @@ -472,15 +472,15 @@ test_that("distinct() on DataFrames", { df <- jsonFile(sqlContext, jsonPathWithDup) uniques <- distinct(df) - expect_true(inherits(uniques, "DataFrame")) - expect_true(count(uniques) == 3) + expect_is(uniques, "DataFrame") + expect_equal(count(uniques), 3) }) test_that("sample on a DataFrame", { df <- jsonFile(sqlContext, jsonPath) sampled <- sample(df, FALSE, 1.0) expect_equal(nrow(collect(sampled)), count(df)) - expect_true(inherits(sampled, "DataFrame")) + expect_is(sampled, "DataFrame") sampled2 <- sample(df, FALSE, 0.1) expect_true(count(sampled2) < 3) @@ -491,15 +491,15 @@ test_that("sample on a DataFrame", { test_that("select operators", { df <- select(jsonFile(sqlContext, jsonPath), "name", "age") - expect_true(inherits(df$name, "Column")) - expect_true(inherits(df[[2]], "Column")) - expect_true(inherits(df[["age"]], "Column")) + expect_is(df$name, "Column") + expect_is(df[[2]], "Column") + expect_is(df[["age"]], "Column") - expect_true(inherits(df[,1], "DataFrame")) + expect_is(df[,1], "DataFrame") expect_equal(columns(df[,1]), c("name")) expect_equal(columns(df[,"age"]), c("age")) df2 <- df[,c("age", "name")] - expect_true(inherits(df2, "DataFrame")) + expect_is(df2, "DataFrame") expect_equal(columns(df2), c("age", "name")) df$age2 <- df$age @@ -518,50 +518,50 @@ test_that("select operators", { test_that("select with column", { df <- jsonFile(sqlContext, jsonPath) df1 <- select(df, "name") - expect_true(columns(df1) == c("name")) - expect_true(count(df1) == 3) + expect_equal(columns(df1), c("name")) + expect_equal(count(df1), 3) df2 <- select(df, df$age) - expect_true(columns(df2) == c("age")) - expect_true(count(df2) == 3) + expect_equal(columns(df2), c("age")) + expect_equal(count(df2), 3) }) test_that("selectExpr() on a DataFrame", { df <- jsonFile(sqlContext, jsonPath) selected <- selectExpr(df, "age * 2") - expect_true(names(selected) == "(age * 2)") + expect_equal(names(selected), "(age * 2)") expect_equal(collect(selected), collect(select(df, df$age * 2L))) selected2 <- selectExpr(df, "name as newName", "abs(age) as age") expect_equal(names(selected2), c("newName", "age")) - expect_true(count(selected2) == 3) + expect_equal(count(selected2), 3) }) test_that("column calculation", { df <- jsonFile(sqlContext, jsonPath) d <- collect(select(df, alias(df$age + 1, "age2"))) - expect_true(names(d) == c("age2")) + expect_equal(names(d), c("age2")) df2 <- select(df, lower(df$name), abs(df$age)) - expect_true(inherits(df2, "DataFrame")) - expect_true(count(df2) == 3) + expect_is(df2, "DataFrame") + expect_equal(count(df2), 3) }) test_that("read.df() from json file", { df <- read.df(sqlContext, jsonPath, "json") - expect_true(inherits(df, "DataFrame")) - expect_true(count(df) == 3) + expect_is(df, "DataFrame") + expect_equal(count(df), 3) # Check if we can apply a user defined schema schema <- structType(structField("name", type = "string"), structField("age", type = "double")) df1 <- read.df(sqlContext, jsonPath, "json", schema) - expect_true(inherits(df1, "DataFrame")) + expect_is(df1, "DataFrame") expect_equal(dtypes(df1), list(c("name", "string"), c("age", "double"))) # Run the same with loadDF df2 <- loadDF(sqlContext, jsonPath, "json", schema) - expect_true(inherits(df2, "DataFrame")) + expect_is(df2, "DataFrame") expect_equal(dtypes(df2), list(c("name", "string"), c("age", "double"))) }) @@ -569,8 +569,8 @@ test_that("write.df() as parquet file", { df <- read.df(sqlContext, jsonPath, "json") write.df(df, parquetPath, "parquet", mode="overwrite") df2 <- read.df(sqlContext, parquetPath, "parquet") - expect_true(inherits(df2, "DataFrame")) - expect_true(count(df2) == 3) + expect_is(df2, "DataFrame") + expect_equal(count(df2), 3) }) test_that("test HiveContext", { @@ -580,17 +580,17 @@ test_that("test HiveContext", { skip("Hive is not build with SparkSQL, skipped") }) df <- createExternalTable(hiveCtx, "json", jsonPath, "json") - expect_true(inherits(df, "DataFrame")) - expect_true(count(df) == 3) + expect_is(df, "DataFrame") + expect_equal(count(df), 3) df2 <- sql(hiveCtx, "select * from json") - expect_true(inherits(df2, "DataFrame")) - expect_true(count(df2) == 3) + expect_is(df2, "DataFrame") + expect_equal(count(df2), 3) jsonPath2 <- tempfile(pattern="sparkr-test", fileext=".tmp") saveAsTable(df, "json", "json", "append", path = jsonPath2) df3 <- sql(hiveCtx, "select * from json") - expect_true(inherits(df3, "DataFrame")) - expect_true(count(df3) == 6) + expect_is(df3, "DataFrame") + expect_equal(count(df3), 6) }) test_that("column operators", { @@ -643,65 +643,65 @@ test_that("string operators", { test_that("group by", { df <- jsonFile(sqlContext, jsonPath) df1 <- agg(df, name = "max", age = "sum") - expect_true(1 == count(df1)) + expect_equal(1, count(df1)) df1 <- agg(df, age2 = max(df$age)) - expect_true(1 == count(df1)) + expect_equal(1, count(df1)) expect_equal(columns(df1), c("age2")) gd <- groupBy(df, "name") - expect_true(inherits(gd, "GroupedData")) + expect_is(gd, "GroupedData") df2 <- count(gd) - expect_true(inherits(df2, "DataFrame")) - expect_true(3 == count(df2)) + expect_is(df2, "DataFrame") + expect_equal(3, count(df2)) # Also test group_by, summarize, mean gd1 <- group_by(df, "name") - expect_true(inherits(gd1, "GroupedData")) + expect_is(gd1, "GroupedData") df_summarized <- summarize(gd, mean_age = mean(df$age)) - expect_true(inherits(df_summarized, "DataFrame")) - expect_true(3 == count(df_summarized)) + expect_is(df_summarized, "DataFrame") + expect_equal(3, count(df_summarized)) df3 <- agg(gd, age = "sum") - expect_true(inherits(df3, "DataFrame")) - expect_true(3 == count(df3)) + expect_is(df3, "DataFrame") + expect_equal(3, count(df3)) df3 <- agg(gd, age = sum(df$age)) - expect_true(inherits(df3, "DataFrame")) - expect_true(3 == count(df3)) + expect_is(df3, "DataFrame") + expect_equal(3, count(df3)) expect_equal(columns(df3), c("name", "age")) df4 <- sum(gd, "age") - expect_true(inherits(df4, "DataFrame")) - expect_true(3 == count(df4)) - expect_true(3 == count(mean(gd, "age"))) - expect_true(3 == count(max(gd, "age"))) + expect_is(df4, "DataFrame") + expect_equal(3, count(df4)) + expect_equal(3, count(mean(gd, "age"))) + expect_equal(3, count(max(gd, "age"))) }) test_that("arrange() and orderBy() on a DataFrame", { df <- jsonFile(sqlContext, jsonPath) sorted <- arrange(df, df$age) - expect_true(collect(sorted)[1,2] == "Michael") + expect_equal(collect(sorted)[1,2], "Michael") sorted2 <- arrange(df, "name") - expect_true(collect(sorted2)[2,"age"] == 19) + expect_equal(collect(sorted2)[2,"age"], 19) sorted3 <- orderBy(df, asc(df$age)) expect_true(is.na(first(sorted3)$age)) - expect_true(collect(sorted3)[2, "age"] == 19) + expect_equal(collect(sorted3)[2, "age"], 19) sorted4 <- orderBy(df, desc(df$name)) - expect_true(first(sorted4)$name == "Michael") - expect_true(collect(sorted4)[3,"name"] == "Andy") + expect_equal(first(sorted4)$name, "Michael") + expect_equal(collect(sorted4)[3,"name"], "Andy") }) test_that("filter() on a DataFrame", { df <- jsonFile(sqlContext, jsonPath) filtered <- filter(df, "age > 20") - expect_true(count(filtered) == 1) - expect_true(collect(filtered)$name == "Andy") + expect_equal(count(filtered), 1) + expect_equal(collect(filtered)$name, "Andy") filtered2 <- where(df, df$name != "Michael") - expect_true(count(filtered2) == 2) - expect_true(collect(filtered2)$age[2] == 19) + expect_equal(count(filtered2), 2) + expect_equal(collect(filtered2)$age[2], 19) # test suites for %in% filtered3 <- filter(df, "age in (19)") @@ -727,29 +727,29 @@ test_that("join() on a DataFrame", { joined <- join(df, df2) expect_equal(names(joined), c("age", "name", "name", "test")) - expect_true(count(joined) == 12) + expect_equal(count(joined), 12) joined2 <- join(df, df2, df$name == df2$name) expect_equal(names(joined2), c("age", "name", "name", "test")) - expect_true(count(joined2) == 3) + expect_equal(count(joined2), 3) joined3 <- join(df, df2, df$name == df2$name, "right_outer") expect_equal(names(joined3), c("age", "name", "name", "test")) - expect_true(count(joined3) == 4) + expect_equal(count(joined3), 4) expect_true(is.na(collect(orderBy(joined3, joined3$age))$age[2])) joined4 <- select(join(df, df2, df$name == df2$name, "outer"), alias(df$age + 5, "newAge"), df$name, df2$test) expect_equal(names(joined4), c("newAge", "name", "test")) - expect_true(count(joined4) == 4) + expect_equal(count(joined4), 4) expect_equal(collect(orderBy(joined4, joined4$name))$newAge[3], 24) }) test_that("toJSON() returns an RDD of the correct values", { df <- jsonFile(sqlContext, jsonPath) testRDD <- toJSON(df) - expect_true(inherits(testRDD, "RDD")) - expect_true(SparkR:::getSerializedMode(testRDD) == "string") + expect_is(testRDD, "RDD") + expect_equal(SparkR:::getSerializedMode(testRDD), "string") expect_equal(collect(testRDD)[[1]], mockLines[1]) }) @@ -775,50 +775,50 @@ test_that("unionAll(), except(), and intersect() on a DataFrame", { df2 <- read.df(sqlContext, jsonPath2, "json") unioned <- arrange(unionAll(df, df2), df$age) - expect_true(inherits(unioned, "DataFrame")) - expect_true(count(unioned) == 6) - expect_true(first(unioned)$name == "Michael") + expect_is(unioned, "DataFrame") + expect_equal(count(unioned), 6) + expect_equal(first(unioned)$name, "Michael") excepted <- arrange(except(df, df2), desc(df$age)) - expect_true(inherits(unioned, "DataFrame")) - expect_true(count(excepted) == 2) - expect_true(first(excepted)$name == "Justin") + expect_is(unioned, "DataFrame") + expect_equal(count(excepted), 2) + expect_equal(first(excepted)$name, "Justin") intersected <- arrange(intersect(df, df2), df$age) - expect_true(inherits(unioned, "DataFrame")) - expect_true(count(intersected) == 1) - expect_true(first(intersected)$name == "Andy") + expect_is(unioned, "DataFrame") + expect_equal(count(intersected), 1) + expect_equal(first(intersected)$name, "Andy") }) test_that("withColumn() and withColumnRenamed()", { df <- jsonFile(sqlContext, jsonPath) newDF <- withColumn(df, "newAge", df$age + 2) - expect_true(length(columns(newDF)) == 3) - expect_true(columns(newDF)[3] == "newAge") - expect_true(first(filter(newDF, df$name != "Michael"))$newAge == 32) + expect_equal(length(columns(newDF)), 3) + expect_equal(columns(newDF)[3], "newAge") + expect_equal(first(filter(newDF, df$name != "Michael"))$newAge, 32) newDF2 <- withColumnRenamed(df, "age", "newerAge") - expect_true(length(columns(newDF2)) == 2) - expect_true(columns(newDF2)[1] == "newerAge") + expect_equal(length(columns(newDF2)), 2) + expect_equal(columns(newDF2)[1], "newerAge") }) test_that("mutate() and rename()", { df <- jsonFile(sqlContext, jsonPath) newDF <- mutate(df, newAge = df$age + 2) - expect_true(length(columns(newDF)) == 3) - expect_true(columns(newDF)[3] == "newAge") - expect_true(first(filter(newDF, df$name != "Michael"))$newAge == 32) + expect_equal(length(columns(newDF)), 3) + expect_equal(columns(newDF)[3], "newAge") + expect_equal(first(filter(newDF, df$name != "Michael"))$newAge, 32) newDF2 <- rename(df, newerAge = df$age) - expect_true(length(columns(newDF2)) == 2) - expect_true(columns(newDF2)[1] == "newerAge") + expect_equal(length(columns(newDF2)), 2) + expect_equal(columns(newDF2)[1], "newerAge") }) test_that("write.df() on DataFrame and works with parquetFile", { df <- jsonFile(sqlContext, jsonPath) write.df(df, parquetPath, "parquet", mode="overwrite") parquetDF <- parquetFile(sqlContext, parquetPath) - expect_true(inherits(parquetDF, "DataFrame")) + expect_is(parquetDF, "DataFrame") expect_equal(count(df), count(parquetDF)) }) @@ -828,8 +828,8 @@ test_that("parquetFile works with multiple input paths", { parquetPath2 <- tempfile(pattern = "parquetPath2", fileext = ".parquet") write.df(df, parquetPath2, "parquet", mode="overwrite") parquetDF <- parquetFile(sqlContext, parquetPath, parquetPath2) - expect_true(inherits(parquetDF, "DataFrame")) - expect_true(count(parquetDF) == count(df)*2) + expect_is(parquetDF, "DataFrame") + expect_equal(count(parquetDF), count(df)*2) }) test_that("describe() on a DataFrame", { @@ -851,58 +851,58 @@ test_that("dropna() on a DataFrame", { expected <- rows[!is.na(rows$name),] actual <- collect(dropna(df, cols = "name")) - expect_true(identical(expected, actual)) + expect_identical(expected, actual) expected <- rows[!is.na(rows$age),] actual <- collect(dropna(df, cols = "age")) row.names(expected) <- row.names(actual) # identical on two dataframes does not work here. Don't know why. # use identical on all columns as a workaround. - expect_true(identical(expected$age, actual$age)) - expect_true(identical(expected$height, actual$height)) - expect_true(identical(expected$name, actual$name)) + expect_identical(expected$age, actual$age) + expect_identical(expected$height, actual$height) + expect_identical(expected$name, actual$name) expected <- rows[!is.na(rows$age) & !is.na(rows$height),] actual <- collect(dropna(df, cols = c("age", "height"))) - expect_true(identical(expected, actual)) + expect_identical(expected, actual) expected <- rows[!is.na(rows$age) & !is.na(rows$height) & !is.na(rows$name),] actual <- collect(dropna(df)) - expect_true(identical(expected, actual)) + expect_identical(expected, actual) # drop with how expected <- rows[!is.na(rows$age) & !is.na(rows$height) & !is.na(rows$name),] actual <- collect(dropna(df)) - expect_true(identical(expected, actual)) + expect_identical(expected, actual) expected <- rows[!is.na(rows$age) | !is.na(rows$height) | !is.na(rows$name),] actual <- collect(dropna(df, "all")) - expect_true(identical(expected, actual)) + expect_identical(expected, actual) expected <- rows[!is.na(rows$age) & !is.na(rows$height) & !is.na(rows$name),] actual <- collect(dropna(df, "any")) - expect_true(identical(expected, actual)) + expect_identical(expected, actual) expected <- rows[!is.na(rows$age) & !is.na(rows$height),] actual <- collect(dropna(df, "any", cols = c("age", "height"))) - expect_true(identical(expected, actual)) + expect_identical(expected, actual) expected <- rows[!is.na(rows$age) | !is.na(rows$height),] actual <- collect(dropna(df, "all", cols = c("age", "height"))) - expect_true(identical(expected, actual)) + expect_identical(expected, actual) # drop with threshold expected <- rows[as.integer(!is.na(rows$age)) + as.integer(!is.na(rows$height)) >= 2,] actual <- collect(dropna(df, minNonNulls = 2, cols = c("age", "height"))) - expect_true(identical(expected, actual)) + expect_identical(expected, actual) expected <- rows[as.integer(!is.na(rows$age)) + as.integer(!is.na(rows$height)) + as.integer(!is.na(rows$name)) >= 3,] actual <- collect(dropna(df, minNonNulls = 3, cols = c("name", "age", "height"))) - expect_true(identical(expected, actual)) + expect_identical(expected, actual) }) test_that("fillna() on a DataFrame", { @@ -915,22 +915,22 @@ test_that("fillna() on a DataFrame", { expected$age[is.na(expected$age)] <- 50 expected$height[is.na(expected$height)] <- 50.6 actual <- collect(fillna(df, 50.6)) - expect_true(identical(expected, actual)) + expect_identical(expected, actual) expected <- rows expected$name[is.na(expected$name)] <- "unknown" actual <- collect(fillna(df, "unknown")) - expect_true(identical(expected, actual)) + expect_identical(expected, actual) expected <- rows expected$age[is.na(expected$age)] <- 50 actual <- collect(fillna(df, 50.6, "age")) - expect_true(identical(expected, actual)) + expect_identical(expected, actual) expected <- rows expected$name[is.na(expected$name)] <- "unknown" actual <- collect(fillna(df, "unknown", c("age", "name"))) - expect_true(identical(expected, actual)) + expect_identical(expected, actual) # fill with named list @@ -939,7 +939,7 @@ test_that("fillna() on a DataFrame", { expected$height[is.na(expected$height)] <- 50.6 expected$name[is.na(expected$name)] <- "unknown" actual <- collect(fillna(df, list("age" = 50, "height" = 50.6, "name" = "unknown"))) - expect_true(identical(expected, actual)) + expect_identical(expected, actual) }) unlink(parquetPath) diff --git a/R/pkg/inst/tests/test_take.R b/R/pkg/inst/tests/test_take.R index c5eb417b40159..c2c724cdc762f 100644 --- a/R/pkg/inst/tests/test_take.R +++ b/R/pkg/inst/tests/test_take.R @@ -59,8 +59,8 @@ test_that("take() gives back the original elements in correct count and order", expect_equal(take(strListRDD, 3), as.list(head(strList, n = 3))) expect_equal(take(strListRDD2, 1), as.list(head(strList, n = 1))) - expect_true(length(take(strListRDD, 0)) == 0) - expect_true(length(take(strVectorRDD, 0)) == 0) - expect_true(length(take(numListRDD, 0)) == 0) - expect_true(length(take(numVectorRDD, 0)) == 0) + expect_equal(length(take(strListRDD, 0)), 0) + expect_equal(length(take(strVectorRDD, 0)), 0) + expect_equal(length(take(numListRDD, 0)), 0) + expect_equal(length(take(numVectorRDD, 0)), 0) }) diff --git a/R/pkg/inst/tests/test_textFile.R b/R/pkg/inst/tests/test_textFile.R index 092ad9dc10c2e..58318dfef71ab 100644 --- a/R/pkg/inst/tests/test_textFile.R +++ b/R/pkg/inst/tests/test_textFile.R @@ -27,9 +27,9 @@ test_that("textFile() on a local file returns an RDD", { writeLines(mockFile, fileName) rdd <- textFile(sc, fileName) - expect_true(inherits(rdd, "RDD")) + expect_is(rdd, "RDD") expect_true(count(rdd) > 0) - expect_true(count(rdd) == 2) + expect_equal(count(rdd), 2) unlink(fileName) }) @@ -133,7 +133,7 @@ test_that("textFile() on multiple paths", { writeLines("Spark is awesome.", fileName2) rdd <- textFile(sc, c(fileName1, fileName2)) - expect_true(count(rdd) == 2) + expect_equal(count(rdd), 2) unlink(fileName1) unlink(fileName2) diff --git a/R/pkg/inst/tests/test_utils.R b/R/pkg/inst/tests/test_utils.R index 15030e6f1d77e..aa0d2a66b9082 100644 --- a/R/pkg/inst/tests/test_utils.R +++ b/R/pkg/inst/tests/test_utils.R @@ -45,10 +45,10 @@ test_that("serializeToBytes on RDD", { writeLines(mockFile, fileName) text.rdd <- textFile(sc, fileName) - expect_true(getSerializedMode(text.rdd) == "string") + expect_equal(getSerializedMode(text.rdd), "string") ser.rdd <- serializeToBytes(text.rdd) expect_equal(collect(ser.rdd), as.list(mockFile)) - expect_true(getSerializedMode(ser.rdd) == "byte") + expect_equal(getSerializedMode(ser.rdd), "byte") unlink(fileName) }) diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index baa4c661cc21e..251a797dc28a2 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -486,11 +486,17 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { // Test for using the util function to change our log levels. test("log4j log level change") { - Utils.setLogLevel(org.apache.log4j.Level.ALL) - assert(log.isInfoEnabled()) - Utils.setLogLevel(org.apache.log4j.Level.ERROR) - assert(!log.isInfoEnabled()) - assert(log.isErrorEnabled()) + val current = org.apache.log4j.Logger.getRootLogger().getLevel() + try { + Utils.setLogLevel(org.apache.log4j.Level.ALL) + assert(log.isInfoEnabled()) + Utils.setLogLevel(org.apache.log4j.Level.ERROR) + assert(!log.isInfoEnabled()) + assert(log.isErrorEnabled()) + } finally { + // Best effort at undoing changes this test made. + Utils.setLogLevel(current) + } } test("deleteRecursively") { diff --git a/dev/run-tests.py b/dev/run-tests.py index 4596e07014733..1f0d218514f92 100755 --- a/dev/run-tests.py +++ b/dev/run-tests.py @@ -96,8 +96,8 @@ def determine_modules_to_test(changed_modules): ['examples', 'graphx'] >>> x = sorted(x.name for x in determine_modules_to_test([modules.sql])) >>> x # doctest: +NORMALIZE_WHITESPACE - ['examples', 'hive-thriftserver', 'mllib', 'pyspark-core', 'pyspark-ml', \ - 'pyspark-mllib', 'pyspark-sql', 'pyspark-streaming', 'sparkr', 'sql'] + ['examples', 'hive-thriftserver', 'mllib', 'pyspark-ml', \ + 'pyspark-mllib', 'pyspark-sql', 'sparkr', 'sql'] """ # If we're going to have to run all of the tests, then we can just short-circuit # and return 'root'. No module depends on root, so if it appears then it will be @@ -293,7 +293,8 @@ def build_spark_sbt(hadoop_version): build_profiles = get_hadoop_profiles(hadoop_version) + modules.root.build_profile_flags sbt_goals = ["package", "assembly/assembly", - "streaming-kafka-assembly/assembly"] + "streaming-kafka-assembly/assembly", + "streaming-flume-assembly/assembly"] profiles_and_goals = build_profiles + sbt_goals print("[info] Building Spark (w/Hive 0.13.1) using SBT with these arguments: ", diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index efe3a897e9c10..993583e2f4119 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -203,7 +203,7 @@ def contains_file(self, filename): streaming_flume = Module( - name="streaming_flume", + name="streaming-flume", dependencies=[streaming], source_file_regexes=[ "external/flume", @@ -214,6 +214,15 @@ def contains_file(self, filename): ) +streaming_flume_assembly = Module( + name="streaming-flume-assembly", + dependencies=[streaming_flume, streaming_flume_sink], + source_file_regexes=[ + "external/flume-assembly", + ] +) + + mllib = Module( name="mllib", dependencies=[streaming, sql], @@ -241,7 +250,7 @@ def contains_file(self, filename): pyspark_core = Module( name="pyspark-core", - dependencies=[mllib, streaming, streaming_kafka], + dependencies=[], source_file_regexes=[ "python/(?!pyspark/(ml|mllib|sql|streaming))" ], @@ -281,7 +290,7 @@ def contains_file(self, filename): pyspark_streaming = Module( name="pyspark-streaming", - dependencies=[pyspark_core, streaming, streaming_kafka], + dependencies=[pyspark_core, streaming, streaming_kafka, streaming_flume_assembly], source_file_regexes=[ "python/pyspark/streaming" ], diff --git a/docs/mllib-collaborative-filtering.md b/docs/mllib-collaborative-filtering.md index dfdf6216b270c..eedc23424ad54 100644 --- a/docs/mllib-collaborative-filtering.md +++ b/docs/mllib-collaborative-filtering.md @@ -77,7 +77,7 @@ val ratings = data.map(_.split(',') match { case Array(user, item, rate) => // Build the recommendation model using ALS val rank = 10 -val numIterations = 20 +val numIterations = 10 val model = ALS.train(ratings, rank, numIterations, 0.01) // Evaluate the model on rating data @@ -149,7 +149,7 @@ public class CollaborativeFiltering { // Build the recommendation model using ALS int rank = 10; - int numIterations = 20; + int numIterations = 10; MatrixFactorizationModel model = ALS.train(JavaRDD.toRDD(ratings), rank, numIterations, 0.01); // Evaluate the model on rating data @@ -210,7 +210,7 @@ ratings = data.map(lambda l: l.split(',')).map(lambda l: Rating(int(l[0]), int(l # Build the recommendation model using Alternating Least Squares rank = 10 -numIterations = 20 +numIterations = 10 model = ALS.train(ratings, rank, numIterations) # Evaluate the model on training data diff --git a/docs/mllib-linear-methods.md b/docs/mllib-linear-methods.md index 2a2a7c13186d8..3927d65fbf8fb 100644 --- a/docs/mllib-linear-methods.md +++ b/docs/mllib-linear-methods.md @@ -499,7 +499,7 @@ Note that the Python API does not yet support multiclass classification and mode will in the future. {% highlight python %} -from pyspark.mllib.classification import LogisticRegressionWithLBFGS +from pyspark.mllib.classification import LogisticRegressionWithLBFGS, LogisticRegressionModel from pyspark.mllib.regression import LabeledPoint from numpy import array @@ -518,6 +518,10 @@ model = LogisticRegressionWithLBFGS.train(parsedData) labelsAndPreds = parsedData.map(lambda p: (p.label, model.predict(p.features))) trainErr = labelsAndPreds.filter(lambda (v, p): v != p).count() / float(parsedData.count()) print("Training Error = " + str(trainErr)) + +# Save and load model +model.save(sc, "myModelPath") +sameModel = LogisticRegressionModel.load(sc, "myModelPath") {% endhighlight %} @@ -668,7 +672,7 @@ values. We compute the mean squared error at the end to evaluate Note that the Python API does not yet support model save/load but will in the future. {% highlight python %} -from pyspark.mllib.regression import LabeledPoint, LinearRegressionWithSGD +from pyspark.mllib.regression import LabeledPoint, LinearRegressionWithSGD, LinearRegressionModel from numpy import array # Load and parse the data @@ -686,6 +690,10 @@ model = LinearRegressionWithSGD.train(parsedData) valuesAndPreds = parsedData.map(lambda p: (p.label, model.predict(p.features))) MSE = valuesAndPreds.map(lambda (v, p): (v - p)**2).reduce(lambda x, y: x + y) / valuesAndPreds.count() print("Mean Squared Error = " + str(MSE)) + +# Save and load model +model.save(sc, "myModelPath") +sameModel = LinearRegressionModel.load(sc, "myModelPath") {% endhighlight %} diff --git a/docs/mllib-naive-bayes.md b/docs/mllib-naive-bayes.md index bf6d124fd5d8d..e73bd30f3a90a 100644 --- a/docs/mllib-naive-bayes.md +++ b/docs/mllib-naive-bayes.md @@ -119,7 +119,7 @@ used for evaluation and prediction. Note that the Python API does not yet support model save/load but will in the future. {% highlight python %} -from pyspark.mllib.classification import NaiveBayes +from pyspark.mllib.classification import NaiveBayes, NaiveBayesModel from pyspark.mllib.linalg import Vectors from pyspark.mllib.regression import LabeledPoint @@ -140,6 +140,10 @@ model = NaiveBayes.train(training, 1.0) # Make prediction and test accuracy. predictionAndLabel = test.map(lambda p : (model.predict(p.features), p.label)) accuracy = 1.0 * predictionAndLabel.filter(lambda (x, v): x == v).count() / test.count() + +# Save and load model +model.save(sc, "myModelPath") +sameModel = NaiveBayesModel.load(sc, "myModelPath") {% endhighlight %} diff --git a/docs/streaming-flume-integration.md b/docs/streaming-flume-integration.md index 8d6e74370918f..de0461010daec 100644 --- a/docs/streaming-flume-integration.md +++ b/docs/streaming-flume-integration.md @@ -58,6 +58,15 @@ configuring Flume agents. See the [API docs](api/java/index.html?org/apache/spark/streaming/flume/FlumeUtils.html) and the [example]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/java/org/apache/spark/examples/streaming/JavaFlumeEventCount.java). +
+ from pyspark.streaming.flume import FlumeUtils + + flumeStream = FlumeUtils.createStream(streamingContext, [chosen machine's hostname], [chosen port]) + + By default, the Python API will decode Flume event body as UTF8 encoded strings. You can specify your custom decoding function to decode the body byte arrays in Flume events to any arbitrary data type. + See the [API docs](api/python/pyspark.streaming.html#pyspark.streaming.flume.FlumeUtils) + and the [example]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/python/streaming/flume_wordcount.py). +
Note that the hostname should be the same as the one used by the resource manager in the @@ -135,6 +144,15 @@ configuring Flume agents. JavaReceiverInputDStreamflumeStream = FlumeUtils.createPollingStream(streamingContext, [sink machine hostname], [sink port]); +
+ from pyspark.streaming.flume import FlumeUtils + + addresses = [([sink machine hostname 1], [sink port 1]), ([sink machine hostname 2], [sink port 2])] + flumeStream = FlumeUtils.createPollingStream(streamingContext, addresses) + + By default, the Python API will decode Flume event body as UTF8 encoded strings. You can specify your custom decoding function to decode the body byte arrays in Flume events to any arbitrary data type. + See the [API docs](api/python/pyspark.streaming.html#pyspark.streaming.flume.FlumeUtils). +
See the Scala example [FlumePollingEventCount]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/scala/org/apache/spark/examples/streaming/FlumePollingEventCount.scala). diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md index b784d59666fec..e72d5580dae55 100644 --- a/docs/streaming-programming-guide.md +++ b/docs/streaming-programming-guide.md @@ -683,7 +683,7 @@ for Java, and [StreamingContext](api/python/pyspark.streaming.html#pyspark.strea {:.no_toc} Python API As of Spark {{site.SPARK_VERSION_SHORT}}, -out of these sources, *only* Kafka is available in the Python API. We will add more advanced sources in the Python API in future. +out of these sources, *only* Kafka and Flume are available in the Python API. We will add more advanced sources in the Python API in future. This category of sources require interfacing with external non-Spark libraries, some of them with complex dependencies (e.g., Kafka and Flume). Hence, to minimize issues related to version conflicts diff --git a/examples/src/main/python/streaming/flume_wordcount.py b/examples/src/main/python/streaming/flume_wordcount.py new file mode 100644 index 0000000000000..091b64d8c4af4 --- /dev/null +++ b/examples/src/main/python/streaming/flume_wordcount.py @@ -0,0 +1,55 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" + Counts words in UTF8 encoded, '\n' delimited text received from the network every second. + Usage: flume_wordcount.py + + To run this on your local machine, you need to setup Flume first, see + https://flume.apache.org/documentation.html + + and then run the example + `$ bin/spark-submit --jars external/flume-assembly/target/scala-*/\ + spark-streaming-flume-assembly-*.jar examples/src/main/python/streaming/flume_wordcount.py \ + localhost 12345 +""" +from __future__ import print_function + +import sys + +from pyspark import SparkContext +from pyspark.streaming import StreamingContext +from pyspark.streaming.flume import FlumeUtils + +if __name__ == "__main__": + if len(sys.argv) != 3: + print("Usage: flume_wordcount.py ", file=sys.stderr) + exit(-1) + + sc = SparkContext(appName="PythonStreamingFlumeWordCount") + ssc = StreamingContext(sc, 1) + + hostname, port = sys.argv[1:] + kvs = FlumeUtils.createStream(ssc, hostname, int(port)) + lines = kvs.map(lambda x: x[1]) + counts = lines.flatMap(lambda line: line.split(" ")) \ + .map(lambda word: (word, 1)) \ + .reduceByKey(lambda a, b: a+b) + counts.pprint() + + ssc.start() + ssc.awaitTermination() diff --git a/external/flume-assembly/pom.xml b/external/flume-assembly/pom.xml new file mode 100644 index 0000000000000..8565cd83edfa2 --- /dev/null +++ b/external/flume-assembly/pom.xml @@ -0,0 +1,135 @@ + + + + + 4.0.0 + + org.apache.spark + spark-parent_2.10 + 1.5.0-SNAPSHOT + ../../pom.xml + + + org.apache.spark + spark-streaming-flume-assembly_2.10 + jar + Spark Project External Flume Assembly + http://spark.apache.org/ + + + streaming-flume-assembly + + + + + org.apache.spark + spark-streaming-flume_${scala.binary.version} + ${project.version} + + + org.apache.spark + spark-streaming_${scala.binary.version} + ${project.version} + provided + + + org.apache.avro + avro + ${avro.version} + + + org.apache.avro + avro-ipc + ${avro.version} + + + io.netty + netty + + + org.mortbay.jetty + jetty + + + org.mortbay.jetty + jetty-util + + + org.mortbay.jetty + servlet-api + + + org.apache.velocity + velocity + + + + + + + target/scala-${scala.binary.version}/classes + target/scala-${scala.binary.version}/test-classes + + + org.apache.maven.plugins + maven-shade-plugin + + false + ${project.build.directory}/scala-${scala.binary.version}/spark-streaming-flume-assembly-${project.version}.jar + + + *:* + + + + + *:* + + META-INF/*.SF + META-INF/*.DSA + META-INF/*.RSA + + + + + + + package + + shade + + + + + + reference.conf + + + log4j.properties + + + + + + + + + + + + diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeTestUtils.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeTestUtils.scala new file mode 100644 index 0000000000000..9d9c3b189415f --- /dev/null +++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeTestUtils.scala @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.flume + +import java.net.{InetSocketAddress, ServerSocket} +import java.nio.ByteBuffer +import java.util.{List => JList} + +import scala.collection.JavaConversions._ + +import com.google.common.base.Charsets.UTF_8 +import org.apache.avro.ipc.NettyTransceiver +import org.apache.avro.ipc.specific.SpecificRequestor +import org.apache.commons.lang3.RandomUtils +import org.apache.flume.source.avro +import org.apache.flume.source.avro.{AvroSourceProtocol, AvroFlumeEvent} +import org.jboss.netty.channel.ChannelPipeline +import org.jboss.netty.channel.socket.SocketChannel +import org.jboss.netty.channel.socket.nio.NioClientSocketChannelFactory +import org.jboss.netty.handler.codec.compression.{ZlibDecoder, ZlibEncoder} + +import org.apache.spark.util.Utils +import org.apache.spark.SparkConf + +/** + * Share codes for Scala and Python unit tests + */ +private[flume] class FlumeTestUtils { + + private var transceiver: NettyTransceiver = null + + private val testPort: Int = findFreePort() + + def getTestPort(): Int = testPort + + /** Find a free port */ + private def findFreePort(): Int = { + val candidatePort = RandomUtils.nextInt(1024, 65536) + Utils.startServiceOnPort(candidatePort, (trialPort: Int) => { + val socket = new ServerSocket(trialPort) + socket.close() + (null, trialPort) + }, new SparkConf())._2 + } + + /** Send data to the flume receiver */ + def writeInput(input: JList[String], enableCompression: Boolean): Unit = { + val testAddress = new InetSocketAddress("localhost", testPort) + + val inputEvents = input.map { item => + val event = new AvroFlumeEvent + event.setBody(ByteBuffer.wrap(item.getBytes(UTF_8))) + event.setHeaders(Map[CharSequence, CharSequence]("test" -> "header")) + event + } + + // if last attempted transceiver had succeeded, close it + close() + + // Create transceiver + transceiver = { + if (enableCompression) { + new NettyTransceiver(testAddress, new CompressionChannelFactory(6)) + } else { + new NettyTransceiver(testAddress) + } + } + + // Create Avro client with the transceiver + val client = SpecificRequestor.getClient(classOf[AvroSourceProtocol], transceiver) + if (client == null) { + throw new AssertionError("Cannot create client") + } + + // Send data + val status = client.appendBatch(inputEvents.toList) + if (status != avro.Status.OK) { + throw new AssertionError("Sent events unsuccessfully") + } + } + + def close(): Unit = { + if (transceiver != null) { + transceiver.close() + transceiver = null + } + } + + /** Class to create socket channel with compression */ + private class CompressionChannelFactory(compressionLevel: Int) + extends NioClientSocketChannelFactory { + + override def newChannel(pipeline: ChannelPipeline): SocketChannel = { + val encoder = new ZlibEncoder(compressionLevel) + pipeline.addFirst("deflater", encoder) + pipeline.addFirst("inflater", new ZlibDecoder()) + super.newChannel(pipeline) + } + } + +} diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeUtils.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeUtils.scala index 44dec45c227ca..095bfb0c73a9a 100644 --- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeUtils.scala +++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeUtils.scala @@ -18,10 +18,16 @@ package org.apache.spark.streaming.flume import java.net.InetSocketAddress +import java.io.{DataOutputStream, ByteArrayOutputStream} +import java.util.{List => JList, Map => JMap} +import scala.collection.JavaConversions._ + +import org.apache.spark.api.java.function.PairFunction +import org.apache.spark.api.python.PythonRDD import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.StreamingContext -import org.apache.spark.streaming.api.java.{JavaReceiverInputDStream, JavaStreamingContext} +import org.apache.spark.streaming.api.java.{JavaPairDStream, JavaReceiverInputDStream, JavaStreamingContext} import org.apache.spark.streaming.dstream.ReceiverInputDStream @@ -236,3 +242,71 @@ object FlumeUtils { createPollingStream(jssc.ssc, addresses, storageLevel, maxBatchSize, parallelism) } } + +/** + * This is a helper class that wraps the methods in FlumeUtils into more Python-friendly class and + * function so that it can be easily instantiated and called from Python's FlumeUtils. + */ +private class FlumeUtilsPythonHelper { + + def createStream( + jssc: JavaStreamingContext, + hostname: String, + port: Int, + storageLevel: StorageLevel, + enableDecompression: Boolean + ): JavaPairDStream[Array[Byte], Array[Byte]] = { + val dstream = FlumeUtils.createStream(jssc, hostname, port, storageLevel, enableDecompression) + FlumeUtilsPythonHelper.toByteArrayPairDStream(dstream) + } + + def createPollingStream( + jssc: JavaStreamingContext, + hosts: JList[String], + ports: JList[Int], + storageLevel: StorageLevel, + maxBatchSize: Int, + parallelism: Int + ): JavaPairDStream[Array[Byte], Array[Byte]] = { + assert(hosts.length == ports.length) + val addresses = hosts.zip(ports).map { + case (host, port) => new InetSocketAddress(host, port) + } + val dstream = FlumeUtils.createPollingStream( + jssc.ssc, addresses, storageLevel, maxBatchSize, parallelism) + FlumeUtilsPythonHelper.toByteArrayPairDStream(dstream) + } + +} + +private object FlumeUtilsPythonHelper { + + private def stringMapToByteArray(map: JMap[CharSequence, CharSequence]): Array[Byte] = { + val byteStream = new ByteArrayOutputStream() + val output = new DataOutputStream(byteStream) + try { + output.writeInt(map.size) + map.foreach { kv => + PythonRDD.writeUTF(kv._1.toString, output) + PythonRDD.writeUTF(kv._2.toString, output) + } + byteStream.toByteArray + } + finally { + output.close() + } + } + + private def toByteArrayPairDStream(dstream: JavaReceiverInputDStream[SparkFlumeEvent]): + JavaPairDStream[Array[Byte], Array[Byte]] = { + dstream.mapToPair(new PairFunction[SparkFlumeEvent, Array[Byte], Array[Byte]] { + override def call(sparkEvent: SparkFlumeEvent): (Array[Byte], Array[Byte]) = { + val event = sparkEvent.event + val byteBuffer = event.getBody + val body = new Array[Byte](byteBuffer.remaining()) + byteBuffer.get(body) + (stringMapToByteArray(event.getHeaders), body) + } + }) + } +} diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/PollingFlumeTestUtils.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/PollingFlumeTestUtils.scala new file mode 100644 index 0000000000000..91d63d49dbec3 --- /dev/null +++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/PollingFlumeTestUtils.scala @@ -0,0 +1,209 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.flume + +import java.util.concurrent._ +import java.util.{List => JList, Map => JMap} + +import scala.collection.JavaConversions._ +import scala.collection.mutable.ArrayBuffer + +import com.google.common.base.Charsets.UTF_8 +import org.apache.flume.event.EventBuilder +import org.apache.flume.Context +import org.apache.flume.channel.MemoryChannel +import org.apache.flume.conf.Configurables + +import org.apache.spark.streaming.flume.sink.{SparkSinkConfig, SparkSink} + +/** + * Share codes for Scala and Python unit tests + */ +private[flume] class PollingFlumeTestUtils { + + private val batchCount = 5 + val eventsPerBatch = 100 + private val totalEventsPerChannel = batchCount * eventsPerBatch + private val channelCapacity = 5000 + + def getTotalEvents: Int = totalEventsPerChannel * channels.size + + private val channels = new ArrayBuffer[MemoryChannel] + private val sinks = new ArrayBuffer[SparkSink] + + /** + * Start a sink and return the port of this sink + */ + def startSingleSink(): Int = { + channels.clear() + sinks.clear() + + // Start the channel and sink. + val context = new Context() + context.put("capacity", channelCapacity.toString) + context.put("transactionCapacity", "1000") + context.put("keep-alive", "0") + val channel = new MemoryChannel() + Configurables.configure(channel, context) + + val sink = new SparkSink() + context.put(SparkSinkConfig.CONF_HOSTNAME, "localhost") + context.put(SparkSinkConfig.CONF_PORT, String.valueOf(0)) + Configurables.configure(sink, context) + sink.setChannel(channel) + sink.start() + + channels += (channel) + sinks += sink + + sink.getPort() + } + + /** + * Start 2 sinks and return the ports + */ + def startMultipleSinks(): JList[Int] = { + channels.clear() + sinks.clear() + + // Start the channel and sink. + val context = new Context() + context.put("capacity", channelCapacity.toString) + context.put("transactionCapacity", "1000") + context.put("keep-alive", "0") + val channel = new MemoryChannel() + Configurables.configure(channel, context) + + val channel2 = new MemoryChannel() + Configurables.configure(channel2, context) + + val sink = new SparkSink() + context.put(SparkSinkConfig.CONF_HOSTNAME, "localhost") + context.put(SparkSinkConfig.CONF_PORT, String.valueOf(0)) + Configurables.configure(sink, context) + sink.setChannel(channel) + sink.start() + + val sink2 = new SparkSink() + context.put(SparkSinkConfig.CONF_HOSTNAME, "localhost") + context.put(SparkSinkConfig.CONF_PORT, String.valueOf(0)) + Configurables.configure(sink2, context) + sink2.setChannel(channel2) + sink2.start() + + sinks += sink + sinks += sink2 + channels += channel + channels += channel2 + + sinks.map(_.getPort()) + } + + /** + * Send data and wait until all data has been received + */ + def sendDatAndEnsureAllDataHasBeenReceived(): Unit = { + val executor = Executors.newCachedThreadPool() + val executorCompletion = new ExecutorCompletionService[Void](executor) + + val latch = new CountDownLatch(batchCount * channels.size) + sinks.foreach(_.countdownWhenBatchReceived(latch)) + + channels.foreach(channel => { + executorCompletion.submit(new TxnSubmitter(channel)) + }) + + for (i <- 0 until channels.size) { + executorCompletion.take() + } + + latch.await(15, TimeUnit.SECONDS) // Ensure all data has been received. + } + + /** + * A Python-friendly method to assert the output + */ + def assertOutput( + outputHeaders: JList[JMap[String, String]], outputBodies: JList[String]): Unit = { + require(outputHeaders.size == outputBodies.size) + val eventSize = outputHeaders.size + if (eventSize != totalEventsPerChannel * channels.size) { + throw new AssertionError( + s"Expected ${totalEventsPerChannel * channels.size} events, but was $eventSize") + } + var counter = 0 + for (k <- 0 until channels.size; i <- 0 until totalEventsPerChannel) { + val eventBodyToVerify = s"${channels(k).getName}-$i" + val eventHeaderToVerify: JMap[String, String] = Map[String, String](s"test-$i" -> "header") + var found = false + var j = 0 + while (j < eventSize && !found) { + if (eventBodyToVerify == outputBodies.get(j) && + eventHeaderToVerify == outputHeaders.get(j)) { + found = true + counter += 1 + } + j += 1 + } + } + if (counter != totalEventsPerChannel * channels.size) { + throw new AssertionError( + s"111 Expected ${totalEventsPerChannel * channels.size} events, but was $counter") + } + } + + def assertChannelsAreEmpty(): Unit = { + channels.foreach(assertChannelIsEmpty) + } + + private def assertChannelIsEmpty(channel: MemoryChannel): Unit = { + val queueRemaining = channel.getClass.getDeclaredField("queueRemaining") + queueRemaining.setAccessible(true) + val m = queueRemaining.get(channel).getClass.getDeclaredMethod("availablePermits") + if (m.invoke(queueRemaining.get(channel)).asInstanceOf[Int] != 5000) { + throw new AssertionError(s"Channel ${channel.getName} is not empty") + } + } + + def close(): Unit = { + sinks.foreach(_.stop()) + sinks.clear() + channels.foreach(_.stop()) + channels.clear() + } + + private class TxnSubmitter(channel: MemoryChannel) extends Callable[Void] { + override def call(): Void = { + var t = 0 + for (i <- 0 until batchCount) { + val tx = channel.getTransaction + tx.begin() + for (j <- 0 until eventsPerBatch) { + channel.put(EventBuilder.withBody(s"${channel.getName}-$t".getBytes(UTF_8), + Map[String, String](s"test-$t" -> "header"))) + t += 1 + } + tx.commit() + tx.close() + Thread.sleep(500) // Allow some time for the events to reach + } + null + } + } + +} diff --git a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala index d772b9ca9b570..d5f9a0aa38f9f 100644 --- a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala +++ b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala @@ -18,47 +18,33 @@ package org.apache.spark.streaming.flume import java.net.InetSocketAddress -import java.util.concurrent._ import scala.collection.JavaConversions._ import scala.collection.mutable.{SynchronizedBuffer, ArrayBuffer} import scala.concurrent.duration._ import scala.language.postfixOps -import org.apache.flume.Context -import org.apache.flume.channel.MemoryChannel -import org.apache.flume.conf.Configurables -import org.apache.flume.event.EventBuilder -import org.scalatest.concurrent.Eventually._ - +import com.google.common.base.Charsets.UTF_8 import org.scalatest.BeforeAndAfter +import org.scalatest.concurrent.Eventually._ import org.apache.spark.{Logging, SparkConf, SparkFunSuite} import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.dstream.ReceiverInputDStream import org.apache.spark.streaming.{Seconds, TestOutputStream, StreamingContext} -import org.apache.spark.streaming.flume.sink._ import org.apache.spark.util.{ManualClock, Utils} class FlumePollingStreamSuite extends SparkFunSuite with BeforeAndAfter with Logging { - val batchCount = 5 - val eventsPerBatch = 100 - val totalEventsPerChannel = batchCount * eventsPerBatch - val channelCapacity = 5000 val maxAttempts = 5 val batchDuration = Seconds(1) val conf = new SparkConf() .setMaster("local[2]") .setAppName(this.getClass.getSimpleName) + .set("spark.streaming.clock", "org.apache.spark.util.ManualClock") - def beforeFunction() { - logInfo("Using manual clock") - conf.set("spark.streaming.clock", "org.apache.spark.util.ManualClock") - } - - before(beforeFunction()) + val utils = new PollingFlumeTestUtils test("flume polling test") { testMultipleTimes(testFlumePolling) @@ -89,146 +75,55 @@ class FlumePollingStreamSuite extends SparkFunSuite with BeforeAndAfter with Log } private def testFlumePolling(): Unit = { - // Start the channel and sink. - val context = new Context() - context.put("capacity", channelCapacity.toString) - context.put("transactionCapacity", "1000") - context.put("keep-alive", "0") - val channel = new MemoryChannel() - Configurables.configure(channel, context) - - val sink = new SparkSink() - context.put(SparkSinkConfig.CONF_HOSTNAME, "localhost") - context.put(SparkSinkConfig.CONF_PORT, String.valueOf(0)) - Configurables.configure(sink, context) - sink.setChannel(channel) - sink.start() - - writeAndVerify(Seq(sink), Seq(channel)) - assertChannelIsEmpty(channel) - sink.stop() - channel.stop() + try { + val port = utils.startSingleSink() + + writeAndVerify(Seq(port)) + utils.assertChannelsAreEmpty() + } finally { + utils.close() + } } private def testFlumePollingMultipleHost(): Unit = { - // Start the channel and sink. - val context = new Context() - context.put("capacity", channelCapacity.toString) - context.put("transactionCapacity", "1000") - context.put("keep-alive", "0") - val channel = new MemoryChannel() - Configurables.configure(channel, context) - - val channel2 = new MemoryChannel() - Configurables.configure(channel2, context) - - val sink = new SparkSink() - context.put(SparkSinkConfig.CONF_HOSTNAME, "localhost") - context.put(SparkSinkConfig.CONF_PORT, String.valueOf(0)) - Configurables.configure(sink, context) - sink.setChannel(channel) - sink.start() - - val sink2 = new SparkSink() - context.put(SparkSinkConfig.CONF_HOSTNAME, "localhost") - context.put(SparkSinkConfig.CONF_PORT, String.valueOf(0)) - Configurables.configure(sink2, context) - sink2.setChannel(channel2) - sink2.start() try { - writeAndVerify(Seq(sink, sink2), Seq(channel, channel2)) - assertChannelIsEmpty(channel) - assertChannelIsEmpty(channel2) + val ports = utils.startMultipleSinks() + writeAndVerify(ports) + utils.assertChannelsAreEmpty() } finally { - sink.stop() - sink2.stop() - channel.stop() - channel2.stop() + utils.close() } } - def writeAndVerify(sinks: Seq[SparkSink], channels: Seq[MemoryChannel]) { + def writeAndVerify(sinkPorts: Seq[Int]): Unit = { // Set up the streaming context and input streams val ssc = new StreamingContext(conf, batchDuration) - val addresses = sinks.map(sink => new InetSocketAddress("localhost", sink.getPort())) + val addresses = sinkPorts.map(port => new InetSocketAddress("localhost", port)) val flumeStream: ReceiverInputDStream[SparkFlumeEvent] = FlumeUtils.createPollingStream(ssc, addresses, StorageLevel.MEMORY_AND_DISK, - eventsPerBatch, 5) + utils.eventsPerBatch, 5) val outputBuffer = new ArrayBuffer[Seq[SparkFlumeEvent]] with SynchronizedBuffer[Seq[SparkFlumeEvent]] val outputStream = new TestOutputStream(flumeStream, outputBuffer) outputStream.register() ssc.start() - val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] - val executor = Executors.newCachedThreadPool() - val executorCompletion = new ExecutorCompletionService[Void](executor) - - val latch = new CountDownLatch(batchCount * channels.size) - sinks.foreach(_.countdownWhenBatchReceived(latch)) - - channels.foreach(channel => { - executorCompletion.submit(new TxnSubmitter(channel, clock)) - }) - - for (i <- 0 until channels.size) { - executorCompletion.take() - } - - latch.await(15, TimeUnit.SECONDS) // Ensure all data has been received. - clock.advance(batchDuration.milliseconds) - - // The eventually is required to ensure that all data in the batch has been processed. - eventually(timeout(10 seconds), interval(100 milliseconds)) { - val flattenedBuffer = outputBuffer.flatten - assert(flattenedBuffer.size === totalEventsPerChannel * channels.size) - var counter = 0 - for (k <- 0 until channels.size; i <- 0 until totalEventsPerChannel) { - val eventToVerify = EventBuilder.withBody((channels(k).getName + " - " + - String.valueOf(i)).getBytes("utf-8"), - Map[String, String]("test-" + i.toString -> "header")) - var found = false - var j = 0 - while (j < flattenedBuffer.size && !found) { - val strToCompare = new String(flattenedBuffer(j).event.getBody.array(), "utf-8") - if (new String(eventToVerify.getBody, "utf-8") == strToCompare && - eventToVerify.getHeaders.get("test-" + i.toString) - .equals(flattenedBuffer(j).event.getHeaders.get("test-" + i.toString))) { - found = true - counter += 1 - } - j += 1 - } - } - assert(counter === totalEventsPerChannel * channels.size) - } - ssc.stop() - } - - def assertChannelIsEmpty(channel: MemoryChannel): Unit = { - val queueRemaining = channel.getClass.getDeclaredField("queueRemaining") - queueRemaining.setAccessible(true) - val m = queueRemaining.get(channel).getClass.getDeclaredMethod("availablePermits") - assert(m.invoke(queueRemaining.get(channel)).asInstanceOf[Int] === 5000) - } - - private class TxnSubmitter(channel: MemoryChannel, clock: ManualClock) extends Callable[Void] { - override def call(): Void = { - var t = 0 - for (i <- 0 until batchCount) { - val tx = channel.getTransaction - tx.begin() - for (j <- 0 until eventsPerBatch) { - channel.put(EventBuilder.withBody((channel.getName + " - " + String.valueOf(t)).getBytes( - "utf-8"), - Map[String, String]("test-" + t.toString -> "header"))) - t += 1 - } - tx.commit() - tx.close() - Thread.sleep(500) // Allow some time for the events to reach + try { + utils.sendDatAndEnsureAllDataHasBeenReceived() + val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] + clock.advance(batchDuration.milliseconds) + + // The eventually is required to ensure that all data in the batch has been processed. + eventually(timeout(10 seconds), interval(100 milliseconds)) { + val flattenOutputBuffer = outputBuffer.flatten + val headers = flattenOutputBuffer.map(_.event.getHeaders.map { + case kv => (kv._1.toString, kv._2.toString) + }).map(mapAsJavaMap) + val bodies = flattenOutputBuffer.map(e => new String(e.event.getBody.array(), UTF_8)) + utils.assertOutput(headers, bodies) } - null + } finally { + ssc.stop() } } diff --git a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala index c926359987d89..5bc4cdf65306c 100644 --- a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala +++ b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala @@ -17,20 +17,12 @@ package org.apache.spark.streaming.flume -import java.net.{InetSocketAddress, ServerSocket} -import java.nio.ByteBuffer - import scala.collection.JavaConversions._ import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer} import scala.concurrent.duration._ import scala.language.postfixOps import com.google.common.base.Charsets -import org.apache.avro.ipc.NettyTransceiver -import org.apache.avro.ipc.specific.SpecificRequestor -import org.apache.commons.lang3.RandomUtils -import org.apache.flume.source.avro -import org.apache.flume.source.avro.{AvroFlumeEvent, AvroSourceProtocol} import org.jboss.netty.channel.ChannelPipeline import org.jboss.netty.channel.socket.SocketChannel import org.jboss.netty.channel.socket.nio.NioClientSocketChannelFactory @@ -41,22 +33,10 @@ import org.scalatest.concurrent.Eventually._ import org.apache.spark.{Logging, SparkConf, SparkFunSuite} import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.{Milliseconds, StreamingContext, TestOutputStream} -import org.apache.spark.util.Utils class FlumeStreamSuite extends SparkFunSuite with BeforeAndAfter with Matchers with Logging { val conf = new SparkConf().setMaster("local[4]").setAppName("FlumeStreamSuite") - var ssc: StreamingContext = null - var transceiver: NettyTransceiver = null - - after { - if (ssc != null) { - ssc.stop() - } - if (transceiver != null) { - transceiver.close() - } - } test("flume input stream") { testFlumeStream(testCompression = false) @@ -69,19 +49,29 @@ class FlumeStreamSuite extends SparkFunSuite with BeforeAndAfter with Matchers w /** Run test on flume stream */ private def testFlumeStream(testCompression: Boolean): Unit = { val input = (1 to 100).map { _.toString } - val testPort = findFreePort() - val outputBuffer = startContext(testPort, testCompression) - writeAndVerify(input, testPort, outputBuffer, testCompression) - } + val utils = new FlumeTestUtils + try { + val outputBuffer = startContext(utils.getTestPort(), testCompression) - /** Find a free port */ - private def findFreePort(): Int = { - val candidatePort = RandomUtils.nextInt(1024, 65536) - Utils.startServiceOnPort(candidatePort, (trialPort: Int) => { - val socket = new ServerSocket(trialPort) - socket.close() - (null, trialPort) - }, conf)._2 + eventually(timeout(10 seconds), interval(100 milliseconds)) { + utils.writeInput(input, testCompression) + } + + eventually(timeout(10 seconds), interval(100 milliseconds)) { + val outputEvents = outputBuffer.flatten.map { _.event } + outputEvents.foreach { + event => + event.getHeaders.get("test") should be("header") + } + val output = outputEvents.map(event => new String(event.getBody.array(), Charsets.UTF_8)) + output should be (input) + } + } finally { + if (ssc != null) { + ssc.stop() + } + utils.close() + } } /** Setup and start the streaming context */ @@ -98,58 +88,6 @@ class FlumeStreamSuite extends SparkFunSuite with BeforeAndAfter with Matchers w outputBuffer } - /** Send data to the flume receiver and verify whether the data was received */ - private def writeAndVerify( - input: Seq[String], - testPort: Int, - outputBuffer: ArrayBuffer[Seq[SparkFlumeEvent]], - enableCompression: Boolean - ) { - val testAddress = new InetSocketAddress("localhost", testPort) - - val inputEvents = input.map { item => - val event = new AvroFlumeEvent - event.setBody(ByteBuffer.wrap(item.getBytes(Charsets.UTF_8))) - event.setHeaders(Map[CharSequence, CharSequence]("test" -> "header")) - event - } - - eventually(timeout(10 seconds), interval(100 milliseconds)) { - // if last attempted transceiver had succeeded, close it - if (transceiver != null) { - transceiver.close() - transceiver = null - } - - // Create transceiver - transceiver = { - if (enableCompression) { - new NettyTransceiver(testAddress, new CompressionChannelFactory(6)) - } else { - new NettyTransceiver(testAddress) - } - } - - // Create Avro client with the transceiver - val client = SpecificRequestor.getClient(classOf[AvroSourceProtocol], transceiver) - client should not be null - - // Send data - val status = client.appendBatch(inputEvents.toList) - status should be (avro.Status.OK) - } - - eventually(timeout(10 seconds), interval(100 milliseconds)) { - val outputEvents = outputBuffer.flatten.map { _.event } - outputEvents.foreach { - event => - event.getHeaders.get("test") should be("header") - } - val output = outputEvents.map(event => new String(event.getBody.array(), Charsets.UTF_8)) - output should be (input) - } - } - /** Class to create socket channel with compression */ private class CompressionChannelFactory(compressionLevel: Int) extends NioClientSocketChannelFactory { diff --git a/extras/java8-tests/pom.xml b/extras/java8-tests/pom.xml index f138251748c9e..3636a9037d43f 100644 --- a/extras/java8-tests/pom.xml +++ b/extras/java8-tests/pom.xml @@ -39,6 +39,13 @@ spark-core_${scala.binary.version} ${project.version} + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + org.apache.spark spark-streaming_${scala.binary.version} @@ -49,6 +56,7 @@ spark-streaming_${scala.binary.version} ${project.version} test-jar + test junit diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index a66a404d5c846..458fab48fef5a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -75,6 +75,15 @@ private[python] class PythonMLLibAPI extends Serializable { minPartitions: Int): JavaRDD[LabeledPoint] = MLUtils.loadLabeledPoints(jsc.sc, path, minPartitions) + /** + * Loads and serializes vectors saved with `RDD#saveAsTextFile`. + * @param jsc Java SparkContext + * @param path file or directory path in any Hadoop-supported file system URI + * @return serialized vectors in a RDD + */ + def loadVectors(jsc: JavaSparkContext, path: String): RDD[Vector] = + MLUtils.loadVectors(jsc.sc, path) + private def trainRegressionModel( learner: GeneralizedLinearAlgorithm[_ <: GeneralizedLinearModel], data: JavaRDD[LabeledPoint], diff --git a/pom.xml b/pom.xml index 94dd512cfb618..211da9ee74a3f 100644 --- a/pom.xml +++ b/pom.xml @@ -102,6 +102,7 @@ external/twitter external/flume external/flume-sink + external/flume-assembly external/mqtt external/zeromq examples diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index f5f1c9a1a247a..5f389bcc9ceeb 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -45,8 +45,8 @@ object BuildCommons { sparkKinesisAsl) = Seq("yarn", "yarn-stable", "java8-tests", "ganglia-lgpl", "kinesis-asl").map(ProjectRef(buildLocation, _)) - val assemblyProjects@Seq(assembly, examples, networkYarn, streamingKafkaAssembly) = - Seq("assembly", "examples", "network-yarn", "streaming-kafka-assembly") + val assemblyProjects@Seq(assembly, examples, networkYarn, streamingFlumeAssembly, streamingKafkaAssembly) = + Seq("assembly", "examples", "network-yarn", "streaming-flume-assembly", "streaming-kafka-assembly") .map(ProjectRef(buildLocation, _)) val tools = ProjectRef(buildLocation, "tools") @@ -161,7 +161,7 @@ object SparkBuild extends PomBuild { // Note ordering of these settings matter. /* Enable shared settings on all projects */ (allProjects ++ optionallyEnabledProjects ++ assemblyProjects ++ Seq(spark, tools)) - .foreach(enable(sharedSettings ++ ExludedDependencies.settings ++ Revolver.settings)) + .foreach(enable(sharedSettings ++ ExcludedDependencies.settings ++ Revolver.settings)) /* Enable tests settings for all projects except examples, assembly and tools */ (allProjects ++ optionallyEnabledProjects).foreach(enable(TestSettings.settings)) @@ -246,7 +246,7 @@ object Flume { This excludes library dependencies in sbt, which are specified in maven but are not needed by sbt build. */ -object ExludedDependencies { +object ExcludedDependencies { lazy val settings = Seq( libraryDependencies ~= { libs => libs.filterNot(_.name == "groovy-all") } ) @@ -347,7 +347,7 @@ object Assembly { .getOrElse(SbtPomKeys.effectivePom.value.getProperties.get("hadoop.version").asInstanceOf[String]) }, jarName in assembly <<= (version, moduleName, hadoopVersion) map { (v, mName, hv) => - if (mName.contains("streaming-kafka-assembly")) { + if (mName.contains("streaming-flume-assembly") || mName.contains("streaming-kafka-assembly")) { // This must match the same name used in maven (see external/kafka-assembly/pom.xml) s"${mName}-${v}.jar" } else { diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py index e3c8a24c4a751..a3eab635282f6 100644 --- a/python/pyspark/mllib/clustering.py +++ b/python/pyspark/mllib/clustering.py @@ -288,16 +288,12 @@ class PowerIterationClusteringModel(JavaModelWrapper, JavaSaveable, JavaLoader): >>> model = PowerIterationClustering.train(rdd, 2, 100) >>> model.k 2 - >>> sorted(model.assignments().collect()) - [Assignment(id=0, cluster=1), Assignment(id=1, cluster=0), ... >>> import os, tempfile >>> path = tempfile.mkdtemp() >>> model.save(sc, path) >>> sameModel = PowerIterationClusteringModel.load(sc, path) >>> sameModel.k 2 - >>> sorted(sameModel.assignments().collect()) - [Assignment(id=0, cluster=1), Assignment(id=1, cluster=0), ... >>> from shutil import rmtree >>> try: ... rmtree(path) diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index f0091d6faccce..49ce125de7e78 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -54,6 +54,7 @@ from pyspark.mllib.feature import IDF from pyspark.mllib.feature import StandardScaler, ElementwiseProduct from pyspark.mllib.util import LinearDataGenerator +from pyspark.mllib.util import MLUtils from pyspark.serializers import PickleSerializer from pyspark.streaming import StreamingContext from pyspark.sql import SQLContext @@ -1290,6 +1291,48 @@ def func(rdd): self.assertTrue(mean_absolute_errors[1] - mean_absolute_errors[-1] > 2) +class MLUtilsTests(MLlibTestCase): + def test_append_bias(self): + data = [2.0, 2.0, 2.0] + ret = MLUtils.appendBias(data) + self.assertEqual(ret[3], 1.0) + self.assertEqual(type(ret), DenseVector) + + def test_append_bias_with_vector(self): + data = Vectors.dense([2.0, 2.0, 2.0]) + ret = MLUtils.appendBias(data) + self.assertEqual(ret[3], 1.0) + self.assertEqual(type(ret), DenseVector) + + def test_append_bias_with_sp_vector(self): + data = Vectors.sparse(3, {0: 2.0, 2: 2.0}) + expected = Vectors.sparse(4, {0: 2.0, 2: 2.0, 3: 1.0}) + # Returned value must be SparseVector + ret = MLUtils.appendBias(data) + self.assertEqual(ret, expected) + self.assertEqual(type(ret), SparseVector) + + def test_load_vectors(self): + import shutil + data = [ + [1.0, 2.0, 3.0], + [1.0, 2.0, 3.0] + ] + temp_dir = tempfile.mkdtemp() + load_vectors_path = os.path.join(temp_dir, "test_load_vectors") + try: + self.sc.parallelize(data).saveAsTextFile(load_vectors_path) + ret_rdd = MLUtils.loadVectors(self.sc, load_vectors_path) + ret = ret_rdd.collect() + self.assertEqual(len(ret), 2) + self.assertEqual(ret[0], DenseVector([1.0, 2.0, 3.0])) + self.assertEqual(ret[1], DenseVector([1.0, 2.0, 3.0])) + except: + self.fail() + finally: + shutil.rmtree(load_vectors_path) + + if __name__ == "__main__": if not _have_scipy: print("NOTE: Skipping SciPy tests as it does not seem to be installed") diff --git a/python/pyspark/mllib/util.py b/python/pyspark/mllib/util.py index 348238319e407..875d3b2d642c6 100644 --- a/python/pyspark/mllib/util.py +++ b/python/pyspark/mllib/util.py @@ -169,6 +169,28 @@ def loadLabeledPoints(sc, path, minPartitions=None): minPartitions = minPartitions or min(sc.defaultParallelism, 2) return callMLlibFunc("loadLabeledPoints", sc, path, minPartitions) + @staticmethod + def appendBias(data): + """ + Returns a new vector with `1.0` (bias) appended to + the end of the input vector. + """ + vec = _convert_to_vector(data) + if isinstance(vec, SparseVector): + newIndices = np.append(vec.indices, len(vec)) + newValues = np.append(vec.values, 1.0) + return SparseVector(len(vec) + 1, newIndices, newValues) + else: + return _convert_to_vector(np.append(vec.toArray(), 1.0)) + + @staticmethod + def loadVectors(sc, path): + """ + Loads vectors saved using `RDD[Vector].saveAsTextFile` + with the default number of partitions. + """ + return callMLlibFunc("loadVectors", sc, path) + class Saveable(object): """ diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 4b9efa0a210fb..1e9c657cf81b3 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -484,13 +484,12 @@ def dtypes(self): return [(str(f.name), f.dataType.simpleString()) for f in self.schema.fields] @property - @ignore_unicode_prefix @since(1.3) def columns(self): """Returns all column names as a list. >>> df.columns - [u'age', u'name'] + ['age', 'name'] """ return [f.name for f in self.schema.fields] @@ -803,11 +802,11 @@ def groupBy(self, *cols): Each element should be a column name (string) or an expression (:class:`Column`). >>> df.groupBy().avg().collect() - [Row(AVG(age)=3.5)] + [Row(avg(age)=3.5)] >>> df.groupBy('name').agg({'age': 'mean'}).collect() - [Row(name=u'Alice', AVG(age)=2.0), Row(name=u'Bob', AVG(age)=5.0)] + [Row(name=u'Alice', avg(age)=2.0), Row(name=u'Bob', avg(age)=5.0)] >>> df.groupBy(df.name).avg().collect() - [Row(name=u'Alice', AVG(age)=2.0), Row(name=u'Bob', AVG(age)=5.0)] + [Row(name=u'Alice', avg(age)=2.0), Row(name=u'Bob', avg(age)=5.0)] >>> df.groupBy(['name', df.age]).count().collect() [Row(name=u'Bob', age=5, count=1), Row(name=u'Alice', age=2, count=1)] """ @@ -865,10 +864,10 @@ def agg(self, *exprs): (shorthand for ``df.groupBy.agg()``). >>> df.agg({"age": "max"}).collect() - [Row(MAX(age)=5)] + [Row(max(age)=5)] >>> from pyspark.sql import functions as F >>> df.agg(F.min(df.age)).collect() - [Row(MIN(age)=2)] + [Row(min(age)=2)] """ return self.groupBy().agg(*exprs) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index f4c5ec5d6d131..1fa24095dbaaf 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -266,7 +266,7 @@ def coalesce(*cols): >>> cDf.select(coalesce(cDf["a"], cDf["b"])).show() +-------------+ - |Coalesce(a,b)| + |coalesce(a,b)| +-------------+ | null| | 1| @@ -275,7 +275,7 @@ def coalesce(*cols): >>> cDf.select('*', coalesce(cDf["a"], lit(0.0))).show() +----+----+---------------+ - | a| b|Coalesce(a,0.0)| + | a| b|coalesce(a,0.0)| +----+----+---------------+ |null|null| 0.0| | 1|null| 1.0| diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py index 5a37a673ee80c..04594d5a836ce 100644 --- a/python/pyspark/sql/group.py +++ b/python/pyspark/sql/group.py @@ -75,11 +75,11 @@ def agg(self, *exprs): >>> gdf = df.groupBy(df.name) >>> gdf.agg({"*": "count"}).collect() - [Row(name=u'Alice', COUNT(1)=1), Row(name=u'Bob', COUNT(1)=1)] + [Row(name=u'Alice', count(1)=1), Row(name=u'Bob', count(1)=1)] >>> from pyspark.sql import functions as F >>> gdf.agg(F.min(df.age)).collect() - [Row(name=u'Alice', MIN(age)=2), Row(name=u'Bob', MIN(age)=5)] + [Row(name=u'Alice', min(age)=2), Row(name=u'Bob', min(age)=5)] """ assert exprs, "exprs should not be empty" if len(exprs) == 1 and isinstance(exprs[0], dict): @@ -110,9 +110,9 @@ def mean(self, *cols): :param cols: list of column names (string). Non-numeric columns are ignored. >>> df.groupBy().mean('age').collect() - [Row(AVG(age)=3.5)] + [Row(avg(age)=3.5)] >>> df3.groupBy().mean('age', 'height').collect() - [Row(AVG(age)=3.5, AVG(height)=82.5)] + [Row(avg(age)=3.5, avg(height)=82.5)] """ @df_varargs_api @@ -125,9 +125,9 @@ def avg(self, *cols): :param cols: list of column names (string). Non-numeric columns are ignored. >>> df.groupBy().avg('age').collect() - [Row(AVG(age)=3.5)] + [Row(avg(age)=3.5)] >>> df3.groupBy().avg('age', 'height').collect() - [Row(AVG(age)=3.5, AVG(height)=82.5)] + [Row(avg(age)=3.5, avg(height)=82.5)] """ @df_varargs_api @@ -136,9 +136,9 @@ def max(self, *cols): """Computes the max value for each numeric columns for each group. >>> df.groupBy().max('age').collect() - [Row(MAX(age)=5)] + [Row(max(age)=5)] >>> df3.groupBy().max('age', 'height').collect() - [Row(MAX(age)=5, MAX(height)=85)] + [Row(max(age)=5, max(height)=85)] """ @df_varargs_api @@ -149,9 +149,9 @@ def min(self, *cols): :param cols: list of column names (string). Non-numeric columns are ignored. >>> df.groupBy().min('age').collect() - [Row(MIN(age)=2)] + [Row(min(age)=2)] >>> df3.groupBy().min('age', 'height').collect() - [Row(MIN(age)=2, MIN(height)=80)] + [Row(min(age)=2, min(height)=80)] """ @df_varargs_api @@ -162,9 +162,9 @@ def sum(self, *cols): :param cols: list of column names (string). Non-numeric columns are ignored. >>> df.groupBy().sum('age').collect() - [Row(SUM(age)=7)] + [Row(sum(age)=7)] >>> df3.groupBy().sum('age', 'height').collect() - [Row(SUM(age)=7, SUM(height)=165)] + [Row(sum(age)=7, sum(height)=165)] """ diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 5af2ce09bc122..333378c7f1854 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -1,3 +1,4 @@ +# -*- encoding: utf-8 -*- # # Licensed to the Apache Software Foundation (ASF) under one or more # contributor license agreements. See the NOTICE file distributed with @@ -628,6 +629,14 @@ def test_access_column(self): self.assertRaises(IndexError, lambda: df["bad_key"]) self.assertRaises(TypeError, lambda: df[{}]) + def test_column_name_with_non_ascii(self): + df = self.sqlCtx.createDataFrame([(1,)], ["数量"]) + self.assertEqual(StructType([StructField("数量", LongType(), True)]), df.schema) + self.assertEqual("DataFrame[数量: bigint]", str(df)) + self.assertEqual([("数量", 'bigint')], df.dtypes) + self.assertEqual(1, df.select("数量").first()[0]) + self.assertEqual(1, df.select(df["数量"]).first()[0]) + def test_access_nested_types(self): df = self.sc.parallelize([Row(l=[1], r=Row(a=1, b="b"), d={"k": "v"})]).toDF() self.assertEqual(1, df.select(df.l[0]).first()[0]) diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index ae9344e6106a4..160df40d65cc1 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -324,6 +324,8 @@ def __init__(self, name, dataType, nullable=True, metadata=None): False """ assert isinstance(dataType, DataType), "dataType should be DataType" + if not isinstance(name, str): + name = name.encode('utf-8') self.name = name self.dataType = dataType self.nullable = nullable diff --git a/python/pyspark/sql/utils.py b/python/pyspark/sql/utils.py index 8096802e7302f..cc5b2c088b7cc 100644 --- a/python/pyspark/sql/utils.py +++ b/python/pyspark/sql/utils.py @@ -29,9 +29,9 @@ def deco(*a, **kw): try: return f(*a, **kw) except py4j.protocol.Py4JJavaError as e: - cls, msg = e.java_exception.toString().split(': ', 1) - if cls == 'org.apache.spark.sql.AnalysisException': - raise AnalysisException(msg) + s = e.java_exception.toString() + if s.startswith('org.apache.spark.sql.AnalysisException: '): + raise AnalysisException(s.split(': ', 1)[1]) raise return deco diff --git a/python/pyspark/streaming/flume.py b/python/pyspark/streaming/flume.py new file mode 100644 index 0000000000000..cbb573f226bbe --- /dev/null +++ b/python/pyspark/streaming/flume.py @@ -0,0 +1,147 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import sys +if sys.version >= "3": + from io import BytesIO +else: + from StringIO import StringIO +from py4j.java_gateway import Py4JJavaError + +from pyspark.storagelevel import StorageLevel +from pyspark.serializers import PairDeserializer, NoOpSerializer, UTF8Deserializer, read_int +from pyspark.streaming import DStream + +__all__ = ['FlumeUtils', 'utf8_decoder'] + + +def utf8_decoder(s): + """ Decode the unicode as UTF-8 """ + return s and s.decode('utf-8') + + +class FlumeUtils(object): + + @staticmethod + def createStream(ssc, hostname, port, + storageLevel=StorageLevel.MEMORY_AND_DISK_SER_2, + enableDecompression=False, + bodyDecoder=utf8_decoder): + """ + Create an input stream that pulls events from Flume. + + :param ssc: StreamingContext object + :param hostname: Hostname of the slave machine to which the flume data will be sent + :param port: Port of the slave machine to which the flume data will be sent + :param storageLevel: Storage level to use for storing the received objects + :param enableDecompression: Should netty server decompress input stream + :param bodyDecoder: A function used to decode body (default is utf8_decoder) + :return: A DStream object + """ + jlevel = ssc._sc._getJavaStorageLevel(storageLevel) + + try: + helperClass = ssc._jvm.java.lang.Thread.currentThread().getContextClassLoader()\ + .loadClass("org.apache.spark.streaming.flume.FlumeUtilsPythonHelper") + helper = helperClass.newInstance() + jstream = helper.createStream(ssc._jssc, hostname, port, jlevel, enableDecompression) + except Py4JJavaError as e: + if 'ClassNotFoundException' in str(e.java_exception): + FlumeUtils._printErrorMsg(ssc.sparkContext) + raise e + + return FlumeUtils._toPythonDStream(ssc, jstream, bodyDecoder) + + @staticmethod + def createPollingStream(ssc, addresses, + storageLevel=StorageLevel.MEMORY_AND_DISK_SER_2, + maxBatchSize=1000, + parallelism=5, + bodyDecoder=utf8_decoder): + """ + Creates an input stream that is to be used with the Spark Sink deployed on a Flume agent. + This stream will poll the sink for data and will pull events as they are available. + + :param ssc: StreamingContext object + :param addresses: List of (host, port)s on which the Spark Sink is running. + :param storageLevel: Storage level to use for storing the received objects + :param maxBatchSize: The maximum number of events to be pulled from the Spark sink + in a single RPC call + :param parallelism: Number of concurrent requests this stream should send to the sink. + Note that having a higher number of requests concurrently being pulled + will result in this stream using more threads + :param bodyDecoder: A function used to decode body (default is utf8_decoder) + :return: A DStream object + """ + jlevel = ssc._sc._getJavaStorageLevel(storageLevel) + hosts = [] + ports = [] + for (host, port) in addresses: + hosts.append(host) + ports.append(port) + + try: + helperClass = ssc._jvm.java.lang.Thread.currentThread().getContextClassLoader() \ + .loadClass("org.apache.spark.streaming.flume.FlumeUtilsPythonHelper") + helper = helperClass.newInstance() + jstream = helper.createPollingStream( + ssc._jssc, hosts, ports, jlevel, maxBatchSize, parallelism) + except Py4JJavaError as e: + if 'ClassNotFoundException' in str(e.java_exception): + FlumeUtils._printErrorMsg(ssc.sparkContext) + raise e + + return FlumeUtils._toPythonDStream(ssc, jstream, bodyDecoder) + + @staticmethod + def _toPythonDStream(ssc, jstream, bodyDecoder): + ser = PairDeserializer(NoOpSerializer(), NoOpSerializer()) + stream = DStream(jstream, ssc, ser) + + def func(event): + headersBytes = BytesIO(event[0]) if sys.version >= "3" else StringIO(event[0]) + headers = {} + strSer = UTF8Deserializer() + for i in range(0, read_int(headersBytes)): + key = strSer.loads(headersBytes) + value = strSer.loads(headersBytes) + headers[key] = value + body = bodyDecoder(event[1]) + return (headers, body) + return stream.map(func) + + @staticmethod + def _printErrorMsg(sc): + print(""" +________________________________________________________________________________________________ + + Spark Streaming's Flume libraries not found in class path. Try one of the following. + + 1. Include the Flume library and its dependencies with in the + spark-submit command as + + $ bin/spark-submit --packages org.apache.spark:spark-streaming-flume:%s ... + + 2. Download the JAR of the artifact from Maven Central http://search.maven.org/, + Group Id = org.apache.spark, Artifact Id = spark-streaming-flume-assembly, Version = %s. + Then, include the jar in the spark-submit command as + + $ bin/spark-submit --jars ... + +________________________________________________________________________________________________ + +""" % (sc.version, sc.version)) diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index 91ce681fbe169..188c8ff12067e 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -38,6 +38,7 @@ from pyspark.context import SparkConf, SparkContext, RDD from pyspark.streaming.context import StreamingContext from pyspark.streaming.kafka import Broker, KafkaUtils, OffsetRange, TopicAndPartition +from pyspark.streaming.flume import FlumeUtils class PySparkStreamingTestCase(unittest.TestCase): @@ -677,7 +678,156 @@ def test_kafka_rdd_with_leaders(self): rdd = KafkaUtils.createRDD(self.sc, kafkaParams, offsetRanges, leaders) self._validateRddResult(sendData, rdd) -if __name__ == "__main__": + +class FlumeStreamTests(PySparkStreamingTestCase): + timeout = 20 # seconds + duration = 1 + + def setUp(self): + super(FlumeStreamTests, self).setUp() + + utilsClz = self.ssc._jvm.java.lang.Thread.currentThread().getContextClassLoader() \ + .loadClass("org.apache.spark.streaming.flume.FlumeTestUtils") + self._utils = utilsClz.newInstance() + + def tearDown(self): + if self._utils is not None: + self._utils.close() + self._utils = None + + super(FlumeStreamTests, self).tearDown() + + def _startContext(self, n, compressed): + # Start the StreamingContext and also collect the result + dstream = FlumeUtils.createStream(self.ssc, "localhost", self._utils.getTestPort(), + enableDecompression=compressed) + result = [] + + def get_output(_, rdd): + for event in rdd.collect(): + if len(result) < n: + result.append(event) + dstream.foreachRDD(get_output) + self.ssc.start() + return result + + def _validateResult(self, input, result): + # Validate both the header and the body + header = {"test": "header"} + self.assertEqual(len(input), len(result)) + for i in range(0, len(input)): + self.assertEqual(header, result[i][0]) + self.assertEqual(input[i], result[i][1]) + + def _writeInput(self, input, compressed): + # Try to write input to the receiver until success or timeout + start_time = time.time() + while True: + try: + self._utils.writeInput(input, compressed) + break + except: + if time.time() - start_time < self.timeout: + time.sleep(0.01) + else: + raise + + def test_flume_stream(self): + input = [str(i) for i in range(1, 101)] + result = self._startContext(len(input), False) + self._writeInput(input, False) + self.wait_for(result, len(input)) + self._validateResult(input, result) + + def test_compressed_flume_stream(self): + input = [str(i) for i in range(1, 101)] + result = self._startContext(len(input), True) + self._writeInput(input, True) + self.wait_for(result, len(input)) + self._validateResult(input, result) + + +class FlumePollingStreamTests(PySparkStreamingTestCase): + timeout = 20 # seconds + duration = 1 + maxAttempts = 5 + + def setUp(self): + utilsClz = \ + self.sc._jvm.java.lang.Thread.currentThread().getContextClassLoader() \ + .loadClass("org.apache.spark.streaming.flume.PollingFlumeTestUtils") + self._utils = utilsClz.newInstance() + + def tearDown(self): + if self._utils is not None: + self._utils.close() + self._utils = None + + def _writeAndVerify(self, ports): + # Set up the streaming context and input streams + ssc = StreamingContext(self.sc, self.duration) + try: + addresses = [("localhost", port) for port in ports] + dstream = FlumeUtils.createPollingStream( + ssc, + addresses, + maxBatchSize=self._utils.eventsPerBatch(), + parallelism=5) + outputBuffer = [] + + def get_output(_, rdd): + for e in rdd.collect(): + outputBuffer.append(e) + + dstream.foreachRDD(get_output) + ssc.start() + self._utils.sendDatAndEnsureAllDataHasBeenReceived() + + self.wait_for(outputBuffer, self._utils.getTotalEvents()) + outputHeaders = [event[0] for event in outputBuffer] + outputBodies = [event[1] for event in outputBuffer] + self._utils.assertOutput(outputHeaders, outputBodies) + finally: + ssc.stop(False) + + def _testMultipleTimes(self, f): + attempt = 0 + while True: + try: + f() + break + except: + attempt += 1 + if attempt >= self.maxAttempts: + raise + else: + import traceback + traceback.print_exc() + + def _testFlumePolling(self): + try: + port = self._utils.startSingleSink() + self._writeAndVerify([port]) + self._utils.assertChannelsAreEmpty() + finally: + self._utils.close() + + def _testFlumePollingMultipleHosts(self): + try: + port = self._utils.startSingleSink() + self._writeAndVerify([port]) + self._utils.assertChannelsAreEmpty() + finally: + self._utils.close() + + def test_flume_polling(self): + self._testMultipleTimes(self._testFlumePolling) + + def test_flume_polling_multiple_hosts(self): + self._testMultipleTimes(self._testFlumePollingMultipleHosts) + + +def search_kafka_assembly_jar(): SPARK_HOME = os.environ["SPARK_HOME"] kafka_assembly_dir = os.path.join(SPARK_HOME, "external/kafka-assembly") jars = glob.glob( @@ -692,5 +842,30 @@ def test_kafka_rdd_with_leaders(self): raise Exception(("Found multiple Spark Streaming Kafka assembly JARs in %s; please " "remove all but one") % kafka_assembly_dir) else: - os.environ["PYSPARK_SUBMIT_ARGS"] = "--jars %s pyspark-shell" % jars[0] + return jars[0] + + +def search_flume_assembly_jar(): + SPARK_HOME = os.environ["SPARK_HOME"] + flume_assembly_dir = os.path.join(SPARK_HOME, "external/flume-assembly") + jars = glob.glob( + os.path.join(flume_assembly_dir, "target/scala-*/spark-streaming-flume-assembly-*.jar")) + if not jars: + raise Exception( + ("Failed to find Spark Streaming Flume assembly jar in %s. " % flume_assembly_dir) + + "You need to build Spark with " + "'build/sbt assembly/assembly streaming-flume-assembly/assembly' or " + "'build/mvn package' before running this test") + elif len(jars) > 1: + raise Exception(("Found multiple Spark Streaming Flume assembly JARs in %s; please " + "remove all but one") % flume_assembly_dir) + else: + return jars[0] + +if __name__ == "__main__": + kafka_assembly_jar = search_kafka_assembly_jar() + flume_assembly_jar = search_flume_assembly_jar() + jars = "%s,%s" % (kafka_assembly_jar, flume_assembly_jar) + + os.environ["PYSPARK_SUBMIT_ARGS"] = "--jars %s pyspark-shell" % jars unittest.main() diff --git a/python/run-tests.py b/python/run-tests.py index b7737650daa54..7638854def2e8 100755 --- a/python/run-tests.py +++ b/python/run-tests.py @@ -31,6 +31,23 @@ import Queue else: import queue as Queue +if sys.version_info >= (2, 7): + subprocess_check_output = subprocess.check_output +else: + # SPARK-8763 + # backported from subprocess module in Python 2.7 + def subprocess_check_output(*popenargs, **kwargs): + if 'stdout' in kwargs: + raise ValueError('stdout argument not allowed, it will be overridden.') + process = subprocess.Popen(stdout=subprocess.PIPE, *popenargs, **kwargs) + output, unused_err = process.communicate() + retcode = process.poll() + if retcode: + cmd = kwargs.get("args") + if cmd is None: + cmd = popenargs[0] + raise subprocess.CalledProcessError(retcode, cmd, output=output) + return output # Append `SPARK_HOME/dev` to the Python path so that we can import the sparktestsupport module @@ -156,11 +173,11 @@ def main(): task_queue = Queue.Queue() for python_exec in python_execs: - python_implementation = subprocess.check_output( + python_implementation = subprocess_check_output( [python_exec, "-c", "import platform; print(platform.python_implementation())"], universal_newlines=True).strip() LOGGER.debug("%s python_implementation is %s", python_exec, python_implementation) - LOGGER.debug("%s version is: %s", python_exec, subprocess.check_output( + LOGGER.debug("%s version is: %s", python_exec, subprocess_check_output( [python_exec, "--version"], stderr=subprocess.STDOUT, universal_newlines=True).strip()) for module in modules_to_test: if python_implementation not in module.blacklisted_python_implementations: diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 117c87a785fdb..15e84e68b9881 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -43,7 +43,7 @@ class Analyzer( registry: FunctionRegistry, conf: CatalystConf, maxIterations: Int = 100) - extends RuleExecutor[LogicalPlan] with HiveTypeCoercion with CheckAnalysis { + extends RuleExecutor[LogicalPlan] with CheckAnalysis { def resolver: Resolver = { if (conf.caseSensitiveAnalysis) { @@ -76,7 +76,7 @@ class Analyzer( ExtractWindowExpressions :: GlobalAggregates :: UnresolvedHavingClauseAttributes :: - typeCoercionRules ++ + HiveTypeCoercion.typeCoercionRules ++ extendedResolutionRules : _*) ) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index a069b4710f38c..583338da57117 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -26,7 +26,6 @@ import org.apache.spark.sql.types._ * Throws user facing errors when passed invalid queries that fail to analyze. */ trait CheckAnalysis { - self: Analyzer => /** * Override to provide additional checks for correct analysis. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index e525ad623ff12..8420c54f7c335 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -22,7 +22,32 @@ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, Union} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.types._ + +/** + * A collection of [[Rule Rules]] that can be used to coerce differing types that + * participate in operations into compatible ones. Most of these rules are based on Hive semantics, + * but they do not introduce any dependencies on the hive codebase. For this reason they remain in + * Catalyst until we have a more standard set of coercions. + */ object HiveTypeCoercion { + + val typeCoercionRules = + PropagateTypes :: + ConvertNaNs :: + InConversion :: + WidenTypes :: + PromoteStrings :: + DecimalPrecision :: + BooleanEquality :: + StringToIntegralCasts :: + FunctionArgumentConversion :: + CaseWhenCoercion :: + IfCoercion :: + Division :: + PropagateTypes :: + ImplicitTypeCasts :: + Nil + // See https://cwiki.apache.org/confluence/display/Hive/LanguageManual+Types. // The conversion for integral and floating point types have a linear widening hierarchy: private val numericPrecedence = @@ -79,7 +104,6 @@ object HiveTypeCoercion { }) } - /** * Find the tightest common type of a set of types by continuously applying * `findTightestCommonTypeOfTwo` on these types. @@ -90,34 +114,6 @@ object HiveTypeCoercion { case Some(d) => findTightestCommonTypeOfTwo(d, c) }) } -} - -/** - * A collection of [[Rule Rules]] that can be used to coerce differing types that - * participate in operations into compatible ones. Most of these rules are based on Hive semantics, - * but they do not introduce any dependencies on the hive codebase. For this reason they remain in - * Catalyst until we have a more standard set of coercions. - */ -trait HiveTypeCoercion { - - import HiveTypeCoercion._ - - val typeCoercionRules = - PropagateTypes :: - ConvertNaNs :: - InConversion :: - WidenTypes :: - PromoteStrings :: - DecimalPrecision :: - BooleanEquality :: - StringToIntegralCasts :: - FunctionArgumentConversion :: - CaseWhenCoercion :: - IfCoercion :: - Division :: - PropagateTypes :: - AddCastForAutoCastInputTypes :: - Nil /** * Applies any changes to [[AttributeReference]] data types that are made by other rules to @@ -154,6 +150,7 @@ trait HiveTypeCoercion { * Converts string "NaN"s that are in binary operators with a NaN-able types (Float / Double) to * the appropriate numeric equivalent. */ + // TODO: remove this rule and make Cast handle Nan. object ConvertNaNs extends Rule[LogicalPlan] { private val StringNaN = Literal("NaN") @@ -163,19 +160,19 @@ trait HiveTypeCoercion { case e if !e.childrenResolved => e /* Double Conversions */ - case b @ BinaryExpression(StringNaN, right @ DoubleType()) => + case b @ BinaryOperator(StringNaN, right @ DoubleType()) => b.makeCopy(Array(Literal(Double.NaN), right)) - case b @ BinaryExpression(left @ DoubleType(), StringNaN) => + case b @ BinaryOperator(left @ DoubleType(), StringNaN) => b.makeCopy(Array(left, Literal(Double.NaN))) /* Float Conversions */ - case b @ BinaryExpression(StringNaN, right @ FloatType()) => + case b @ BinaryOperator(StringNaN, right @ FloatType()) => b.makeCopy(Array(Literal(Float.NaN), right)) - case b @ BinaryExpression(left @ FloatType(), StringNaN) => + case b @ BinaryOperator(left @ FloatType(), StringNaN) => b.makeCopy(Array(left, Literal(Float.NaN))) /* Use float NaN by default to avoid unnecessary type widening */ - case b @ BinaryExpression(left @ StringNaN, StringNaN) => + case b @ BinaryOperator(left @ StringNaN, StringNaN) => b.makeCopy(Array(left, Literal(Float.NaN))) } } @@ -202,8 +199,6 @@ trait HiveTypeCoercion { * - LongType to DoubleType */ object WidenTypes extends Rule[LogicalPlan] { - import HiveTypeCoercion._ - def apply(plan: LogicalPlan): LogicalPlan = plan transform { // TODO: unions with fixed-precision decimals case u @ Union(left, right) if u.childrenResolved && !u.resolved => @@ -251,12 +246,12 @@ trait HiveTypeCoercion { Union(newLeft, newRight) - // Also widen types for BinaryExpressions. + // Also widen types for BinaryOperator. case q: LogicalPlan => q transformExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e - case b @ BinaryExpression(left, right) if left.dataType != right.dataType => + case b @ BinaryOperator(left, right) if left.dataType != right.dataType => findTightestCommonTypeOfTwo(left.dataType, right.dataType).map { widestType => val newLeft = if (left.dataType == widestType) left else Cast(left, widestType) val newRight = if (right.dataType == widestType) right else Cast(right, widestType) @@ -484,7 +479,7 @@ trait HiveTypeCoercion { // Promote integers inside a binary expression with fixed-precision decimals to decimals, // and fixed-precision decimals in an expression with floats / doubles to doubles - case b @ BinaryExpression(left, right) if left.dataType != right.dataType => + case b @ BinaryOperator(left, right) if left.dataType != right.dataType => (left.dataType, right.dataType) match { case (t, DecimalType.Fixed(p, s)) if intTypeToFixed.contains(t) => b.makeCopy(Array(Cast(left, intTypeToFixed(t)), right)) @@ -655,8 +650,6 @@ trait HiveTypeCoercion { * Coerces the type of different branches of a CASE WHEN statement to a common type. */ object CaseWhenCoercion extends Rule[LogicalPlan] { - import HiveTypeCoercion._ - def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { case c: CaseWhenLike if c.childrenResolved && !c.valueTypesEqual => logDebug(s"Input values for null casting ${c.valueTypes.mkString(",")}") @@ -713,14 +706,13 @@ trait HiveTypeCoercion { * Casts types according to the expected input types for Expressions that have the trait * [[AutoCastInputTypes]]. */ - object AddCastForAutoCastInputTypes extends Rule[LogicalPlan] { - + object ImplicitTypeCasts extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e - case e: AutoCastInputTypes if e.children.map(_.dataType) != e.expectedChildTypes => - val newC = (e.children, e.children.map(_.dataType), e.expectedChildTypes).zipped.map { + case e: AutoCastInputTypes if e.children.map(_.dataType) != e.inputTypes => + val newC = (e.children, e.children.map(_.dataType), e.inputTypes).zipped.map { case (child, actual, expected) => if (actual == expected) child else Cast(child, expected) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala new file mode 100644 index 0000000000000..450fc4165f93b --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.types.DataType + + +/** + * An trait that gets mixin to define the expected input types of an expression. + */ +trait ExpectsInputTypes { self: Expression => + + /** + * Expected input types from child expressions. The i-th position in the returned seq indicates + * the type requirement for the i-th child. + * + * The possible values at each position are: + * 1. a specific data type, e.g. LongType, StringType. + * 2. a non-leaf data type, e.g. NumericType, IntegralType, FractionalType. + * 3. a list of specific data types, e.g. Seq(StringType, BinaryType). + */ + def inputTypes: Seq[Any] + + override def checkInputDataTypes(): TypeCheckResult = { + // We will do the type checking in `HiveTypeCoercion`, so always returning success here. + TypeCheckResult.TypeCheckSuccess + } +} + +/** + * Expressions that require a specific `DataType` as input should implement this trait + * so that the proper type conversions can be performed in the analyzer. + */ +trait AutoCastInputTypes { self: Expression => + + def inputTypes: Seq[DataType] + + override def checkInputDataTypes(): TypeCheckResult = { + // We will always do type casting for `AutoCastInputTypes` in `HiveTypeCoercion`, + // so type mismatch error won't be reported here, but for underling `Cast`s. + TypeCheckResult.TypeCheckSuccess + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index b5063f32fa529..cafbbafdca207 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -119,17 +119,6 @@ abstract class Expression extends TreeNode[Expression] { */ def childrenResolved: Boolean = children.forall(_.resolved) - /** - * Returns a string representation of this expression that does not have developer centric - * debugging information like the expression id. - */ - def prettyString: String = { - transform { - case a: AttributeReference => PrettyAttribute(a.name) - case u: UnresolvedAttribute => PrettyAttribute(u.name) - }.toString - } - /** * Returns true when two expressions will always compute the same result, even if they differ * cosmetically (i.e. capitalization of names in attributes may be different). @@ -154,71 +143,40 @@ abstract class Expression extends TreeNode[Expression] { * Note: it's not valid to call this method until `childrenResolved == true`. */ def checkInputDataTypes(): TypeCheckResult = TypeCheckResult.TypeCheckSuccess -} - -abstract class BinaryExpression extends Expression with trees.BinaryNode[Expression] { - self: Product => - - def symbol: String = sys.error(s"BinaryExpressions must override either toString or symbol") - - override def foldable: Boolean = left.foldable && right.foldable - - override def nullable: Boolean = left.nullable || right.nullable - - override def toString: String = s"($left $symbol $right)" /** - * Short hand for generating binary evaluation code. - * If either of the sub-expressions is null, the result of this computation - * is assumed to be null. - * - * @param f accepts two variable names and returns Java code to compute the output. + * Returns a user-facing string representation of this expression's name. + * This should usually match the name of the function in SQL. */ - protected def defineCodeGen( - ctx: CodeGenContext, - ev: GeneratedExpressionCode, - f: (String, String) => String): String = { - nullSafeCodeGen(ctx, ev, (result, eval1, eval2) => { - s"$result = ${f(eval1, eval2)};" - }) - } + def prettyName: String = getClass.getSimpleName.toLowerCase /** - * Short hand for generating binary evaluation code. - * If either of the sub-expressions is null, the result of this computation - * is assumed to be null. + * Returns a user-facing string representation of this expression, i.e. does not have developer + * centric debugging information like the expression id. */ - protected def nullSafeCodeGen( - ctx: CodeGenContext, - ev: GeneratedExpressionCode, - f: (String, String, String) => String): String = { - val eval1 = left.gen(ctx) - val eval2 = right.gen(ctx) - val resultCode = f(ev.primitive, eval1.primitive, eval2.primitive) - s""" - ${eval1.code} - boolean ${ev.isNull} = ${eval1.isNull}; - ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; - if (!${ev.isNull}) { - ${eval2.code} - if (!${eval2.isNull}) { - $resultCode - } else { - ${ev.isNull} = true; - } - } - """ + def prettyString: String = { + transform { + case a: AttributeReference => PrettyAttribute(a.name) + case u: UnresolvedAttribute => PrettyAttribute(u.name) + }.toString } -} -private[sql] object BinaryExpression { - def unapply(e: BinaryExpression): Option[(Expression, Expression)] = Some((e.left, e.right)) + override def toString: String = prettyName + children.mkString("(", ",", ")") } + +/** + * A leaf expression, i.e. one without any child expressions. + */ abstract class LeafExpression extends Expression with trees.LeafNode[Expression] { self: Product => } + +/** + * An expression with one input and one output. The output is by default evaluated to null + * if the input is evaluated to null. + */ abstract class UnaryExpression extends Expression with trees.UnaryNode[Expression] { self: Product => @@ -265,18 +223,76 @@ abstract class UnaryExpression extends Expression with trees.UnaryNode[Expressio } } + /** - * Expressions that require a specific `DataType` as input should implement this trait - * so that the proper type conversions can be performed in the analyzer. + * An expression with two inputs and one output. The output is by default evaluated to null + * if any input is evaluated to null. */ -trait AutoCastInputTypes { - self: Expression => +abstract class BinaryExpression extends Expression with trees.BinaryNode[Expression] { + self: Product => + + override def foldable: Boolean = left.foldable && right.foldable - def expectedChildTypes: Seq[DataType] + override def nullable: Boolean = left.nullable || right.nullable + + /** + * Short hand for generating binary evaluation code. + * If either of the sub-expressions is null, the result of this computation + * is assumed to be null. + * + * @param f accepts two variable names and returns Java code to compute the output. + */ + protected def defineCodeGen( + ctx: CodeGenContext, + ev: GeneratedExpressionCode, + f: (String, String) => String): String = { + nullSafeCodeGen(ctx, ev, (result, eval1, eval2) => { + s"$result = ${f(eval1, eval2)};" + }) + } - override def checkInputDataTypes(): TypeCheckResult = { - // We will always do type casting for `ExpectsInputTypes` in `HiveTypeCoercion`, - // so type mismatch error won't be reported here, but for underling `Cast`s. - TypeCheckResult.TypeCheckSuccess + /** + * Short hand for generating binary evaluation code. + * If either of the sub-expressions is null, the result of this computation + * is assumed to be null. + */ + protected def nullSafeCodeGen( + ctx: CodeGenContext, + ev: GeneratedExpressionCode, + f: (String, String, String) => String): String = { + val eval1 = left.gen(ctx) + val eval2 = right.gen(ctx) + val resultCode = f(ev.primitive, eval1.primitive, eval2.primitive) + s""" + ${eval1.code} + boolean ${ev.isNull} = ${eval1.isNull}; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + if (!${ev.isNull}) { + ${eval2.code} + if (!${eval2.isNull}) { + $resultCode + } else { + ${ev.isNull} = true; + } + } + """ } } + + +/** + * An expression that has two inputs that are expected to the be same type. If the two inputs have + * different types, the analyzer will find the tightest common type and do the proper type casting. + */ +abstract class BinaryOperator extends BinaryExpression { + self: Product => + + def symbol: String + + override def toString: String = s"($left $symbol $right)" +} + + +private[sql] object BinaryOperator { + def unapply(e: BinaryOperator): Option[(Expression, Expression)] = Some((e.left, e.right)) +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala index ebabb6f117851..caf021b016a41 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala @@ -29,7 +29,7 @@ case class ScalaUDF(function: AnyRef, dataType: DataType, children: Seq[Expressi override def nullable: Boolean = true - override def toString: String = s"scalaUDF(${children.mkString(",")})" + override def toString: String = s"UDF(${children.mkString(",")})" // scalastyle:off diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala index a9fc54c548f49..64e07bd2a17db 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala @@ -94,7 +94,6 @@ case class Min(child: Expression) extends PartialAggregate with trees.UnaryNode[ override def nullable: Boolean = true override def dataType: DataType = child.dataType - override def toString: String = s"MIN($child)" override def asPartial: SplitEvaluation = { val partialMin = Alias(Min(child), "PartialMin")() @@ -128,7 +127,6 @@ case class Max(child: Expression) extends PartialAggregate with trees.UnaryNode[ override def nullable: Boolean = true override def dataType: DataType = child.dataType - override def toString: String = s"MAX($child)" override def asPartial: SplitEvaluation = { val partialMax = Alias(Max(child), "PartialMax")() @@ -162,7 +160,6 @@ case class Count(child: Expression) extends PartialAggregate with trees.UnaryNod override def nullable: Boolean = false override def dataType: LongType.type = LongType - override def toString: String = s"COUNT($child)" override def asPartial: SplitEvaluation = { val partialCount = Alias(Count(child), "PartialCount")() @@ -390,6 +387,8 @@ case class ApproxCountDistinct(child: Expression, relativeSD: Double = 0.05) case class Average(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { + override def prettyName: String = "avg" + override def nullable: Boolean = true override def dataType: DataType = child.dataType match { @@ -401,8 +400,6 @@ case class Average(child: Expression) extends PartialAggregate with trees.UnaryN DoubleType } - override def toString: String = s"AVG($child)" - override def asPartial: SplitEvaluation = { child.dataType match { case DecimalType.Fixed(_, _) | DecimalType.Unlimited => @@ -494,8 +491,6 @@ case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[ child.dataType } - override def toString: String = s"SUM($child)" - override def asPartial: SplitEvaluation = { child.dataType match { case DecimalType.Fixed(_, _) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 5363b3556886a..4fbf4c87009c2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -57,7 +57,7 @@ case class UnaryMinus(child: Expression) extends UnaryArithmetic { } case class UnaryPositive(child: Expression) extends UnaryArithmetic { - override def toString: String = s"positive($child)" + override def prettyName: String = "positive" override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = defineCodeGen(ctx, ev, c => c) @@ -69,8 +69,6 @@ case class UnaryPositive(child: Expression) extends UnaryArithmetic { * A function that get the absolute value of the numeric value. */ case class Abs(child: Expression) extends UnaryArithmetic { - override def toString: String = s"Abs($child)" - override def checkInputDataTypes(): TypeCheckResult = TypeUtils.checkForNumericExpr(child.dataType, "function abs") @@ -79,10 +77,9 @@ case class Abs(child: Expression) extends UnaryArithmetic { protected override def evalInternal(evalE: Any) = numeric.abs(evalE) } -abstract class BinaryArithmetic extends BinaryExpression { +abstract class BinaryArithmetic extends BinaryOperator { self: Product => - override def dataType: DataType = left.dataType override def checkInputDataTypes(): TypeCheckResult = { @@ -360,7 +357,9 @@ case class MaxOf(left: Expression, right: Expression) extends BinaryArithmetic { } """ } - override def toString: String = s"MaxOf($left, $right)" + + override def symbol: String = "max" + override def prettyName: String = symbol } case class MinOf(left: Expression, right: Expression) extends BinaryArithmetic { @@ -413,5 +412,6 @@ case class MinOf(left: Expression, right: Expression) extends BinaryArithmetic { """ } - override def toString: String = s"MinOf($left, $right)" + override def symbol: String = "min" + override def prettyName: String = symbol } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala index 5be47175fa7f1..3c7ee9cc16599 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala @@ -148,7 +148,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { }.mkString("\n") val copyColumns = expressions.zipWithIndex.map { case (e, i) => - s"""arr[$i] = c$i;""" + s"""if (!nullBits[$i]) arr[$i] = c$i;""" }.mkString("\n ") val code = s""" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 5def57b067424..67e7dc4ec8b14 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -43,7 +43,7 @@ case class CreateArray(children: Seq[Expression]) extends Expression { children.map(_.eval(input)) } - override def toString: String = s"Array(${children.mkString(",")})" + override def prettyName: String = "array" } /** @@ -71,4 +71,6 @@ case class CreateStruct(children: Seq[Expression]) extends Expression { override def eval(input: InternalRow): Any = { InternalRow(children.map(_.eval(input)): _*) } + + override def prettyName: String = "struct" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala index 1097fea776e17..5725405451435 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala @@ -59,7 +59,7 @@ abstract class UnaryMathExpression(f: Double => Double, name: String) extends UnaryExpression with Serializable with AutoCastInputTypes { self: Product => - override def expectedChildTypes: Seq[DataType] = Seq(DoubleType) + override def inputTypes: Seq[DataType] = Seq(DoubleType) override def dataType: DataType = DoubleType override def nullable: Boolean = true override def toString: String = s"$name($child)" @@ -98,7 +98,7 @@ abstract class UnaryMathExpression(f: Double => Double, name: String) abstract class BinaryMathExpression(f: (Double, Double) => Double, name: String) extends BinaryExpression with Serializable with AutoCastInputTypes { self: Product => - override def expectedChildTypes: Seq[DataType] = Seq(DoubleType, DoubleType) + override def inputTypes: Seq[DataType] = Seq(DoubleType, DoubleType) override def toString: String = s"$name($left, $right)" @@ -210,7 +210,7 @@ case class ToRadians(child: Expression) extends UnaryMathExpression(math.toRadia case class Bin(child: Expression) extends UnaryExpression with Serializable with AutoCastInputTypes { - override def expectedChildTypes: Seq[DataType] = Seq(LongType) + override def inputTypes: Seq[DataType] = Seq(LongType) override def dataType: DataType = StringType override def eval(input: InternalRow): Any = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index a7bcbe46c339a..407023e472081 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -36,7 +36,7 @@ case class Md5(child: Expression) override def dataType: DataType = StringType - override def expectedChildTypes: Seq[DataType] = Seq(BinaryType) + override def inputTypes: Seq[DataType] = Seq(BinaryType) override def eval(input: InternalRow): Any = { val value = child.eval(input) @@ -68,7 +68,7 @@ case class Sha2(left: Expression, right: Expression) override def toString: String = s"SHA2($left, $right)" - override def expectedChildTypes: Seq[DataType] = Seq(BinaryType, IntegerType) + override def inputTypes: Seq[DataType] = Seq(BinaryType, IntegerType) override def eval(input: InternalRow): Any = { val evalE1 = left.eval(input) @@ -151,7 +151,7 @@ case class Sha1(child: Expression) extends UnaryExpression with AutoCastInputTyp override def dataType: DataType = StringType - override def expectedChildTypes: Seq[DataType] = Seq(BinaryType) + override def inputTypes: Seq[DataType] = Seq(BinaryType) override def eval(input: InternalRow): Any = { val value = child.eval(input) @@ -179,7 +179,7 @@ case class Crc32(child: Expression) override def dataType: DataType = LongType - override def expectedChildTypes: Seq[DataType] = Seq(BinaryType) + override def inputTypes: Seq[DataType] = Seq(BinaryType) override def eval(input: InternalRow): Any = { val value = child.eval(input) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala index 78be2824347d7..145d323a9f0bb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala @@ -38,8 +38,6 @@ case class Coalesce(children: Seq[Expression]) extends Expression { } } - override def toString: String = s"Coalesce(${children.mkString(",")})" - override def dataType: DataType = children.head.dataType override def eval(input: InternalRow): Any = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 98cd5aa8148c4..34df89a163895 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -72,7 +72,7 @@ trait PredicateHelper { case class Not(child: Expression) extends UnaryExpression with Predicate with AutoCastInputTypes { override def toString: String = s"NOT $child" - override def expectedChildTypes: Seq[DataType] = Seq(BooleanType) + override def inputTypes: Seq[DataType] = Seq(BooleanType) override def eval(input: InternalRow): Any = { child.eval(input) match { @@ -120,9 +120,9 @@ case class InSet(value: Expression, hset: Set[Any]) } case class And(left: Expression, right: Expression) - extends BinaryExpression with Predicate with AutoCastInputTypes { + extends BinaryOperator with Predicate with AutoCastInputTypes { - override def expectedChildTypes: Seq[DataType] = Seq(BooleanType, BooleanType) + override def inputTypes: Seq[DataType] = Seq(BooleanType, BooleanType) override def symbol: String = "&&" @@ -169,9 +169,9 @@ case class And(left: Expression, right: Expression) } case class Or(left: Expression, right: Expression) - extends BinaryExpression with Predicate with AutoCastInputTypes { + extends BinaryOperator with Predicate with AutoCastInputTypes { - override def expectedChildTypes: Seq[DataType] = Seq(BooleanType, BooleanType) + override def inputTypes: Seq[DataType] = Seq(BooleanType, BooleanType) override def symbol: String = "||" @@ -217,7 +217,7 @@ case class Or(left: Expression, right: Expression) } } -abstract class BinaryComparison extends BinaryExpression with Predicate { +abstract class BinaryComparison extends BinaryOperator with Predicate { self: Product => override def checkInputDataTypes(): TypeCheckResult = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala index daa9f4403ffab..5d51a4ca65332 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala @@ -137,8 +137,6 @@ case class CombineSets(left: Expression, right: Expression) extends BinaryExpres override def dataType: DataType = left.dataType - override def symbol: String = "++=" - override def eval(input: InternalRow): Any = { val leftEval = left.eval(input).asInstanceOf[OpenHashSet[Any]] if(leftEval != null) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index ce184e4f32f18..b020f2bbc5818 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -32,7 +32,7 @@ trait StringRegexExpression extends AutoCastInputTypes { override def nullable: Boolean = left.nullable || right.nullable override def dataType: DataType = BooleanType - override def expectedChildTypes: Seq[DataType] = Seq(StringType, StringType) + override def inputTypes: Seq[DataType] = Seq(StringType, StringType) // try cache the pattern for Literal private lazy val cache: Pattern = right match { @@ -75,8 +75,6 @@ trait StringRegexExpression extends AutoCastInputTypes { case class Like(left: Expression, right: Expression) extends BinaryExpression with StringRegexExpression { - override def symbol: String = "LIKE" - // replace the _ with .{1} exactly match 1 time of any character // replace the % with .*, match 0 or more times with any character override def escape(v: String): String = @@ -101,14 +99,16 @@ case class Like(left: Expression, right: Expression) } override def matches(regex: Pattern, str: String): Boolean = regex.matcher(str).matches() + + override def toString: String = s"$left LIKE $right" } case class RLike(left: Expression, right: Expression) extends BinaryExpression with StringRegexExpression { - override def symbol: String = "RLIKE" override def escape(v: String): String = v override def matches(regex: Pattern, str: String): Boolean = regex.matcher(str).find(0) + override def toString: String = s"$left RLIKE $right" } trait CaseConversionExpression extends AutoCastInputTypes { @@ -117,7 +117,7 @@ trait CaseConversionExpression extends AutoCastInputTypes { def convert(v: UTF8String): UTF8String override def dataType: DataType = StringType - override def expectedChildTypes: Seq[DataType] = Seq(StringType) + override def inputTypes: Seq[DataType] = Seq(StringType) override def eval(input: InternalRow): Any = { val evaluated = child.eval(input) @@ -134,9 +134,7 @@ trait CaseConversionExpression extends AutoCastInputTypes { */ case class Upper(child: Expression) extends UnaryExpression with CaseConversionExpression { - override def convert(v: UTF8String): UTF8String = v.toUpperCase() - - override def toString: String = s"Upper($child)" + override def convert(v: UTF8String): UTF8String = v.toUpperCase override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { defineCodeGen(ctx, ev, c => s"($c).toUpperCase()") @@ -148,9 +146,7 @@ case class Upper(child: Expression) extends UnaryExpression with CaseConversionE */ case class Lower(child: Expression) extends UnaryExpression with CaseConversionExpression { - override def convert(v: UTF8String): UTF8String = v.toLowerCase() - - override def toString: String = s"Lower($child)" + override def convert(v: UTF8String): UTF8String = v.toLowerCase override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { defineCodeGen(ctx, ev, c => s"($c).toLowerCase()") @@ -165,7 +161,7 @@ trait StringComparison extends AutoCastInputTypes { override def nullable: Boolean = left.nullable || right.nullable - override def expectedChildTypes: Seq[DataType] = Seq(StringType, StringType) + override def inputTypes: Seq[DataType] = Seq(StringType, StringType) override def eval(input: InternalRow): Any = { val leftEval = left.eval(input) @@ -178,8 +174,6 @@ trait StringComparison extends AutoCastInputTypes { } } - override def symbol: String = nodeName - override def toString: String = s"$nodeName($left, $right)" } @@ -238,7 +232,7 @@ case class Substring(str: Expression, pos: Expression, len: Expression) if (str.dataType == BinaryType) str.dataType else StringType } - override def expectedChildTypes: Seq[DataType] = Seq(StringType, IntegerType, IntegerType) + override def inputTypes: Seq[DataType] = Seq(StringType, IntegerType, IntegerType) override def children: Seq[Expression] = str :: pos :: len :: Nil @@ -284,12 +278,6 @@ case class Substring(str: Expression, pos: Expression, len: Expression) } } } - - override def toString: String = len match { - // TODO: This is broken because max is not an integer value. - case max if max == Integer.MAX_VALUE => s"SUBSTR($str, $pos)" - case _ => s"SUBSTR($str, $pos, $len)" - } } /** @@ -297,16 +285,16 @@ case class Substring(str: Expression, pos: Expression, len: Expression) */ case class StringLength(child: Expression) extends UnaryExpression with AutoCastInputTypes { override def dataType: DataType = IntegerType - override def expectedChildTypes: Seq[DataType] = Seq(StringType) + override def inputTypes: Seq[DataType] = Seq(StringType) override def eval(input: InternalRow): Any = { val string = child.eval(input) if (string == null) null else string.asInstanceOf[UTF8String].length } - override def toString: String = s"length($child)" - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { defineCodeGen(ctx, ev, c => s"($c).length()") } + + override def prettyName: String = "length" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index b009a200b920f..e911b907e8536 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -161,7 +161,7 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { if (tmp.nonEmpty) throw e inBacktick = true } else if (char == '.') { - if (tmp.isEmpty) throw e + if (name(i - 1) == '.' || i == name.length - 1) throw e nameParts += tmp.mkString tmp.clear() } else { @@ -170,7 +170,7 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { } i += 1 } - if (tmp.isEmpty || inBacktick) throw e + if (inBacktick) throw e nameParts += tmp.mkString nameParts.toSeq } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala index f7b8e21bed490..eae3666595a38 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala @@ -113,8 +113,7 @@ class HiveTypeCoercionSuite extends PlanTest { } test("coalesce casts") { - val fac = new HiveTypeCoercion { }.FunctionArgumentConversion - ruleTest(fac, + ruleTest(HiveTypeCoercion.FunctionArgumentConversion, Coalesce(Literal(1.0) :: Literal(1) :: Literal.create(1.0, FloatType) @@ -123,7 +122,7 @@ class HiveTypeCoercionSuite extends PlanTest { :: Cast(Literal(1), DoubleType) :: Cast(Literal.create(1.0, FloatType), DoubleType) :: Nil)) - ruleTest(fac, + ruleTest(HiveTypeCoercion.FunctionArgumentConversion, Coalesce(Literal(1L) :: Literal(1) :: Literal(new java.math.BigDecimal("1000000000000000000000")) @@ -135,7 +134,7 @@ class HiveTypeCoercionSuite extends PlanTest { } test("type coercion for If") { - val rule = new HiveTypeCoercion { }.IfCoercion + val rule = HiveTypeCoercion.IfCoercion ruleTest(rule, If(Literal(true), Literal(1), Literal(1L)), If(Literal(true), Cast(Literal(1), LongType), Literal(1L)) @@ -148,19 +147,18 @@ class HiveTypeCoercionSuite extends PlanTest { } test("type coercion for CaseKeyWhen") { - val cwc = new HiveTypeCoercion {}.CaseWhenCoercion - ruleTest(cwc, + ruleTest(HiveTypeCoercion.CaseWhenCoercion, CaseKeyWhen(Literal(1.toShort), Seq(Literal(1), Literal("a"))), CaseKeyWhen(Cast(Literal(1.toShort), IntegerType), Seq(Literal(1), Literal("a"))) ) - ruleTest(cwc, + ruleTest(HiveTypeCoercion.CaseWhenCoercion, CaseKeyWhen(Literal(true), Seq(Literal(1), Literal("a"))), CaseKeyWhen(Literal(true), Seq(Literal(1), Literal("a"))) ) } test("type coercion simplification for equal to") { - val be = new HiveTypeCoercion {}.BooleanEquality + val be = HiveTypeCoercion.BooleanEquality ruleTest(be, EqualTo(Literal(true), Literal(1)), diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index 7d95ef7f710af..3171caf6ad77f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -136,6 +136,9 @@ trait ExpressionEvalHelper { val input = if (inputRow == EmptyRow) "" else s", input: $inputRow" fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: $expected$input") } + if (actual.copy() != expectedRow) { + fail(s"Copy of generated Row is wrong: actual: ${actual.copy()}, expected: $expectedRow") + } } protected def checkEvaluationWithOptimization( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala index bda217935cb05..86792f0217572 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala @@ -73,7 +73,7 @@ class TreeNodeSuite extends SparkFunSuite { val expected = Seq("+", "1", "*", "2", "-", "3", "4") val expression = Add(Literal(1), Multiply(Literal(2), Subtract(Literal(3), Literal(4)))) expression transformDown { - case b: BinaryExpression => actual.append(b.symbol); b + case b: BinaryOperator => actual.append(b.symbol); b case l: Literal => actual.append(l.toString); l } @@ -85,7 +85,7 @@ class TreeNodeSuite extends SparkFunSuite { val expected = Seq("1", "2", "3", "4", "-", "*", "+") val expression = Add(Literal(1), Multiply(Literal(2), Subtract(Literal(3), Literal(4)))) expression transformUp { - case b: BinaryExpression => actual.append(b.symbol); b + case b: BinaryOperator => actual.append(b.symbol); b case l: Literal => actual.append(l.toString); l } @@ -125,7 +125,7 @@ class TreeNodeSuite extends SparkFunSuite { val expected = Seq("1", "2", "3", "4", "-", "*", "+") val expression = Add(Literal(1), Multiply(Literal(2), Subtract(Literal(3), Literal(4)))) expression foreachUp { - case b: BinaryExpression => actual.append(b.symbol); + case b: BinaryOperator => actual.append(b.symbol); case l: Literal => actual.append(l.toString); } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 8fe1f7e34cb5e..caad2da80b1eb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -1469,14 +1469,12 @@ class DataFrame private[sql]( lazy val rdd: RDD[Row] = { // use a local variable to make sure the map closure doesn't capture the whole DataFrame val schema = this.schema - internalRowRdd.mapPartitions { rows => + queryExecution.toRdd.mapPartitions { rows => val converter = CatalystTypeConverters.createToScalaConverter(schema) rows.map(converter(_).asInstanceOf[Row]) } } - private[sql] def internalRowRdd = queryExecution.executedPlan.execute() - /** * Returns the content of the [[DataFrame]] as a [[JavaRDD]] of [[Row]]s. * @group rdd diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala index 3ebbf96090a55..4e2e2c210d5a9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala @@ -90,7 +90,7 @@ private[sql] object FrequentItems extends Logging { (name, originalSchema.fields(index).dataType) } - val freqItems = df.select(cols.map(Column(_)) : _*).internalRowRdd.aggregate(countMaps)( + val freqItems = df.select(cols.map(Column(_)) : _*).queryExecution.toRdd.aggregate(countMaps)( seqOp = (counts, row) => { var i = 0 while (i < numCols) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala index b624ef7e8fa1a..23ddfa9839e5e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala @@ -82,7 +82,7 @@ private[sql] object StatFunctions extends Logging { s"with dataType ${data.get.dataType} not supported.") } val columns = cols.map(n => Column(Cast(Column(n).expr, DoubleType))) - df.select(columns: _*).internalRowRdd.aggregate(new CovarianceCounter)( + df.select(columns: _*).queryExecution.toRdd.aggregate(new CovarianceCounter)( seqOp = (counter, row) => { counter.add(row.getDouble(0), row.getDouble(1)) }, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 19274cf544a11..feb3905ee016b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1867,7 +1867,15 @@ object functions { */ @deprecated("Use callUDF", "1.5.0") def callUdf(udfName: String, cols: Column*): Column = { - UnresolvedFunction(udfName, cols.map(_.expr)) + // Note: we avoid using closures here because on file systems that are case-insensitive, the + // compiled class file for the closure here will conflict with the one in callUDF (upper case). + val exprs = new Array[Expression](cols.size) + var i = 0 + while (i < cols.size) { + exprs(i) = cols(i).expr + i += 1 + } + UnresolvedFunction(udfName, exprs) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala index 42b51caab5ce9..7214eb0b4169a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala @@ -154,7 +154,7 @@ private[sql] case class InsertIntoHadoopFsRelation( writerContainer.driverSideSetup() try { - df.sqlContext.sparkContext.runJob(df.internalRowRdd, writeRows _) + df.sqlContext.sparkContext.runJob(df.queryExecution.toRdd, writeRows _) writerContainer.commitJob() relation.refresh() } catch { case cause: Throwable => @@ -220,7 +220,7 @@ private[sql] case class InsertIntoHadoopFsRelation( writerContainer.driverSideSetup() try { - df.sqlContext.sparkContext.runJob(df.internalRowRdd, writeRows _) + df.sqlContext.sparkContext.runJob(df.queryExecution.toRdd, writeRows _) writerContainer.commitJob() relation.refresh() } catch { case cause: Throwable => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 50d324c0686fa..afb1cf5f8d1cb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -730,4 +730,11 @@ class DataFrameSuite extends QueryTest { val res11 = ctx.range(-1).select("id") assert(res11.count == 0) } + + test("SPARK-8621: support empty string column name") { + val df = Seq(Tuple1(1)).toDF("").as("t") + // We should allow empty string as column name + df.col("") + df.col("t.``") + } }