From b6974f8fed1726a381636e996834111a8e7ced8d Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 5 Nov 2015 15:34:05 -0800 Subject: [PATCH 01/88] [SPARK-11536][SQL] Remove the internal implicit conversion from Expression to Column in functions.scala Author: Reynold Xin Closes #9505 from rxin/SPARK-11536. --- .../org/apache/spark/sql/functions.scala | 580 +++++++++--------- 1 file changed, 299 insertions(+), 281 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index c70c965a9b04c..04627589886a8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -51,7 +51,7 @@ import org.apache.spark.util.Utils object functions { // scalastyle:on - private[this] implicit def toColumn(expr: Expression): Column = Column(expr) + private def withExpr(expr: Expression): Column = Column(expr) /** * Returns a [[Column]] based on the given column name. @@ -128,7 +128,7 @@ object functions { * @group agg_funcs * @since 1.3.0 */ - def approxCountDistinct(e: Column): Column = ApproxCountDistinct(e.expr) + def approxCountDistinct(e: Column): Column = withExpr { ApproxCountDistinct(e.expr) } /** * Aggregate function: returns the approximate number of distinct items in a group. @@ -144,7 +144,9 @@ object functions { * @group agg_funcs * @since 1.3.0 */ - def approxCountDistinct(e: Column, rsd: Double): Column = ApproxCountDistinct(e.expr, rsd) + def approxCountDistinct(e: Column, rsd: Double): Column = withExpr { + ApproxCountDistinct(e.expr, rsd) + } /** * Aggregate function: returns the approximate number of distinct items in a group. @@ -162,7 +164,7 @@ object functions { * @group agg_funcs * @since 1.3.0 */ - def avg(e: Column): Column = Average(e.expr) + def avg(e: Column): Column = withExpr { Average(e.expr) } /** * Aggregate function: returns the average of the values in a group. @@ -178,8 +180,9 @@ object functions { * @group agg_funcs * @since 1.6.0 */ - def corr(column1: Column, column2: Column): Column = + def corr(column1: Column, column2: Column): Column = withExpr { Corr(column1.expr, column2.expr) + } /** * Aggregate function: returns the Pearson Correlation Coefficient for two columns. @@ -187,8 +190,9 @@ object functions { * @group agg_funcs * @since 1.6.0 */ - def corr(columnName1: String, columnName2: String): Column = + def corr(columnName1: String, columnName2: String): Column = { corr(Column(columnName1), Column(columnName2)) + } /** * Aggregate function: returns the number of items in a group. @@ -196,10 +200,12 @@ object functions { * @group agg_funcs * @since 1.3.0 */ - def count(e: Column): Column = e.expr match { - // Turn count(*) into count(1) - case s: Star => Count(Literal(1)) - case _ => Count(e.expr) + def count(e: Column): Column = withExpr { + e.expr match { + // Turn count(*) into count(1) + case s: Star => Count(Literal(1)) + case _ => Count(e.expr) + } } /** @@ -217,8 +223,9 @@ object functions { * @since 1.3.0 */ @scala.annotation.varargs - def countDistinct(expr: Column, exprs: Column*): Column = + def countDistinct(expr: Column, exprs: Column*): Column = withExpr { CountDistinct((expr +: exprs).map(_.expr)) + } /** * Aggregate function: returns the number of distinct items in a group. @@ -236,7 +243,7 @@ object functions { * @group agg_funcs * @since 1.3.0 */ - def first(e: Column): Column = First(e.expr) + def first(e: Column): Column = withExpr { First(e.expr) } /** * Aggregate function: returns the first value of a column in a group. @@ -252,7 +259,7 @@ object functions { * @group agg_funcs * @since 1.6.0 */ - def kurtosis(e: Column): Column = Kurtosis(e.expr) + def kurtosis(e: Column): Column = withExpr { Kurtosis(e.expr) } /** * Aggregate function: returns the last value in a group. @@ -260,7 +267,7 @@ object functions { * @group agg_funcs * @since 1.3.0 */ - def last(e: Column): Column = Last(e.expr) + def last(e: Column): Column = withExpr { Last(e.expr) } /** * Aggregate function: returns the last value of the column in a group. @@ -276,7 +283,7 @@ object functions { * @group agg_funcs * @since 1.3.0 */ - def max(e: Column): Column = Max(e.expr) + def max(e: Column): Column = withExpr { Max(e.expr) } /** * Aggregate function: returns the maximum value of the column in a group. @@ -310,7 +317,7 @@ object functions { * @group agg_funcs * @since 1.3.0 */ - def min(e: Column): Column = Min(e.expr) + def min(e: Column): Column = withExpr { Min(e.expr) } /** * Aggregate function: returns the minimum value of the column in a group. @@ -326,7 +333,7 @@ object functions { * @group agg_funcs * @since 1.6.0 */ - def skewness(e: Column): Column = Skewness(e.expr) + def skewness(e: Column): Column = withExpr { Skewness(e.expr) } /** * Aggregate function: alias for [[stddev_samp]]. @@ -334,7 +341,7 @@ object functions { * @group agg_funcs * @since 1.6.0 */ - def stddev(e: Column): Column = StddevSamp(e.expr) + def stddev(e: Column): Column = withExpr { StddevSamp(e.expr) } /** * Aggregate function: returns the unbiased sample standard deviation of @@ -343,7 +350,7 @@ object functions { * @group agg_funcs * @since 1.6.0 */ - def stddev_samp(e: Column): Column = StddevSamp(e.expr) + def stddev_samp(e: Column): Column = withExpr { StddevSamp(e.expr) } /** * Aggregate function: returns the population standard deviation of @@ -352,7 +359,7 @@ object functions { * @group agg_funcs * @since 1.6.0 */ - def stddev_pop(e: Column): Column = StddevPop(e.expr) + def stddev_pop(e: Column): Column = withExpr { StddevPop(e.expr) } /** * Aggregate function: returns the sum of all values in the expression. @@ -360,7 +367,7 @@ object functions { * @group agg_funcs * @since 1.3.0 */ - def sum(e: Column): Column = Sum(e.expr) + def sum(e: Column): Column = withExpr { Sum(e.expr) } /** * Aggregate function: returns the sum of all values in the given column. @@ -376,7 +383,7 @@ object functions { * @group agg_funcs * @since 1.3.0 */ - def sumDistinct(e: Column): Column = SumDistinct(e.expr) + def sumDistinct(e: Column): Column = withExpr { SumDistinct(e.expr) } /** * Aggregate function: returns the sum of distinct values in the expression. @@ -392,7 +399,7 @@ object functions { * @group agg_funcs * @since 1.6.0 */ - def variance(e: Column): Column = VarianceSamp(e.expr) + def variance(e: Column): Column = withExpr { VarianceSamp(e.expr) } /** * Aggregate function: returns the unbiased variance of the values in a group. @@ -400,7 +407,7 @@ object functions { * @group agg_funcs * @since 1.6.0 */ - def var_samp(e: Column): Column = VarianceSamp(e.expr) + def var_samp(e: Column): Column = withExpr { VarianceSamp(e.expr) } /** * Aggregate function: returns the population variance of the values in a group. @@ -408,7 +415,7 @@ object functions { * @group agg_funcs * @since 1.6.0 */ - def var_pop(e: Column): Column = VariancePop(e.expr) + def var_pop(e: Column): Column = withExpr { VariancePop(e.expr) } ////////////////////////////////////////////////////////////////////////////////////////////// // Window functions @@ -429,9 +436,7 @@ object functions { * @group window_funcs * @since 1.4.0 */ - def cumeDist(): Column = { - UnresolvedWindowFunction("cume_dist", Nil) - } + def cumeDist(): Column = withExpr { UnresolvedWindowFunction("cume_dist", Nil) } /** * Window function: returns the rank of rows within a window partition, without any gaps. @@ -446,9 +451,7 @@ object functions { * @group window_funcs * @since 1.4.0 */ - def denseRank(): Column = { - UnresolvedWindowFunction("dense_rank", Nil) - } + def denseRank(): Column = withExpr { UnresolvedWindowFunction("dense_rank", Nil) } /** * Window function: returns the value that is `offset` rows before the current row, and @@ -460,9 +463,7 @@ object functions { * @group window_funcs * @since 1.4.0 */ - def lag(e: Column, offset: Int): Column = { - lag(e, offset, null) - } + def lag(e: Column, offset: Int): Column = lag(e, offset, null) /** * Window function: returns the value that is `offset` rows before the current row, and @@ -474,9 +475,7 @@ object functions { * @group window_funcs * @since 1.4.0 */ - def lag(columnName: String, offset: Int): Column = { - lag(columnName, offset, null) - } + def lag(columnName: String, offset: Int): Column = lag(columnName, offset, null) /** * Window function: returns the value that is `offset` rows before the current row, and @@ -502,7 +501,7 @@ object functions { * @group window_funcs * @since 1.4.0 */ - def lag(e: Column, offset: Int, defaultValue: Any): Column = { + def lag(e: Column, offset: Int, defaultValue: Any): Column = withExpr { UnresolvedWindowFunction("lag", e.expr :: Literal(offset) :: Literal(defaultValue) :: Nil) } @@ -516,9 +515,7 @@ object functions { * @group window_funcs * @since 1.4.0 */ - def lead(columnName: String, offset: Int): Column = { - lead(columnName, offset, null) - } + def lead(columnName: String, offset: Int): Column = { lead(columnName, offset, null) } /** * Window function: returns the value that is `offset` rows after the current row, and @@ -530,9 +527,7 @@ object functions { * @group window_funcs * @since 1.4.0 */ - def lead(e: Column, offset: Int): Column = { - lead(e, offset, null) - } + def lead(e: Column, offset: Int): Column = { lead(e, offset, null) } /** * Window function: returns the value that is `offset` rows after the current row, and @@ -558,7 +553,7 @@ object functions { * @group window_funcs * @since 1.4.0 */ - def lead(e: Column, offset: Int, defaultValue: Any): Column = { + def lead(e: Column, offset: Int, defaultValue: Any): Column = withExpr { UnresolvedWindowFunction("lead", e.expr :: Literal(offset) :: Literal(defaultValue) :: Nil) } @@ -572,9 +567,7 @@ object functions { * @group window_funcs * @since 1.4.0 */ - def ntile(n: Int): Column = { - UnresolvedWindowFunction("ntile", lit(n).expr :: Nil) - } + def ntile(n: Int): Column = withExpr { UnresolvedWindowFunction("ntile", lit(n).expr :: Nil) } /** * Window function: returns the relative rank (i.e. percentile) of rows within a window partition. @@ -589,9 +582,7 @@ object functions { * @group window_funcs * @since 1.4.0 */ - def percentRank(): Column = { - UnresolvedWindowFunction("percent_rank", Nil) - } + def percentRank(): Column = withExpr { UnresolvedWindowFunction("percent_rank", Nil) } /** * Window function: returns the rank of rows within a window partition. @@ -606,9 +597,7 @@ object functions { * @group window_funcs * @since 1.4.0 */ - def rank(): Column = { - UnresolvedWindowFunction("rank", Nil) - } + def rank(): Column = withExpr { UnresolvedWindowFunction("rank", Nil) } /** * Window function: returns a sequential number starting at 1 within a window partition. @@ -618,9 +607,7 @@ object functions { * @group window_funcs * @since 1.4.0 */ - def rowNumber(): Column = { - UnresolvedWindowFunction("row_number", Nil) - } + def rowNumber(): Column = withExpr { UnresolvedWindowFunction("row_number", Nil) } ////////////////////////////////////////////////////////////////////////////////////////////// // Non-aggregate functions @@ -632,7 +619,7 @@ object functions { * @group normal_funcs * @since 1.3.0 */ - def abs(e: Column): Column = Abs(e.expr) + def abs(e: Column): Column = withExpr { Abs(e.expr) } /** * Creates a new array column. The input columns must all have the same data type. @@ -641,7 +628,7 @@ object functions { * @since 1.4.0 */ @scala.annotation.varargs - def array(cols: Column*): Column = CreateArray(cols.map(_.expr)) + def array(cols: Column*): Column = withExpr { CreateArray(cols.map(_.expr)) } /** * Creates a new array column. The input columns must all have the same data type. @@ -679,14 +666,14 @@ object functions { * @since 1.3.0 */ @scala.annotation.varargs - def coalesce(e: Column*): Column = Coalesce(e.map(_.expr)) + def coalesce(e: Column*): Column = withExpr { Coalesce(e.map(_.expr)) } /** * Creates a string column for the file name of the current Spark task. * * @group normal_funcs */ - def inputFileName(): Column = InputFileName() + def inputFileName(): Column = withExpr { InputFileName() } /** * Return true iff the column is NaN. @@ -694,7 +681,7 @@ object functions { * @group normal_funcs * @since 1.5.0 */ - def isNaN(e: Column): Column = IsNaN(e.expr) + def isNaN(e: Column): Column = withExpr { IsNaN(e.expr) } /** * A column expression that generates monotonically increasing 64-bit integers. @@ -711,7 +698,7 @@ object functions { * @group normal_funcs * @since 1.4.0 */ - def monotonicallyIncreasingId(): Column = MonotonicallyIncreasingID() + def monotonicallyIncreasingId(): Column = withExpr { MonotonicallyIncreasingID() } /** * Returns col1 if it is not NaN, or col2 if col1 is NaN. @@ -721,7 +708,7 @@ object functions { * @group normal_funcs * @since 1.5.0 */ - def nanvl(col1: Column, col2: Column): Column = NaNvl(col1.expr, col2.expr) + def nanvl(col1: Column, col2: Column): Column = withExpr { NaNvl(col1.expr, col2.expr) } /** * Unary minus, i.e. negate the expression. @@ -760,7 +747,7 @@ object functions { * @group normal_funcs * @since 1.4.0 */ - def rand(seed: Long): Column = Rand(seed) + def rand(seed: Long): Column = withExpr { Rand(seed) } /** * Generate a random column with i.i.d. samples from U[0.0, 1.0]. @@ -776,7 +763,7 @@ object functions { * @group normal_funcs * @since 1.4.0 */ - def randn(seed: Long): Column = Randn(seed) + def randn(seed: Long): Column = withExpr { Randn(seed) } /** * Generate a column with i.i.d. samples from the standard normal distribution. @@ -794,7 +781,7 @@ object functions { * @group normal_funcs * @since 1.4.0 */ - def sparkPartitionId(): Column = SparkPartitionID() + def sparkPartitionId(): Column = withExpr { SparkPartitionID() } /** * Computes the square root of the specified float value. @@ -802,7 +789,7 @@ object functions { * @group math_funcs * @since 1.3.0 */ - def sqrt(e: Column): Column = Sqrt(e.expr) + def sqrt(e: Column): Column = withExpr { Sqrt(e.expr) } /** * Computes the square root of the specified float value. @@ -823,9 +810,7 @@ object functions { * @since 1.4.0 */ @scala.annotation.varargs - def struct(cols: Column*): Column = { - CreateStruct(cols.map(_.expr)) - } + def struct(cols: Column*): Column = withExpr { CreateStruct(cols.map(_.expr)) } /** * Creates a new struct column that composes multiple input columns. @@ -858,7 +843,7 @@ object functions { * @group normal_funcs * @since 1.4.0 */ - def when(condition: Column, value: Any): Column = { + def when(condition: Column, value: Any): Column = withExpr { CaseWhen(Seq(condition.expr, lit(value).expr)) } @@ -868,7 +853,7 @@ object functions { * @group normal_funcs * @since 1.4.0 */ - def bitwiseNOT(e: Column): Column = BitwiseNot(e.expr) + def bitwiseNOT(e: Column): Column = withExpr { BitwiseNot(e.expr) } /** * Parses the expression string into the column that it represents, similar to @@ -893,7 +878,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def acos(e: Column): Column = Acos(e.expr) + def acos(e: Column): Column = withExpr { Acos(e.expr) } /** * Computes the cosine inverse of the given column; the returned angle is in the range @@ -911,7 +896,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def asin(e: Column): Column = Asin(e.expr) + def asin(e: Column): Column = withExpr { Asin(e.expr) } /** * Computes the sine inverse of the given column; the returned angle is in the range @@ -928,7 +913,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def atan(e: Column): Column = Atan(e.expr) + def atan(e: Column): Column = withExpr { Atan(e.expr) } /** * Computes the tangent inverse of the given column. @@ -945,7 +930,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def atan2(l: Column, r: Column): Column = Atan2(l.expr, r.expr) + def atan2(l: Column, r: Column): Column = withExpr { Atan2(l.expr, r.expr) } /** * Returns the angle theta from the conversion of rectangular coordinates (x, y) to @@ -982,7 +967,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def atan2(l: Column, r: Double): Column = atan2(l, lit(r).expr) + def atan2(l: Column, r: Double): Column = atan2(l, lit(r)) /** * Returns the angle theta from the conversion of rectangular coordinates (x, y) to @@ -1000,7 +985,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def atan2(l: Double, r: Column): Column = atan2(lit(l).expr, r) + def atan2(l: Double, r: Column): Column = atan2(lit(l), r) /** * Returns the angle theta from the conversion of rectangular coordinates (x, y) to @@ -1018,7 +1003,7 @@ object functions { * @group math_funcs * @since 1.5.0 */ - def bin(e: Column): Column = Bin(e.expr) + def bin(e: Column): Column = withExpr { Bin(e.expr) } /** * An expression that returns the string representation of the binary value of the given long @@ -1035,7 +1020,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def cbrt(e: Column): Column = Cbrt(e.expr) + def cbrt(e: Column): Column = withExpr { Cbrt(e.expr) } /** * Computes the cube-root of the given column. @@ -1051,7 +1036,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def ceil(e: Column): Column = Ceil(e.expr) + def ceil(e: Column): Column = withExpr { Ceil(e.expr) } /** * Computes the ceiling of the given column. @@ -1067,8 +1052,9 @@ object functions { * @group math_funcs * @since 1.5.0 */ - def conv(num: Column, fromBase: Int, toBase: Int): Column = + def conv(num: Column, fromBase: Int, toBase: Int): Column = withExpr { Conv(num.expr, lit(fromBase).expr, lit(toBase).expr) + } /** * Computes the cosine of the given value. @@ -1076,7 +1062,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def cos(e: Column): Column = Cos(e.expr) + def cos(e: Column): Column = withExpr { Cos(e.expr) } /** * Computes the cosine of the given column. @@ -1092,7 +1078,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def cosh(e: Column): Column = Cosh(e.expr) + def cosh(e: Column): Column = withExpr { Cosh(e.expr) } /** * Computes the hyperbolic cosine of the given column. @@ -1108,7 +1094,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def exp(e: Column): Column = Exp(e.expr) + def exp(e: Column): Column = withExpr { Exp(e.expr) } /** * Computes the exponential of the given column. @@ -1124,7 +1110,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def expm1(e: Column): Column = Expm1(e.expr) + def expm1(e: Column): Column = withExpr { Expm1(e.expr) } /** * Computes the exponential of the given column. @@ -1140,7 +1126,7 @@ object functions { * @group math_funcs * @since 1.5.0 */ - def factorial(e: Column): Column = Factorial(e.expr) + def factorial(e: Column): Column = withExpr { Factorial(e.expr) } /** * Computes the floor of the given value. @@ -1148,7 +1134,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def floor(e: Column): Column = Floor(e.expr) + def floor(e: Column): Column = withExpr { Floor(e.expr) } /** * Computes the floor of the given column. @@ -1166,7 +1152,7 @@ object functions { * @since 1.5.0 */ @scala.annotation.varargs - def greatest(exprs: Column*): Column = { + def greatest(exprs: Column*): Column = withExpr { require(exprs.length > 1, "greatest requires at least 2 arguments.") Greatest(exprs.map(_.expr)) } @@ -1189,7 +1175,7 @@ object functions { * @group math_funcs * @since 1.5.0 */ - def hex(column: Column): Column = Hex(column.expr) + def hex(column: Column): Column = withExpr { Hex(column.expr) } /** * Inverse of hex. Interprets each pair of characters as a hexadecimal number @@ -1198,7 +1184,7 @@ object functions { * @group math_funcs * @since 1.5.0 */ - def unhex(column: Column): Column = Unhex(column.expr) + def unhex(column: Column): Column = withExpr { Unhex(column.expr) } /** * Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow. @@ -1206,7 +1192,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def hypot(l: Column, r: Column): Column = Hypot(l.expr, r.expr) + def hypot(l: Column, r: Column): Column = withExpr { Hypot(l.expr, r.expr) } /** * Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow. @@ -1239,7 +1225,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def hypot(l: Column, r: Double): Column = hypot(l, lit(r).expr) + def hypot(l: Column, r: Double): Column = hypot(l, lit(r)) /** * Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow. @@ -1255,7 +1241,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def hypot(l: Double, r: Column): Column = hypot(lit(l).expr, r) + def hypot(l: Double, r: Column): Column = hypot(lit(l), r) /** * Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow. @@ -1273,7 +1259,7 @@ object functions { * @since 1.5.0 */ @scala.annotation.varargs - def least(exprs: Column*): Column = { + def least(exprs: Column*): Column = withExpr { require(exprs.length > 1, "least requires at least 2 arguments.") Least(exprs.map(_.expr)) } @@ -1296,7 +1282,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def log(e: Column): Column = Log(e.expr) + def log(e: Column): Column = withExpr { Log(e.expr) } /** * Computes the natural logarithm of the given column. @@ -1312,7 +1298,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def log(base: Double, a: Column): Column = Logarithm(lit(base).expr, a.expr) + def log(base: Double, a: Column): Column = withExpr { Logarithm(lit(base).expr, a.expr) } /** * Returns the first argument-base logarithm of the second argument. @@ -1328,7 +1314,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def log10(e: Column): Column = Log10(e.expr) + def log10(e: Column): Column = withExpr { Log10(e.expr) } /** * Computes the logarithm of the given value in base 10. @@ -1344,7 +1330,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def log1p(e: Column): Column = Log1p(e.expr) + def log1p(e: Column): Column = withExpr { Log1p(e.expr) } /** * Computes the natural logarithm of the given column plus one. @@ -1360,7 +1346,7 @@ object functions { * @group math_funcs * @since 1.5.0 */ - def log2(expr: Column): Column = Log2(expr.expr) + def log2(expr: Column): Column = withExpr { Log2(expr.expr) } /** * Computes the logarithm of the given value in base 2. @@ -1376,7 +1362,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def pow(l: Column, r: Column): Column = Pow(l.expr, r.expr) + def pow(l: Column, r: Column): Column = withExpr { Pow(l.expr, r.expr) } /** * Returns the value of the first argument raised to the power of the second argument. @@ -1408,7 +1394,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def pow(l: Column, r: Double): Column = pow(l, lit(r).expr) + def pow(l: Column, r: Double): Column = pow(l, lit(r)) /** * Returns the value of the first argument raised to the power of the second argument. @@ -1424,7 +1410,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def pow(l: Double, r: Column): Column = pow(lit(l).expr, r) + def pow(l: Double, r: Column): Column = pow(lit(l), r) /** * Returns the value of the first argument raised to the power of the second argument. @@ -1440,7 +1426,9 @@ object functions { * @group math_funcs * @since 1.5.0 */ - def pmod(dividend: Column, divisor: Column): Column = Pmod(dividend.expr, divisor.expr) + def pmod(dividend: Column, divisor: Column): Column = withExpr { + Pmod(dividend.expr, divisor.expr) + } /** * Returns the double value that is closest in value to the argument and @@ -1449,7 +1437,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def rint(e: Column): Column = Rint(e.expr) + def rint(e: Column): Column = withExpr { Rint(e.expr) } /** * Returns the double value that is closest in value to the argument and @@ -1466,7 +1454,7 @@ object functions { * @group math_funcs * @since 1.5.0 */ - def round(e: Column): Column = round(e.expr, 0) + def round(e: Column): Column = round(e, 0) /** * Round the value of `e` to `scale` decimal places if `scale` >= 0 @@ -1475,7 +1463,7 @@ object functions { * @group math_funcs * @since 1.5.0 */ - def round(e: Column, scale: Int): Column = Round(e.expr, Literal(scale)) + def round(e: Column, scale: Int): Column = withExpr { Round(e.expr, Literal(scale)) } /** * Shift the the given value numBits left. If the given value is a long value, this function @@ -1484,7 +1472,7 @@ object functions { * @group math_funcs * @since 1.5.0 */ - def shiftLeft(e: Column, numBits: Int): Column = ShiftLeft(e.expr, lit(numBits).expr) + def shiftLeft(e: Column, numBits: Int): Column = withExpr { ShiftLeft(e.expr, lit(numBits).expr) } /** * Shift the the given value numBits right. If the given value is a long value, it will return @@ -1493,7 +1481,9 @@ object functions { * @group math_funcs * @since 1.5.0 */ - def shiftRight(e: Column, numBits: Int): Column = ShiftRight(e.expr, lit(numBits).expr) + def shiftRight(e: Column, numBits: Int): Column = withExpr { + ShiftRight(e.expr, lit(numBits).expr) + } /** * Unsigned shift the the given value numBits right. If the given value is a long value, @@ -1502,8 +1492,9 @@ object functions { * @group math_funcs * @since 1.5.0 */ - def shiftRightUnsigned(e: Column, numBits: Int): Column = + def shiftRightUnsigned(e: Column, numBits: Int): Column = withExpr { ShiftRightUnsigned(e.expr, lit(numBits).expr) + } /** * Computes the signum of the given value. @@ -1511,7 +1502,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def signum(e: Column): Column = Signum(e.expr) + def signum(e: Column): Column = withExpr { Signum(e.expr) } /** * Computes the signum of the given column. @@ -1527,7 +1518,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def sin(e: Column): Column = Sin(e.expr) + def sin(e: Column): Column = withExpr { Sin(e.expr) } /** * Computes the sine of the given column. @@ -1543,7 +1534,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def sinh(e: Column): Column = Sinh(e.expr) + def sinh(e: Column): Column = withExpr { Sinh(e.expr) } /** * Computes the hyperbolic sine of the given column. @@ -1559,7 +1550,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def tan(e: Column): Column = Tan(e.expr) + def tan(e: Column): Column = withExpr { Tan(e.expr) } /** * Computes the tangent of the given column. @@ -1575,7 +1566,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def tanh(e: Column): Column = Tanh(e.expr) + def tanh(e: Column): Column = withExpr { Tanh(e.expr) } /** * Computes the hyperbolic tangent of the given column. @@ -1591,7 +1582,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def toDegrees(e: Column): Column = ToDegrees(e.expr) + def toDegrees(e: Column): Column = withExpr { ToDegrees(e.expr) } /** * Converts an angle measured in radians to an approximately equivalent angle measured in degrees. @@ -1607,7 +1598,7 @@ object functions { * @group math_funcs * @since 1.4.0 */ - def toRadians(e: Column): Column = ToRadians(e.expr) + def toRadians(e: Column): Column = withExpr { ToRadians(e.expr) } /** * Converts an angle measured in degrees to an approximately equivalent angle measured in radians. @@ -1628,7 +1619,7 @@ object functions { * @group misc_funcs * @since 1.5.0 */ - def md5(e: Column): Column = Md5(e.expr) + def md5(e: Column): Column = withExpr { Md5(e.expr) } /** * Calculates the SHA-1 digest of a binary column and returns the value @@ -1637,7 +1628,7 @@ object functions { * @group misc_funcs * @since 1.5.0 */ - def sha1(e: Column): Column = Sha1(e.expr) + def sha1(e: Column): Column = withExpr { Sha1(e.expr) } /** * Calculates the SHA-2 family of hash functions of a binary column and @@ -1652,7 +1643,7 @@ object functions { def sha2(e: Column, numBits: Int): Column = { require(Seq(0, 224, 256, 384, 512).contains(numBits), s"numBits $numBits is not in the permitted values (0, 224, 256, 384, 512)") - Sha2(e.expr, lit(numBits).expr) + withExpr { Sha2(e.expr, lit(numBits).expr) } } /** @@ -1662,7 +1653,7 @@ object functions { * @group misc_funcs * @since 1.5.0 */ - def crc32(e: Column): Column = Crc32(e.expr) + def crc32(e: Column): Column = withExpr { Crc32(e.expr) } ////////////////////////////////////////////////////////////////////////////////////////////// // String functions @@ -1675,7 +1666,7 @@ object functions { * @group string_funcs * @since 1.5.0 */ - def ascii(e: Column): Column = Ascii(e.expr) + def ascii(e: Column): Column = withExpr { Ascii(e.expr) } /** * Computes the BASE64 encoding of a binary column and returns it as a string column. @@ -1684,7 +1675,7 @@ object functions { * @group string_funcs * @since 1.5.0 */ - def base64(e: Column): Column = Base64(e.expr) + def base64(e: Column): Column = withExpr { Base64(e.expr) } /** * Concatenates multiple input string columns together into a single string column. @@ -1693,7 +1684,7 @@ object functions { * @since 1.5.0 */ @scala.annotation.varargs - def concat(exprs: Column*): Column = Concat(exprs.map(_.expr)) + def concat(exprs: Column*): Column = withExpr { Concat(exprs.map(_.expr)) } /** * Concatenates multiple input string columns together into a single string column, @@ -1703,7 +1694,7 @@ object functions { * @since 1.5.0 */ @scala.annotation.varargs - def concat_ws(sep: String, exprs: Column*): Column = { + def concat_ws(sep: String, exprs: Column*): Column = withExpr { ConcatWs(Literal.create(sep, StringType) +: exprs.map(_.expr)) } @@ -1715,7 +1706,9 @@ object functions { * @group string_funcs * @since 1.5.0 */ - def decode(value: Column, charset: String): Column = Decode(value.expr, lit(charset).expr) + def decode(value: Column, charset: String): Column = withExpr { + Decode(value.expr, lit(charset).expr) + } /** * Computes the first argument into a binary from a string using the provided character set @@ -1725,7 +1718,9 @@ object functions { * @group string_funcs * @since 1.5.0 */ - def encode(value: Column, charset: String): Column = Encode(value.expr, lit(charset).expr) + def encode(value: Column, charset: String): Column = withExpr { + Encode(value.expr, lit(charset).expr) + } /** * Formats numeric column x to a format like '#,###,###.##', rounded to d decimal places, @@ -1737,7 +1732,9 @@ object functions { * @group string_funcs * @since 1.5.0 */ - def format_number(x: Column, d: Int): Column = FormatNumber(x.expr, lit(d).expr) + def format_number(x: Column, d: Int): Column = withExpr { + FormatNumber(x.expr, lit(d).expr) + } /** * Formats the arguments in printf-style and returns the result as a string column. @@ -1746,7 +1743,7 @@ object functions { * @since 1.5.0 */ @scala.annotation.varargs - def format_string(format: String, arguments: Column*): Column = { + def format_string(format: String, arguments: Column*): Column = withExpr { FormatString((lit(format) +: arguments).map(_.expr): _*) } @@ -1759,7 +1756,7 @@ object functions { * @group string_funcs * @since 1.5.0 */ - def initcap(e: Column): Column = InitCap(e.expr) + def initcap(e: Column): Column = withExpr { InitCap(e.expr) } /** * Locate the position of the first occurrence of substr column in the given string. @@ -1771,7 +1768,9 @@ object functions { * @group string_funcs * @since 1.5.0 */ - def instr(str: Column, substring: String): Column = StringInstr(str.expr, lit(substring).expr) + def instr(str: Column, substring: String): Column = withExpr { + StringInstr(str.expr, lit(substring).expr) + } /** * Computes the length of a given string or binary column. @@ -1779,7 +1778,7 @@ object functions { * @group string_funcs * @since 1.5.0 */ - def length(e: Column): Column = Length(e.expr) + def length(e: Column): Column = withExpr { Length(e.expr) } /** * Converts a string column to lower case. @@ -1787,14 +1786,14 @@ object functions { * @group string_funcs * @since 1.3.0 */ - def lower(e: Column): Column = Lower(e.expr) + def lower(e: Column): Column = withExpr { Lower(e.expr) } /** * Computes the Levenshtein distance of the two given string columns. * @group string_funcs * @since 1.5.0 */ - def levenshtein(l: Column, r: Column): Column = Levenshtein(l.expr, r.expr) + def levenshtein(l: Column, r: Column): Column = withExpr { Levenshtein(l.expr, r.expr) } /** * Locate the position of the first occurrence of substr. @@ -1804,7 +1803,7 @@ object functions { * @group string_funcs * @since 1.5.0 */ - def locate(substr: String, str: Column): Column = { + def locate(substr: String, str: Column): Column = withExpr { new StringLocate(lit(substr).expr, str.expr) } @@ -1817,7 +1816,7 @@ object functions { * @group string_funcs * @since 1.5.0 */ - def locate(substr: String, str: Column, pos: Int): Column = { + def locate(substr: String, str: Column, pos: Int): Column = withExpr { StringLocate(lit(substr).expr, str.expr, lit(pos).expr) } @@ -1827,7 +1826,7 @@ object functions { * @group string_funcs * @since 1.5.0 */ - def lpad(str: Column, len: Int, pad: String): Column = { + def lpad(str: Column, len: Int, pad: String): Column = withExpr { StringLPad(str.expr, lit(len).expr, lit(pad).expr) } @@ -1837,7 +1836,7 @@ object functions { * @group string_funcs * @since 1.5.0 */ - def ltrim(e: Column): Column = StringTrimLeft(e.expr) + def ltrim(e: Column): Column = withExpr {StringTrimLeft(e.expr) } /** * Extract a specific(idx) group identified by a java regex, from the specified string column. @@ -1845,7 +1844,7 @@ object functions { * @group string_funcs * @since 1.5.0 */ - def regexp_extract(e: Column, exp: String, groupIdx: Int): Column = { + def regexp_extract(e: Column, exp: String, groupIdx: Int): Column = withExpr { RegExpExtract(e.expr, lit(exp).expr, lit(groupIdx).expr) } @@ -1855,7 +1854,7 @@ object functions { * @group string_funcs * @since 1.5.0 */ - def regexp_replace(e: Column, pattern: String, replacement: String): Column = { + def regexp_replace(e: Column, pattern: String, replacement: String): Column = withExpr { RegExpReplace(e.expr, lit(pattern).expr, lit(replacement).expr) } @@ -1866,7 +1865,7 @@ object functions { * @group string_funcs * @since 1.5.0 */ - def unbase64(e: Column): Column = UnBase64(e.expr) + def unbase64(e: Column): Column = withExpr { UnBase64(e.expr) } /** * Right-padded with pad to a length of len. @@ -1874,7 +1873,7 @@ object functions { * @group string_funcs * @since 1.5.0 */ - def rpad(str: Column, len: Int, pad: String): Column = { + def rpad(str: Column, len: Int, pad: String): Column = withExpr { StringRPad(str.expr, lit(len).expr, lit(pad).expr) } @@ -1884,7 +1883,7 @@ object functions { * @group string_funcs * @since 1.5.0 */ - def repeat(str: Column, n: Int): Column = { + def repeat(str: Column, n: Int): Column = withExpr { StringRepeat(str.expr, lit(n).expr) } @@ -1894,9 +1893,7 @@ object functions { * @group string_funcs * @since 1.5.0 */ - def reverse(str: Column): Column = { - StringReverse(str.expr) - } + def reverse(str: Column): Column = withExpr { StringReverse(str.expr) } /** * Trim the spaces from right end for the specified string value. @@ -1904,7 +1901,7 @@ object functions { * @group string_funcs * @since 1.5.0 */ - def rtrim(e: Column): Column = StringTrimRight(e.expr) + def rtrim(e: Column): Column = withExpr { StringTrimRight(e.expr) } /** * * Return the soundex code for the specified expression. @@ -1912,7 +1909,7 @@ object functions { * @group string_funcs * @since 1.5.0 */ - def soundex(e: Column): Column = SoundEx(e.expr) + def soundex(e: Column): Column = withExpr { SoundEx(e.expr) } /** * Splits str around pattern (pattern is a regular expression). @@ -1921,7 +1918,7 @@ object functions { * @group string_funcs * @since 1.5.0 */ - def split(str: Column, pattern: String): Column = { + def split(str: Column, pattern: String): Column = withExpr { StringSplit(str.expr, lit(pattern).expr) } @@ -1933,8 +1930,9 @@ object functions { * @group string_funcs * @since 1.5.0 */ - def substring(str: Column, pos: Int, len: Int): Column = + def substring(str: Column, pos: Int, len: Int): Column = withExpr { Substring(str.expr, lit(pos).expr, lit(len).expr) + } /** * Returns the substring from string str before count occurrences of the delimiter delim. @@ -1944,8 +1942,9 @@ object functions { * * @group string_funcs */ - def substring_index(str: Column, delim: String, count: Int): Column = + def substring_index(str: Column, delim: String, count: Int): Column = withExpr { SubstringIndex(str.expr, lit(delim).expr, lit(count).expr) + } /** * Translate any character in the src by a character in replaceString. @@ -1956,8 +1955,9 @@ object functions { * @group string_funcs * @since 1.5.0 */ - def translate(src: Column, matchingString: String, replaceString: String): Column = + def translate(src: Column, matchingString: String, replaceString: String): Column = withExpr { StringTranslate(src.expr, lit(matchingString).expr, lit(replaceString).expr) + } /** * Trim the spaces from both ends for the specified string column. @@ -1965,7 +1965,7 @@ object functions { * @group string_funcs * @since 1.5.0 */ - def trim(e: Column): Column = StringTrim(e.expr) + def trim(e: Column): Column = withExpr { StringTrim(e.expr) } /** * Converts a string column to upper case. @@ -1973,7 +1973,7 @@ object functions { * @group string_funcs * @since 1.3.0 */ - def upper(e: Column): Column = Upper(e.expr) + def upper(e: Column): Column = withExpr { Upper(e.expr) } ////////////////////////////////////////////////////////////////////////////////////////////// // DateTime functions @@ -1985,8 +1985,9 @@ object functions { * @group datetime_funcs * @since 1.5.0 */ - def add_months(startDate: Column, numMonths: Int): Column = + def add_months(startDate: Column, numMonths: Int): Column = withExpr { AddMonths(startDate.expr, Literal(numMonths)) + } /** * Returns the current date as a date column. @@ -1994,7 +1995,7 @@ object functions { * @group datetime_funcs * @since 1.5.0 */ - def current_date(): Column = CurrentDate() + def current_date(): Column = withExpr { CurrentDate() } /** * Returns the current timestamp as a timestamp column. @@ -2002,7 +2003,7 @@ object functions { * @group datetime_funcs * @since 1.5.0 */ - def current_timestamp(): Column = CurrentTimestamp() + def current_timestamp(): Column = withExpr { CurrentTimestamp() } /** * Converts a date/timestamp/string to a value of string in the format specified by the date @@ -2017,71 +2018,72 @@ object functions { * @group datetime_funcs * @since 1.5.0 */ - def date_format(dateExpr: Column, format: String): Column = + def date_format(dateExpr: Column, format: String): Column = withExpr { DateFormatClass(dateExpr.expr, Literal(format)) + } /** * Returns the date that is `days` days after `start` * @group datetime_funcs * @since 1.5.0 */ - def date_add(start: Column, days: Int): Column = DateAdd(start.expr, Literal(days)) + def date_add(start: Column, days: Int): Column = withExpr { DateAdd(start.expr, Literal(days)) } /** * Returns the date that is `days` days before `start` * @group datetime_funcs * @since 1.5.0 */ - def date_sub(start: Column, days: Int): Column = DateSub(start.expr, Literal(days)) + def date_sub(start: Column, days: Int): Column = withExpr { DateSub(start.expr, Literal(days)) } /** * Returns the number of days from `start` to `end`. * @group datetime_funcs * @since 1.5.0 */ - def datediff(end: Column, start: Column): Column = DateDiff(end.expr, start.expr) + def datediff(end: Column, start: Column): Column = withExpr { DateDiff(end.expr, start.expr) } /** * Extracts the year as an integer from a given date/timestamp/string. * @group datetime_funcs * @since 1.5.0 */ - def year(e: Column): Column = Year(e.expr) + def year(e: Column): Column = withExpr { Year(e.expr) } /** * Extracts the quarter as an integer from a given date/timestamp/string. * @group datetime_funcs * @since 1.5.0 */ - def quarter(e: Column): Column = Quarter(e.expr) + def quarter(e: Column): Column = withExpr { Quarter(e.expr) } /** * Extracts the month as an integer from a given date/timestamp/string. * @group datetime_funcs * @since 1.5.0 */ - def month(e: Column): Column = Month(e.expr) + def month(e: Column): Column = withExpr { Month(e.expr) } /** * Extracts the day of the month as an integer from a given date/timestamp/string. * @group datetime_funcs * @since 1.5.0 */ - def dayofmonth(e: Column): Column = DayOfMonth(e.expr) + def dayofmonth(e: Column): Column = withExpr { DayOfMonth(e.expr) } /** * Extracts the day of the year as an integer from a given date/timestamp/string. * @group datetime_funcs * @since 1.5.0 */ - def dayofyear(e: Column): Column = DayOfYear(e.expr) + def dayofyear(e: Column): Column = withExpr { DayOfYear(e.expr) } /** * Extracts the hours as an integer from a given date/timestamp/string. * @group datetime_funcs * @since 1.5.0 */ - def hour(e: Column): Column = Hour(e.expr) + def hour(e: Column): Column = withExpr { Hour(e.expr) } /** * Given a date column, returns the last day of the month which the given date belongs to. @@ -2091,21 +2093,23 @@ object functions { * @group datetime_funcs * @since 1.5.0 */ - def last_day(e: Column): Column = LastDay(e.expr) + def last_day(e: Column): Column = withExpr { LastDay(e.expr) } /** * Extracts the minutes as an integer from a given date/timestamp/string. * @group datetime_funcs * @since 1.5.0 */ - def minute(e: Column): Column = Minute(e.expr) + def minute(e: Column): Column = withExpr { Minute(e.expr) } /* * Returns number of months between dates `date1` and `date2`. * @group datetime_funcs * @since 1.5.0 */ - def months_between(date1: Column, date2: Column): Column = MonthsBetween(date1.expr, date2.expr) + def months_between(date1: Column, date2: Column): Column = withExpr { + MonthsBetween(date1.expr, date2.expr) + } /** * Given a date column, returns the first date which is later than the value of the date column @@ -2120,21 +2124,23 @@ object functions { * @group datetime_funcs * @since 1.5.0 */ - def next_day(date: Column, dayOfWeek: String): Column = NextDay(date.expr, lit(dayOfWeek).expr) + def next_day(date: Column, dayOfWeek: String): Column = withExpr { + NextDay(date.expr, lit(dayOfWeek).expr) + } /** * Extracts the seconds as an integer from a given date/timestamp/string. * @group datetime_funcs * @since 1.5.0 */ - def second(e: Column): Column = Second(e.expr) + def second(e: Column): Column = withExpr { Second(e.expr) } /** * Extracts the week number as an integer from a given date/timestamp/string. * @group datetime_funcs * @since 1.5.0 */ - def weekofyear(e: Column): Column = WeekOfYear(e.expr) + def weekofyear(e: Column): Column = withExpr { WeekOfYear(e.expr) } /** * Converts the number of seconds from unix epoch (1970-01-01 00:00:00 UTC) to a string @@ -2143,7 +2149,9 @@ object functions { * @group datetime_funcs * @since 1.5.0 */ - def from_unixtime(ut: Column): Column = FromUnixTime(ut.expr, Literal("yyyy-MM-dd HH:mm:ss")) + def from_unixtime(ut: Column): Column = withExpr { + FromUnixTime(ut.expr, Literal("yyyy-MM-dd HH:mm:ss")) + } /** * Converts the number of seconds from unix epoch (1970-01-01 00:00:00 UTC) to a string @@ -2152,14 +2160,18 @@ object functions { * @group datetime_funcs * @since 1.5.0 */ - def from_unixtime(ut: Column, f: String): Column = FromUnixTime(ut.expr, Literal(f)) + def from_unixtime(ut: Column, f: String): Column = withExpr { + FromUnixTime(ut.expr, Literal(f)) + } /** * Gets current Unix timestamp in seconds. * @group datetime_funcs * @since 1.5.0 */ - def unix_timestamp(): Column = UnixTimestamp(CurrentTimestamp(), Literal("yyyy-MM-dd HH:mm:ss")) + def unix_timestamp(): Column = withExpr { + UnixTimestamp(CurrentTimestamp(), Literal("yyyy-MM-dd HH:mm:ss")) + } /** * Converts time string in format yyyy-MM-dd HH:mm:ss to Unix timestamp (in seconds), @@ -2167,7 +2179,9 @@ object functions { * @group datetime_funcs * @since 1.5.0 */ - def unix_timestamp(s: Column): Column = UnixTimestamp(s.expr, Literal("yyyy-MM-dd HH:mm:ss")) + def unix_timestamp(s: Column): Column = withExpr { + UnixTimestamp(s.expr, Literal("yyyy-MM-dd HH:mm:ss")) + } /** * Convert time string with given pattern @@ -2176,7 +2190,7 @@ object functions { * @group datetime_funcs * @since 1.5.0 */ - def unix_timestamp(s: Column, p: String): Column = UnixTimestamp(s.expr, Literal(p)) + def unix_timestamp(s: Column, p: String): Column = withExpr {UnixTimestamp(s.expr, Literal(p)) } /** * Converts the column into DateType. @@ -2184,7 +2198,7 @@ object functions { * @group datetime_funcs * @since 1.5.0 */ - def to_date(e: Column): Column = ToDate(e.expr) + def to_date(e: Column): Column = withExpr { ToDate(e.expr) } /** * Returns date truncated to the unit specified by the format. @@ -2195,22 +2209,27 @@ object functions { * @group datetime_funcs * @since 1.5.0 */ - def trunc(date: Column, format: String): Column = TruncDate(date.expr, Literal(format)) + def trunc(date: Column, format: String): Column = withExpr { + TruncDate(date.expr, Literal(format)) + } /** * Assumes given timestamp is UTC and converts to given timezone. * @group datetime_funcs * @since 1.5.0 */ - def from_utc_timestamp(ts: Column, tz: String): Column = - FromUTCTimestamp(ts.expr, Literal(tz).expr) + def from_utc_timestamp(ts: Column, tz: String): Column = withExpr { + FromUTCTimestamp(ts.expr, Literal(tz)) + } /** * Assumes given timestamp is in given timezone and converts to UTC. * @group datetime_funcs * @since 1.5.0 */ - def to_utc_timestamp(ts: Column, tz: String): Column = ToUTCTimestamp(ts.expr, Literal(tz).expr) + def to_utc_timestamp(ts: Column, tz: String): Column = withExpr { + ToUTCTimestamp(ts.expr, Literal(tz)) + } ////////////////////////////////////////////////////////////////////////////////////////////// // Collection functions @@ -2221,8 +2240,9 @@ object functions { * @group collection_funcs * @since 1.5.0 */ - def array_contains(column: Column, value: Any): Column = + def array_contains(column: Column, value: Any): Column = withExpr { ArrayContains(column.expr, Literal(value)) + } /** * Creates a new row for each element in the given array or map column. @@ -2230,7 +2250,7 @@ object functions { * @group collection_funcs * @since 1.3.0 */ - def explode(e: Column): Column = Explode(e.expr) + def explode(e: Column): Column = withExpr { Explode(e.expr) } /** * Returns length of array or map. @@ -2238,7 +2258,7 @@ object functions { * @group collection_funcs * @since 1.5.0 */ - def size(e: Column): Column = Size(e.expr) + def size(e: Column): Column = withExpr { Size(e.expr) } /** * Sorts the input array for the given column in ascending order, @@ -2256,7 +2276,7 @@ object functions { * @group collection_funcs * @since 1.5.0 */ - def sort_array(e: Column, asc: Boolean): Column = SortArray(e.expr, lit(asc).expr) + def sort_array(e: Column, asc: Boolean): Column = withExpr { SortArray(e.expr, lit(asc).expr) } ////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////// @@ -2296,11 +2316,10 @@ object functions { * @deprecated As of 1.5.0, since it's redundant with udf() */ @deprecated("Use udf", "1.5.0") - def callUDF(f: Function$x[$fTypes], returnType: DataType${if (args.length > 0) ", " + args else ""}): Column = { + def callUDF(f: Function$x[$fTypes], returnType: DataType${if (args.length > 0) ", " + args else ""}): Column = withExpr { ScalaUDF(f, returnType, Seq($argsInUDF)) }""") } - } */ /** * Defines a user-defined function of 0 arguments as user-defined function (UDF). @@ -2435,147 +2454,146 @@ object functions { } ////////////////////////////////////////////////////////////////////////////////////////////////// - /** - * Call a Scala function of 0 arguments as user-defined function (UDF). This requires - * you to specify the return data type. - * - * @group udf_funcs - * @since 1.3.0 - * @deprecated As of 1.5.0, since it's redundant with udf() - */ + * Call a Scala function of 0 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + * + * @group udf_funcs + * @since 1.3.0 + * @deprecated As of 1.5.0, since it's redundant with udf() + */ @deprecated("Use udf", "1.5.0") - def callUDF(f: Function0[_], returnType: DataType): Column = { + def callUDF(f: Function0[_], returnType: DataType): Column = withExpr { ScalaUDF(f, returnType, Seq()) } /** - * Call a Scala function of 1 arguments as user-defined function (UDF). This requires - * you to specify the return data type. - * - * @group udf_funcs - * @since 1.3.0 - * @deprecated As of 1.5.0, since it's redundant with udf() - */ + * Call a Scala function of 1 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + * + * @group udf_funcs + * @since 1.3.0 + * @deprecated As of 1.5.0, since it's redundant with udf() + */ @deprecated("Use udf", "1.5.0") - def callUDF(f: Function1[_, _], returnType: DataType, arg1: Column): Column = { + def callUDF(f: Function1[_, _], returnType: DataType, arg1: Column): Column = withExpr { ScalaUDF(f, returnType, Seq(arg1.expr)) } /** - * Call a Scala function of 2 arguments as user-defined function (UDF). This requires - * you to specify the return data type. - * - * @group udf_funcs - * @since 1.3.0 - * @deprecated As of 1.5.0, since it's redundant with udf() - */ + * Call a Scala function of 2 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + * + * @group udf_funcs + * @since 1.3.0 + * @deprecated As of 1.5.0, since it's redundant with udf() + */ @deprecated("Use udf", "1.5.0") - def callUDF(f: Function2[_, _, _], returnType: DataType, arg1: Column, arg2: Column): Column = { + def callUDF(f: Function2[_, _, _], returnType: DataType, arg1: Column, arg2: Column): Column = withExpr { ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr)) } /** - * Call a Scala function of 3 arguments as user-defined function (UDF). This requires - * you to specify the return data type. - * - * @group udf_funcs - * @since 1.3.0 - * @deprecated As of 1.5.0, since it's redundant with udf() - */ + * Call a Scala function of 3 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + * + * @group udf_funcs + * @since 1.3.0 + * @deprecated As of 1.5.0, since it's redundant with udf() + */ @deprecated("Use udf", "1.5.0") - def callUDF(f: Function3[_, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column): Column = { + def callUDF(f: Function3[_, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column): Column = withExpr { ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr)) } /** - * Call a Scala function of 4 arguments as user-defined function (UDF). This requires - * you to specify the return data type. - * - * @group udf_funcs - * @since 1.3.0 - * @deprecated As of 1.5.0, since it's redundant with udf() - */ + * Call a Scala function of 4 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + * + * @group udf_funcs + * @since 1.3.0 + * @deprecated As of 1.5.0, since it's redundant with udf() + */ @deprecated("Use udf", "1.5.0") - def callUDF(f: Function4[_, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column): Column = { + def callUDF(f: Function4[_, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column): Column = withExpr { ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr)) } /** - * Call a Scala function of 5 arguments as user-defined function (UDF). This requires - * you to specify the return data type. - * - * @group udf_funcs - * @since 1.3.0 - * @deprecated As of 1.5.0, since it's redundant with udf() - */ + * Call a Scala function of 5 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + * + * @group udf_funcs + * @since 1.3.0 + * @deprecated As of 1.5.0, since it's redundant with udf() + */ @deprecated("Use udf", "1.5.0") - def callUDF(f: Function5[_, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column): Column = { + def callUDF(f: Function5[_, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column): Column = withExpr { ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr)) } /** - * Call a Scala function of 6 arguments as user-defined function (UDF). This requires - * you to specify the return data type. - * - * @group udf_funcs - * @since 1.3.0 - * @deprecated As of 1.5.0, since it's redundant with udf() - */ + * Call a Scala function of 6 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + * + * @group udf_funcs + * @since 1.3.0 + * @deprecated As of 1.5.0, since it's redundant with udf() + */ @deprecated("Use udf", "1.5.0") - def callUDF(f: Function6[_, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column): Column = { + def callUDF(f: Function6[_, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column): Column = withExpr { ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr)) } /** - * Call a Scala function of 7 arguments as user-defined function (UDF). This requires - * you to specify the return data type. - * - * @group udf_funcs - * @since 1.3.0 - * @deprecated As of 1.5.0, since it's redundant with udf() - */ + * Call a Scala function of 7 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + * + * @group udf_funcs + * @since 1.3.0 + * @deprecated As of 1.5.0, since it's redundant with udf() + */ @deprecated("Use udf", "1.5.0") - def callUDF(f: Function7[_, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column): Column = { + def callUDF(f: Function7[_, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column): Column = withExpr { ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr)) } /** - * Call a Scala function of 8 arguments as user-defined function (UDF). This requires - * you to specify the return data type. - * - * @group udf_funcs - * @since 1.3.0 - * @deprecated As of 1.5.0, since it's redundant with udf() - */ + * Call a Scala function of 8 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + * + * @group udf_funcs + * @since 1.3.0 + * @deprecated As of 1.5.0, since it's redundant with udf() + */ @deprecated("Use udf", "1.5.0") - def callUDF(f: Function8[_, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column): Column = { + def callUDF(f: Function8[_, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column): Column = withExpr { ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr)) } /** - * Call a Scala function of 9 arguments as user-defined function (UDF). This requires - * you to specify the return data type. - * - * @group udf_funcs - * @since 1.3.0 - * @deprecated As of 1.5.0, since it's redundant with udf() - */ + * Call a Scala function of 9 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + * + * @group udf_funcs + * @since 1.3.0 + * @deprecated As of 1.5.0, since it's redundant with udf() + */ @deprecated("Use udf", "1.5.0") - def callUDF(f: Function9[_, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column): Column = { + def callUDF(f: Function9[_, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column): Column = withExpr { ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr)) } /** - * Call a Scala function of 10 arguments as user-defined function (UDF). This requires - * you to specify the return data type. - * - * @group udf_funcs - * @since 1.3.0 - * @deprecated As of 1.5.0, since it's redundant with udf() - */ + * Call a Scala function of 10 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + * + * @group udf_funcs + * @since 1.3.0 + * @deprecated As of 1.5.0, since it's redundant with udf() + */ @deprecated("Use udf", "1.5.0") - def callUDF(f: Function10[_, _, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column): Column = { + def callUDF(f: Function10[_, _, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column): Column = withExpr { ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr)) } @@ -2597,7 +2615,7 @@ object functions { * @since 1.5.0 */ @scala.annotation.varargs - def callUDF(udfName: String, cols: Column*): Column = { + def callUDF(udfName: String, cols: Column*): Column = withExpr { UnresolvedFunction(udfName, cols.map(_.expr), isDistinct = false) } @@ -2618,7 +2636,7 @@ object functions { * @deprecated As of 1.5.0, since it was not coherent to have two functions callUdf and callUDF */ @deprecated("Use callUDF", "1.5.0") - def callUdf(udfName: String, cols: Column*): Column = { + def callUdf(udfName: String, cols: Column*): Column = withExpr { // Note: we avoid using closures here because on file systems that are case-insensitive, the // compiled class file for the closure here will conflict with the one in callUDF (upper case). val exprs = new Array[Expression](cols.size) From 244010624200eddea6dfd1b2c89f40be45212e96 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 5 Nov 2015 16:34:10 -0800 Subject: [PATCH 02/88] [SPARK-11542] [SPARKR] fix glm with long fomular Because deparse() will break the long string into multiple lines, the deserialization will fail Author: Davies Liu Closes #9510 from davies/fix_glm. --- R/pkg/R/mllib.R | 3 ++- R/pkg/inst/tests/test_mllib.R | 12 ++++++++++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R index 60bfadb8e7503..b0d73dd93a79d 100644 --- a/R/pkg/R/mllib.R +++ b/R/pkg/R/mllib.R @@ -48,8 +48,9 @@ setMethod("glm", signature(formula = "formula", family = "ANY", data = "DataFram function(formula, family = c("gaussian", "binomial"), data, lambda = 0, alpha = 0, standardize = TRUE, solver = "auto") { family <- match.arg(family) + formula <- paste(deparse(formula), collapse="") model <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", - "fitRModelFormula", deparse(formula), data@sdf, family, lambda, + "fitRModelFormula", formula, data@sdf, family, lambda, alpha, standardize, solver) return(new("PipelineModel", model = model)) }) diff --git a/R/pkg/inst/tests/test_mllib.R b/R/pkg/inst/tests/test_mllib.R index 032cfef061fd3..4761e285a2479 100644 --- a/R/pkg/inst/tests/test_mllib.R +++ b/R/pkg/inst/tests/test_mllib.R @@ -33,6 +33,18 @@ test_that("glm and predict", { expect_equal(typeof(take(select(prediction, "prediction"), 1)$prediction), "double") }) +test_that("glm should work with long formula", { + training <- createDataFrame(sqlContext, iris) + training$LongLongLongLongLongName <- training$Sepal_Width + training$VeryLongLongLongLonLongName <- training$Sepal_Length + training$AnotherLongLongLongLongName <- training$Species + model <- glm(LongLongLongLongLongName ~ VeryLongLongLongLonLongName + AnotherLongLongLongLongName, + data = training) + vals <- collect(select(predict(model, training), "prediction")) + rVals <- predict(glm(Sepal.Width ~ Sepal.Length + Species, data = iris), iris) + expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals) +}) + test_that("predictions match with native glm", { training <- createDataFrame(sqlContext, iris) model <- glm(Sepal_Width ~ Sepal_Length + Species, data = training) From 07414afac9a100ede1dee5f3d45a657802c8bd2a Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 5 Nov 2015 17:02:22 -0800 Subject: [PATCH 03/88] [SPARK-11537] [SQL] fix negative hours/minutes/seconds Currently, if the Timestamp is before epoch (1970/01/01), the hours, minutes and seconds will be negative (also rounding up). Author: Davies Liu Closes #9502 from davies/neg_hour. --- .../sql/catalyst/util/DateTimeUtils.scala | 23 ++++++++++++------- .../catalyst/util/DateTimeUtilsSuite.scala | 13 +++++++++++ 2 files changed, 28 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index 781ed1688a327..f5fff90e5a542 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -392,29 +392,36 @@ object DateTimeUtils { Some((c.getTimeInMillis / MILLIS_PER_DAY).toInt) } + /** + * Returns the microseconds since year zero (-17999) from microseconds since epoch. + */ + def absoluteMicroSecond(microsec: SQLTimestamp): SQLTimestamp = { + microsec + toYearZero * MICROS_PER_DAY + } + /** * Returns the hour value of a given timestamp value. The timestamp is expressed in microseconds. */ - def getHours(timestamp: SQLTimestamp): Int = { - val localTs = (timestamp / 1000) + defaultTimeZone.getOffset(timestamp / 1000) - ((localTs / 1000 / 3600) % 24).toInt + def getHours(microsec: SQLTimestamp): Int = { + val localTs = absoluteMicroSecond(microsec) + defaultTimeZone.getOffset(microsec / 1000) * 1000L + ((localTs / MICROS_PER_SECOND / 3600) % 24).toInt } /** * Returns the minute value of a given timestamp value. The timestamp is expressed in * microseconds. */ - def getMinutes(timestamp: SQLTimestamp): Int = { - val localTs = (timestamp / 1000) + defaultTimeZone.getOffset(timestamp / 1000) - ((localTs / 1000 / 60) % 60).toInt + def getMinutes(microsec: SQLTimestamp): Int = { + val localTs = absoluteMicroSecond(microsec) + defaultTimeZone.getOffset(microsec / 1000) * 1000L + ((localTs / MICROS_PER_SECOND / 60) % 60).toInt } /** * Returns the second value of a given timestamp value. The timestamp is expressed in * microseconds. */ - def getSeconds(timestamp: SQLTimestamp): Int = { - ((timestamp / 1000 / 1000) % 60).toInt + def getSeconds(microsec: SQLTimestamp): Int = { + ((absoluteMicroSecond(microsec) / MICROS_PER_SECOND) % 60).toInt } private[this] def isLeapYear(year: Int): Boolean = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala index 46335941b62d6..64d15e6b910c1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala @@ -358,6 +358,19 @@ class DateTimeUtilsSuite extends SparkFunSuite { assert(getSeconds(c.getTimeInMillis * 1000) === 9) } + test("hours / miniute / seconds") { + Seq(Timestamp.valueOf("2015-06-11 10:12:35.789"), + Timestamp.valueOf("2015-06-11 20:13:40.789"), + Timestamp.valueOf("1900-06-11 12:14:50.789"), + Timestamp.valueOf("1700-02-28 12:14:50.123456")).foreach { t => + val us = fromJavaTimestamp(t) + assert(toJavaTimestamp(us) === t) + assert(getHours(us) === t.getHours) + assert(getMinutes(us) === t.getMinutes) + assert(getSeconds(us) === t.getSeconds) + } + } + test("get day in year") { val c = Calendar.getInstance() c.set(2015, 2, 18, 0, 0, 0) From 6091e91fca58078a0f1d9c35d68c0ae7205a534c Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 5 Nov 2015 17:10:35 -0800 Subject: [PATCH 04/88] Revert "[SPARK-11469][SQL] Allow users to define nondeterministic udfs." This reverts commit 9cf56c96b7d02a14175d40b336da14c2e1c88339. --- project/MimaExcludes.scala | 47 ----- .../sql/catalyst/expressions/ScalaUDF.scala | 7 +- .../apache/spark/sql/UDFRegistration.scala | 164 ++++++++---------- .../spark/sql/UserDefinedFunction.scala | 13 +- .../scala/org/apache/spark/sql/UDFSuite.scala | 105 ----------- .../datasources/parquet/ParquetIOSuite.scala | 4 +- 6 files changed, 78 insertions(+), 262 deletions(-) diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 90dc947d4e588..40f5c9fec8bb8 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -114,53 +114,6 @@ object MimaExcludes { "org.apache.spark.rdd.MapPartitionsWithPreparationRDD"), ProblemFilters.exclude[MissingClassProblem]( "org.apache.spark.rdd.MapPartitionsWithPreparationRDD$") - ) ++ Seq( - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$2"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$3"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$4"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$5"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$6"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$7"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$8"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$9"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$10"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$11"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$12"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$13"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$14"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$15"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$16"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$17"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$18"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$19"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$20"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$21"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$22"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$23"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$24") ) ++ Seq( // SPARK-11485 ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.DataFrameHolder.df") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala index a04af7f1dd877..11c7950c0613b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala @@ -30,18 +30,13 @@ case class ScalaUDF( function: AnyRef, dataType: DataType, children: Seq[Expression], - inputTypes: Seq[DataType] = Nil, - isDeterministic: Boolean = true) + inputTypes: Seq[DataType] = Nil) extends Expression with ImplicitCastInputTypes with CodegenFallback { override def nullable: Boolean = true override def toString: String = s"UDF(${children.mkString(",")})" - override def foldable: Boolean = deterministic && children.forall(_.foldable) - - override def deterministic: Boolean = isDeterministic && children.forall(_.deterministic) - // scalastyle:off /** This method has been generated by this script diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala index f5b95e13e47bc..fc4d0938c533a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala @@ -58,10 +58,8 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { * Register a user-defined aggregate function (UDAF). * * @param name the name of the UDAF. - * @param udaf the UDAF that needs to be registered. + * @param udaf the UDAF needs to be registered. * @return the registered UDAF. - * - * @since 1.5.0 */ def register( name: String, @@ -71,22 +69,6 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { udaf } - /** - * Register a user-defined function (UDF). - * - * @param name the name of the UDF. - * @param udf the UDF that needs to be registered. - * @return the registered UDF. - * - * @since 1.6.0 - */ - def register( - name: String, - udf: UserDefinedFunction): UserDefinedFunction = { - functionRegistry.registerFunction(name, udf.builder) - udf - } - // scalastyle:off /* register 0-22 were generated by this script @@ -104,9 +86,9 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register[$typeTags](name: String, func: Function$x[$types]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType val inputTypes = Try($inputTypes).getOrElse(Nil) - val udf = UserDefinedFunction(func, dataType, inputTypes) - functionRegistry.registerFunction(name, udf.builder) - udf + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) + functionRegistry.registerFunction(name, builder) + UserDefinedFunction(func, dataType) }""") } @@ -136,9 +118,9 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register[RT: TypeTag](name: String, func: Function0[RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType val inputTypes = Try(Nil).getOrElse(Nil) - val udf = UserDefinedFunction(func, dataType, inputTypes) - functionRegistry.registerFunction(name, udf.builder) - udf + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) + functionRegistry.registerFunction(name, builder) + UserDefinedFunction(func, dataType) } /** @@ -149,9 +131,9 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register[RT: TypeTag, A1: TypeTag](name: String, func: Function1[A1, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: Nil).getOrElse(Nil) - val udf = UserDefinedFunction(func, dataType, inputTypes) - functionRegistry.registerFunction(name, udf.builder) - udf + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) + functionRegistry.registerFunction(name, builder) + UserDefinedFunction(func, dataType) } /** @@ -162,9 +144,9 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register[RT: TypeTag, A1: TypeTag, A2: TypeTag](name: String, func: Function2[A1, A2, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: Nil).getOrElse(Nil) - val udf = UserDefinedFunction(func, dataType, inputTypes) - functionRegistry.registerFunction(name, udf.builder) - udf + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) + functionRegistry.registerFunction(name, builder) + UserDefinedFunction(func, dataType) } /** @@ -175,9 +157,9 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag](name: String, func: Function3[A1, A2, A3, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: Nil).getOrElse(Nil) - val udf = UserDefinedFunction(func, dataType, inputTypes) - functionRegistry.registerFunction(name, udf.builder) - udf + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) + functionRegistry.registerFunction(name, builder) + UserDefinedFunction(func, dataType) } /** @@ -188,9 +170,9 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag](name: String, func: Function4[A1, A2, A3, A4, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: Nil).getOrElse(Nil) - val udf = UserDefinedFunction(func, dataType, inputTypes) - functionRegistry.registerFunction(name, udf.builder) - udf + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) + functionRegistry.registerFunction(name, builder) + UserDefinedFunction(func, dataType) } /** @@ -201,9 +183,9 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag](name: String, func: Function5[A1, A2, A3, A4, A5, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: Nil).getOrElse(Nil) - val udf = UserDefinedFunction(func, dataType, inputTypes) - functionRegistry.registerFunction(name, udf.builder) - udf + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) + functionRegistry.registerFunction(name, builder) + UserDefinedFunction(func, dataType) } /** @@ -214,9 +196,9 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag](name: String, func: Function6[A1, A2, A3, A4, A5, A6, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: Nil).getOrElse(Nil) - val udf = UserDefinedFunction(func, dataType, inputTypes) - functionRegistry.registerFunction(name, udf.builder) - udf + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) + functionRegistry.registerFunction(name, builder) + UserDefinedFunction(func, dataType) } /** @@ -227,9 +209,9 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag](name: String, func: Function7[A1, A2, A3, A4, A5, A6, A7, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: Nil).getOrElse(Nil) - val udf = UserDefinedFunction(func, dataType, inputTypes) - functionRegistry.registerFunction(name, udf.builder) - udf + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) + functionRegistry.registerFunction(name, builder) + UserDefinedFunction(func, dataType) } /** @@ -240,9 +222,9 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag](name: String, func: Function8[A1, A2, A3, A4, A5, A6, A7, A8, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: Nil).getOrElse(Nil) - val udf = UserDefinedFunction(func, dataType, inputTypes) - functionRegistry.registerFunction(name, udf.builder) - udf + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) + functionRegistry.registerFunction(name, builder) + UserDefinedFunction(func, dataType) } /** @@ -253,9 +235,9 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag](name: String, func: Function9[A1, A2, A3, A4, A5, A6, A7, A8, A9, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: Nil).getOrElse(Nil) - val udf = UserDefinedFunction(func, dataType, inputTypes) - functionRegistry.registerFunction(name, udf.builder) - udf + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) + functionRegistry.registerFunction(name, builder) + UserDefinedFunction(func, dataType) } /** @@ -266,9 +248,9 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag](name: String, func: Function10[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: Nil).getOrElse(Nil) - val udf = UserDefinedFunction(func, dataType, inputTypes) - functionRegistry.registerFunction(name, udf.builder) - udf + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) + functionRegistry.registerFunction(name, builder) + UserDefinedFunction(func, dataType) } /** @@ -279,9 +261,9 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag](name: String, func: Function11[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: Nil).getOrElse(Nil) - val udf = UserDefinedFunction(func, dataType, inputTypes) - functionRegistry.registerFunction(name, udf.builder) - udf + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) + functionRegistry.registerFunction(name, builder) + UserDefinedFunction(func, dataType) } /** @@ -292,9 +274,9 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag](name: String, func: Function12[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: Nil).getOrElse(Nil) - val udf = UserDefinedFunction(func, dataType, inputTypes) - functionRegistry.registerFunction(name, udf.builder) - udf + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) + functionRegistry.registerFunction(name, builder) + UserDefinedFunction(func, dataType) } /** @@ -305,9 +287,9 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag](name: String, func: Function13[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: Nil).getOrElse(Nil) - val udf = UserDefinedFunction(func, dataType, inputTypes) - functionRegistry.registerFunction(name, udf.builder) - udf + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) + functionRegistry.registerFunction(name, builder) + UserDefinedFunction(func, dataType) } /** @@ -318,9 +300,9 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag](name: String, func: Function14[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: Nil).getOrElse(Nil) - val udf = UserDefinedFunction(func, dataType, inputTypes) - functionRegistry.registerFunction(name, udf.builder) - udf + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) + functionRegistry.registerFunction(name, builder) + UserDefinedFunction(func, dataType) } /** @@ -331,9 +313,9 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag](name: String, func: Function15[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: Nil).getOrElse(Nil) - val udf = UserDefinedFunction(func, dataType, inputTypes) - functionRegistry.registerFunction(name, udf.builder) - udf + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) + functionRegistry.registerFunction(name, builder) + UserDefinedFunction(func, dataType) } /** @@ -344,9 +326,9 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag](name: String, func: Function16[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: Nil).getOrElse(Nil) - val udf = UserDefinedFunction(func, dataType, inputTypes) - functionRegistry.registerFunction(name, udf.builder) - udf + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) + functionRegistry.registerFunction(name, builder) + UserDefinedFunction(func, dataType) } /** @@ -357,9 +339,9 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag](name: String, func: Function17[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: Nil).getOrElse(Nil) - val udf = UserDefinedFunction(func, dataType, inputTypes) - functionRegistry.registerFunction(name, udf.builder) - udf + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) + functionRegistry.registerFunction(name, builder) + UserDefinedFunction(func, dataType) } /** @@ -370,9 +352,9 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag](name: String, func: Function18[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: Nil).getOrElse(Nil) - val udf = UserDefinedFunction(func, dataType, inputTypes) - functionRegistry.registerFunction(name, udf.builder) - udf + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) + functionRegistry.registerFunction(name, builder) + UserDefinedFunction(func, dataType) } /** @@ -383,9 +365,9 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag](name: String, func: Function19[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: ScalaReflection.schemaFor[A19].dataType :: Nil).getOrElse(Nil) - val udf = UserDefinedFunction(func, dataType, inputTypes) - functionRegistry.registerFunction(name, udf.builder) - udf + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) + functionRegistry.registerFunction(name, builder) + UserDefinedFunction(func, dataType) } /** @@ -396,9 +378,9 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag](name: String, func: Function20[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: ScalaReflection.schemaFor[A19].dataType :: ScalaReflection.schemaFor[A20].dataType :: Nil).getOrElse(Nil) - val udf = UserDefinedFunction(func, dataType, inputTypes) - functionRegistry.registerFunction(name, udf.builder) - udf + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) + functionRegistry.registerFunction(name, builder) + UserDefinedFunction(func, dataType) } /** @@ -409,9 +391,9 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag, A21: TypeTag](name: String, func: Function21[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, A21, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: ScalaReflection.schemaFor[A19].dataType :: ScalaReflection.schemaFor[A20].dataType :: ScalaReflection.schemaFor[A21].dataType :: Nil).getOrElse(Nil) - val udf = UserDefinedFunction(func, dataType, inputTypes) - functionRegistry.registerFunction(name, udf.builder) - udf + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) + functionRegistry.registerFunction(name, builder) + UserDefinedFunction(func, dataType) } /** @@ -422,9 +404,9 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag, A21: TypeTag, A22: TypeTag](name: String, func: Function22[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, A21, A22, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: ScalaReflection.schemaFor[A19].dataType :: ScalaReflection.schemaFor[A20].dataType :: ScalaReflection.schemaFor[A21].dataType :: ScalaReflection.schemaFor[A22].dataType :: Nil).getOrElse(Nil) - val udf = UserDefinedFunction(func, dataType, inputTypes) - functionRegistry.registerFunction(name, udf.builder) - udf + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) + functionRegistry.registerFunction(name, builder) + UserDefinedFunction(func, dataType) } ////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala index 1319391db5375..0f8cd280b5acb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala @@ -44,20 +44,11 @@ import org.apache.spark.sql.types.DataType case class UserDefinedFunction protected[sql] ( f: AnyRef, dataType: DataType, - inputTypes: Seq[DataType] = Nil, - deterministic: Boolean = true) { + inputTypes: Seq[DataType] = Nil) { def apply(exprs: Column*): Column = { - Column(ScalaUDF(f, dataType, exprs.map(_.expr), inputTypes, deterministic)) + Column(ScalaUDF(f, dataType, exprs.map(_.expr), inputTypes)) } - - protected[sql] def builder: Seq[Expression] => ScalaUDF = { - (exprs: Seq[Expression]) => - ScalaUDF(f, dataType, exprs, inputTypes, deterministic) - } - - def nondeterministic: UserDefinedFunction = - UserDefinedFunction(f, dataType, inputTypes, deterministic = false) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index 6e510f0b8aff4..e0435a0dba6ad 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql -import org.apache.spark.sql.catalyst.expressions.ScalaUDF -import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.test.SQLTestData._ @@ -193,107 +191,4 @@ class UDFSuite extends QueryTest with SharedSQLContext { // pass a decimal to intExpected. assert(sql("SELECT intExpected(1.0)").head().getInt(0) === 1) } - - private def checkNumUDFs(df: DataFrame, expectedNumUDFs: Int): Unit = { - val udfs = df.queryExecution.optimizedPlan.collect { - case p: logical.Project => p.projectList.flatMap { - case e => e.collect { - case udf: ScalaUDF => udf - } - } - }.flatten - assert(udfs.length === expectedNumUDFs) - } - - test("foldable udf") { - import org.apache.spark.sql.functions._ - - val myUDF = udf((x: Int) => x + 1) - - { - val df = sql("SELECT 1 as a") - .select(col("a"), myUDF(col("a")).as("b")) - .select(col("a"), col("b"), myUDF(col("b")).as("c")) - checkNumUDFs(df, 0) - checkAnswer(df, Row(1, 2, 3)) - } - } - - test("nondeterministic udf: using UDFRegistration") { - import org.apache.spark.sql.functions._ - - val myUDF = sqlContext.udf.register("plusOne1", (x: Int) => x + 1) - sqlContext.udf.register("plusOne2", myUDF.nondeterministic) - - { - val df = sqlContext.range(1, 2).select(col("id").as("a")) - .select(col("a"), myUDF(col("a")).as("b")) - .select(col("a"), col("b"), myUDF(col("b")).as("c")) - checkNumUDFs(df, 3) - checkAnswer(df, Row(1, 2, 3)) - } - - { - val df = sqlContext.range(1, 2).select(col("id").as("a")) - .select(col("a"), callUDF("plusOne1", col("a")).as("b")) - .select(col("a"), col("b"), callUDF("plusOne1", col("b")).as("c")) - checkNumUDFs(df, 3) - checkAnswer(df, Row(1, 2, 3)) - } - - { - val df = sqlContext.range(1, 2).select(col("id").as("a")) - .select(col("a"), myUDF.nondeterministic(col("a")).as("b")) - .select(col("a"), col("b"), myUDF.nondeterministic(col("b")).as("c")) - checkNumUDFs(df, 2) - checkAnswer(df, Row(1, 2, 3)) - } - - { - val df = sqlContext.range(1, 2).select(col("id").as("a")) - .select(col("a"), callUDF("plusOne2", col("a")).as("b")) - .select(col("a"), col("b"), callUDF("plusOne2", col("b")).as("c")) - checkNumUDFs(df, 2) - checkAnswer(df, Row(1, 2, 3)) - } - } - - test("nondeterministic udf: using udf function") { - import org.apache.spark.sql.functions._ - - val myUDF = udf((x: Int) => x + 1) - - { - val df = sqlContext.range(1, 2).select(col("id").as("a")) - .select(col("a"), myUDF(col("a")).as("b")) - .select(col("a"), col("b"), myUDF(col("b")).as("c")) - checkNumUDFs(df, 3) - checkAnswer(df, Row(1, 2, 3)) - } - - { - val df = sqlContext.range(1, 2).select(col("id").as("a")) - .select(col("a"), myUDF.nondeterministic(col("a")).as("b")) - .select(col("a"), col("b"), myUDF.nondeterministic(col("b")).as("c")) - checkNumUDFs(df, 2) - checkAnswer(df, Row(1, 2, 3)) - } - - { - // nondeterministicUDF will not be foldable. - val df = sql("SELECT 1 as a") - .select(col("a"), myUDF.nondeterministic(col("a")).as("b")) - .select(col("a"), col("b"), myUDF.nondeterministic(col("b")).as("c")) - checkNumUDFs(df, 2) - checkAnswer(df, Row(1, 2, 3)) - } - } - - test("override a registered udf") { - sqlContext.udf.register("intExpected", (x: Int) => x) - assert(sql("SELECT intExpected(1.0)").head().getInt(0) === 1) - - sqlContext.udf.register("intExpected", (x: Int) => x + 1) - assert(sql("SELECT intExpected(1.0)").head().getInt(0) === 2) - } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala index f14b2886a9ecb..72744799897be 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala @@ -381,7 +381,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { sqlContext.udf.register("div0", (x: Int) => x / 0) withTempPath { dir => intercept[org.apache.spark.SparkException] { - sqlContext.range(1, 2).selectExpr("div0(id) as a").write.parquet(dir.getCanonicalPath) + sqlContext.sql("select div0(1)").write.parquet(dir.getCanonicalPath) } val path = new Path(dir.getCanonicalPath, "_temporary") val fs = path.getFileSystem(hadoopConfiguration) @@ -405,7 +405,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { sqlContext.udf.register("div0", (x: Int) => x / 0) withTempPath { dir => intercept[org.apache.spark.SparkException] { - sqlContext.range(1, 2).selectExpr("div0(id) as a").write.parquet(dir.getCanonicalPath) + sqlContext.sql("select div0(1)").write.parquet(dir.getCanonicalPath) } val path = new Path(dir.getCanonicalPath, "_temporary") val fs = path.getFileSystem(hadoopConfiguration) From 8fa8c8375d7015a0332aa9ee613d7c6b6d62bae7 Mon Sep 17 00:00:00 2001 From: Yu ISHIKAWA Date: Thu, 5 Nov 2015 17:59:01 -0800 Subject: [PATCH 05/88] [SPARK-11514][ML] Pass random seed to spark.ml DecisionTree* cc jkbradley Author: Yu ISHIKAWA Closes #9486 from yu-iskw/SPARK-11514. --- .../ml/classification/DecisionTreeClassifier.scala | 4 +++- .../spark/ml/regression/DecisionTreeRegressor.scala | 4 +++- .../scala/org/apache/spark/ml/tree/treeParams.scala | 11 ++++++----- .../classification/DecisionTreeClassifierSuite.scala | 1 + .../ml/regression/DecisionTreeRegressorSuite.scala | 1 + 5 files changed, 14 insertions(+), 7 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala index b0157f7ce24ec..c478aea44ace8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala @@ -62,6 +62,8 @@ final class DecisionTreeClassifier(override val uid: String) override def setImpurity(value: String): this.type = super.setImpurity(value) + override def setSeed(value: Long): this.type = super.setSeed(value) + override protected def train(dataset: DataFrame): DecisionTreeClassificationModel = { val categoricalFeatures: Map[Int, Int] = MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) @@ -75,7 +77,7 @@ final class DecisionTreeClassifier(override val uid: String) val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset) val strategy = getOldStrategy(categoricalFeatures, numClasses) val trees = RandomForest.run(oldDataset, strategy, numTrees = 1, featureSubsetStrategy = "all", - seed = 0L, parentUID = Some(uid)) + seed = $(seed), parentUID = Some(uid)) trees.head.asInstanceOf[DecisionTreeClassificationModel] } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala index 04420fc6e8251..477030d9ea3ee 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala @@ -71,13 +71,15 @@ final class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val @Since("1.4.0") override def setImpurity(value: String): this.type = super.setImpurity(value) + override def setSeed(value: Long): this.type = super.setSeed(value) + override protected def train(dataset: DataFrame): DecisionTreeRegressionModel = { val categoricalFeatures: Map[Int, Int] = MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset) val strategy = getOldStrategy(categoricalFeatures) val trees = RandomForest.run(oldDataset, strategy, numTrees = 1, featureSubsetStrategy = "all", - seed = 0L, parentUID = Some(uid)) + seed = $(seed), parentUID = Some(uid)) trees.head.asInstanceOf[DecisionTreeRegressionModel] } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala index 281ba6eeffa92..1da97db9277d8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala @@ -29,7 +29,8 @@ import org.apache.spark.mllib.tree.loss.{Loss => OldLoss} * * Note: Marked as private and DeveloperApi since this may be made public in the future. */ -private[ml] trait DecisionTreeParams extends PredictorParams with HasCheckpointInterval { +private[ml] trait DecisionTreeParams extends PredictorParams + with HasCheckpointInterval with HasSeed { /** * Maximum depth of the tree (>= 0). @@ -123,6 +124,9 @@ private[ml] trait DecisionTreeParams extends PredictorParams with HasCheckpointI /** @group getParam */ final def getMinInfoGain: Double = $(minInfoGain) + /** @group setParam */ + def setSeed(value: Long): this.type = set(seed, value) + /** @group expertSetParam */ def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value) @@ -257,7 +261,7 @@ private[ml] object TreeRegressorParams { * * Note: Marked as private and DeveloperApi since this may be made public in the future. */ -private[ml] trait TreeEnsembleParams extends DecisionTreeParams with HasSeed { +private[ml] trait TreeEnsembleParams extends DecisionTreeParams { /** * Fraction of the training data used for learning each decision tree, in range (0, 1]. @@ -276,9 +280,6 @@ private[ml] trait TreeEnsembleParams extends DecisionTreeParams with HasSeed { /** @group getParam */ final def getSubsamplingRate: Double = $(subsamplingRate) - /** @group setParam */ - def setSeed(value: Long): this.type = set(seed, value) - /** * Create a Strategy instance to use with the old API. * NOTE: The caller should set impurity and seed. diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala index 815f6fd997584..92b8f84144ab0 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala @@ -72,6 +72,7 @@ class DecisionTreeClassifierSuite extends SparkFunSuite with MLlibTestSparkConte .setImpurity("gini") .setMaxDepth(2) .setMaxBins(100) + .setSeed(1) val categoricalFeatures = Map(0 -> 3, 1-> 3) val numClasses = 2 compareAPIs(categoricalDataPointsRDD, dt, categoricalFeatures, numClasses) diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala index 868fb8eecb8bb..e0d5afa7a7e97 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala @@ -49,6 +49,7 @@ class DecisionTreeRegressorSuite extends SparkFunSuite with MLlibTestSparkContex .setImpurity("variance") .setMaxDepth(2) .setMaxBins(100) + .setSeed(1) val categoricalFeatures = Map(0 -> 3, 1-> 3) compareAPIs(categoricalDataPointsRDD, dt, categoricalFeatures) } From 468ad0ae874d5cf55712ee976faf77f19c937ccb Mon Sep 17 00:00:00 2001 From: jerryshao Date: Thu, 5 Nov 2015 18:03:12 -0800 Subject: [PATCH 06/88] [SPARK-11457][STREAMING][YARN] Fix incorrect AM proxy filter conf recovery from checkpoint Currently Yarn AM proxy filter configuration is recovered from checkpoint file when Spark Streaming application is restarted, which will lead to some unwanted behaviors: 1. Wrong RM address if RM is redeployed from failure. 2. Wrong proxyBase, since app id is updated, old app id for proxyBase is wrong. So instead of recovering from checkpoint file, these configurations should be reloaded each time when app started. This problem only exists in Yarn cluster mode, for Yarn client mode, these configurations will be updated with RPC message `AddWebUIFilter`. Please help to review tdas harishreedharan vanzin , thanks a lot. Author: jerryshao Closes #9412 from jerryshao/SPARK-11457. --- .../org/apache/spark/streaming/Checkpoint.scala | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala index b7de6dde61c63..0cd55d9aec2cd 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala @@ -55,7 +55,8 @@ class Checkpoint(ssc: StreamingContext, val checkpointTime: Time) "spark.driver.port", "spark.master", "spark.yarn.keytab", - "spark.yarn.principal") + "spark.yarn.principal", + "spark.ui.filters") val newSparkConf = new SparkConf(loadDefaults = false).setAll(sparkConfPairs) .remove("spark.driver.host") @@ -66,6 +67,16 @@ class Checkpoint(ssc: StreamingContext, val checkpointTime: Time) newSparkConf.set(prop, value) } } + + // Add Yarn proxy filter specific configurations to the recovered SparkConf + val filter = "org.apache.hadoop.yarn.server.webproxy.amfilter.AmIpFilter" + val filterPrefix = s"spark.$filter.param." + newReloadConf.getAll.foreach { case (k, v) => + if (k.startsWith(filterPrefix) && k.length > filterPrefix.length) { + newSparkConf.set(k, v) + } + } + newSparkConf } From 5e31db70bb783656ba042863fcd3c223e17a8f81 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Thu, 5 Nov 2015 18:05:58 -0800 Subject: [PATCH 07/88] [SPARK-11538][BUILD] Force guava 14 in sbt build. sbt's version resolution code always picks the most recent version, and we don't want that for guava. Author: Marcelo Vanzin Closes #9508 from vanzin/SPARK-11538. --- project/SparkBuild.scala | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 75c36930decef..b75ed13a78c68 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -207,7 +207,8 @@ object SparkBuild extends PomBuild { // Note ordering of these settings matter. /* Enable shared settings on all projects */ (allProjects ++ optionallyEnabledProjects ++ assemblyProjects ++ Seq(spark, tools)) - .foreach(enable(sharedSettings ++ ExcludedDependencies.settings ++ Revolver.settings)) + .foreach(enable(sharedSettings ++ DependencyOverrides.settings ++ + ExcludedDependencies.settings ++ Revolver.settings)) /* Enable tests settings for all projects except examples, assembly and tools */ (allProjects ++ optionallyEnabledProjects).foreach(enable(TestSettings.settings)) @@ -291,6 +292,14 @@ object Flume { lazy val settings = sbtavro.SbtAvro.avroSettings } +/** + * Overrides to work around sbt's dependency resolution being different from Maven's. + */ +object DependencyOverrides { + lazy val settings = Seq( + dependencyOverrides += "com.google.guava" % "guava" % "14.0.1") +} + /** This excludes library dependencies in sbt, which are specified in maven but are not needed by sbt build. From 3cc2c053b5d68c747a30bd58cf388b87b1922f13 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 5 Nov 2015 18:12:54 -0800 Subject: [PATCH 08/88] [SPARK-11540][SQL] API audit for QueryExecutionListener. Author: Reynold Xin Closes #9509 from rxin/SPARK-11540. --- .../spark/sql/execution/QueryExecution.scala | 30 +++--- .../sql/util/QueryExecutionListener.scala | 101 ++++++++++-------- 2 files changed, 72 insertions(+), 59 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index fc9174549e642..c2142d03f422b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution +import com.google.common.annotations.VisibleForTesting + import org.apache.spark.rdd.RDD import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.InternalRow @@ -25,31 +27,33 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan /** * The primary workflow for executing relational queries using Spark. Designed to allow easy * access to the intermediate phases of query execution for developers. + * + * While this is not a public class, we should avoid changing the function names for the sake of + * changing them, because a lot of developers use the feature for debugging. */ class QueryExecution(val sqlContext: SQLContext, val logical: LogicalPlan) { - val analyzer = sqlContext.analyzer - val optimizer = sqlContext.optimizer - val planner = sqlContext.planner - val cacheManager = sqlContext.cacheManager - val prepareForExecution = sqlContext.prepareForExecution - def assertAnalyzed(): Unit = analyzer.checkAnalysis(analyzed) + @VisibleForTesting + def assertAnalyzed(): Unit = sqlContext.analyzer.checkAnalysis(analyzed) + + lazy val analyzed: LogicalPlan = sqlContext.analyzer.execute(logical) - lazy val analyzed: LogicalPlan = analyzer.execute(logical) lazy val withCachedData: LogicalPlan = { assertAnalyzed() - cacheManager.useCachedData(analyzed) + sqlContext.cacheManager.useCachedData(analyzed) } - lazy val optimizedPlan: LogicalPlan = optimizer.execute(withCachedData) + + lazy val optimizedPlan: LogicalPlan = sqlContext.optimizer.execute(withCachedData) // TODO: Don't just pick the first one... lazy val sparkPlan: SparkPlan = { SparkPlan.currentContext.set(sqlContext) - planner.plan(optimizedPlan).next() + sqlContext.planner.plan(optimizedPlan).next() } + // executedPlan should not be used to initialize any SparkPlan. It should be // only used for execution. - lazy val executedPlan: SparkPlan = prepareForExecution.execute(sparkPlan) + lazy val executedPlan: SparkPlan = sqlContext.prepareForExecution.execute(sparkPlan) /** Internal version of the RDD. Avoids copies and has no schema */ lazy val toRdd: RDD[InternalRow] = executedPlan.execute() @@ -57,11 +61,11 @@ class QueryExecution(val sqlContext: SQLContext, val logical: LogicalPlan) { protected def stringOrError[A](f: => A): String = try f.toString catch { case e: Throwable => e.toString } - def simpleString: String = + def simpleString: String = { s"""== Physical Plan == |${stringOrError(executedPlan)} """.stripMargin.trim - + } override def toString: String = { def output = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/util/QueryExecutionListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/util/QueryExecutionListener.scala index 909a8abd225b8..ac432e2baa3c0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/util/QueryExecutionListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/util/QueryExecutionListener.scala @@ -19,36 +19,38 @@ package org.apache.spark.sql.util import java.util.concurrent.locks.ReentrantReadWriteLock import scala.collection.mutable.ListBuffer +import scala.util.control.NonFatal -import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.Logging +import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.sql.execution.QueryExecution /** + * :: Experimental :: * The interface of query execution listener that can be used to analyze execution metrics. * - * Note that implementations should guarantee thread-safety as they will be used in a non - * thread-safe way. + * Note that implementations should guarantee thread-safety as they can be invoked by + * multiple different threads. */ @Experimental trait QueryExecutionListener { /** * A callback function that will be called when a query executed successfully. - * Implementations should guarantee thread-safe. + * Note that this can be invoked by multiple different threads. * - * @param funcName the name of the action that triggered this query. + * @param funcName name of the action that triggered this query. * @param qe the QueryExecution object that carries detail information like logical plan, * physical plan, etc. - * @param duration the execution time for this query in nanoseconds. + * @param durationNs the execution time for this query in nanoseconds. */ @DeveloperApi - def onSuccess(funcName: String, qe: QueryExecution, duration: Long) + def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit /** * A callback function that will be called when a query execution failed. - * Implementations should guarantee thread-safe. + * Note that this can be invoked by multiple different threads. * * @param funcName the name of the action that triggered this query. * @param qe the QueryExecution object that carries detail information like logical plan, @@ -56,34 +58,20 @@ trait QueryExecutionListener { * @param exception the exception that failed this query. */ @DeveloperApi - def onFailure(funcName: String, qe: QueryExecution, exception: Exception) + def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit } -@Experimental -class ExecutionListenerManager extends Logging { - private[this] val listeners = ListBuffer.empty[QueryExecutionListener] - private[this] val lock = new ReentrantReadWriteLock() - - /** Acquires a read lock on the cache for the duration of `f`. */ - private def readLock[A](f: => A): A = { - val rl = lock.readLock() - rl.lock() - try f finally { - rl.unlock() - } - } - /** Acquires a write lock on the cache for the duration of `f`. */ - private def writeLock[A](f: => A): A = { - val wl = lock.writeLock() - wl.lock() - try f finally { - wl.unlock() - } - } +/** + * :: Experimental :: + * + * Manager for [[QueryExecutionListener]]. See [[org.apache.spark.sql.SQLContext.listenerManager]]. + */ +@Experimental +class ExecutionListenerManager private[sql] () extends Logging { /** - * Registers the specified QueryExecutionListener. + * Registers the specified [[QueryExecutionListener]]. */ @DeveloperApi def register(listener: QueryExecutionListener): Unit = writeLock { @@ -91,7 +79,7 @@ class ExecutionListenerManager extends Logging { } /** - * Unregisters the specified QueryExecutionListener. + * Unregisters the specified [[QueryExecutionListener]]. */ @DeveloperApi def unregister(listener: QueryExecutionListener): Unit = writeLock { @@ -99,38 +87,59 @@ class ExecutionListenerManager extends Logging { } /** - * clears out all registered QueryExecutionListeners. + * Removes all the registered [[QueryExecutionListener]]. */ @DeveloperApi def clear(): Unit = writeLock { listeners.clear() } - private[sql] def onSuccess( - funcName: String, - qe: QueryExecution, - duration: Long): Unit = readLock { - withErrorHandling { listener => - listener.onSuccess(funcName, qe, duration) + private[sql] def onSuccess(funcName: String, qe: QueryExecution, duration: Long): Unit = { + readLock { + withErrorHandling { listener => + listener.onSuccess(funcName, qe, duration) + } } } - private[sql] def onFailure( - funcName: String, - qe: QueryExecution, - exception: Exception): Unit = readLock { - withErrorHandling { listener => - listener.onFailure(funcName, qe, exception) + private[sql] def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = { + readLock { + withErrorHandling { listener => + listener.onFailure(funcName, qe, exception) + } } } + private[this] val listeners = ListBuffer.empty[QueryExecutionListener] + + /** A lock to prevent updating the list of listeners while we are traversing through them. */ + private[this] val lock = new ReentrantReadWriteLock() + private def withErrorHandling(f: QueryExecutionListener => Unit): Unit = { for (listener <- listeners) { try { f(listener) } catch { - case e: Exception => logWarning("error executing query execution listener", e) + case NonFatal(e) => logWarning("Error executing query execution listener", e) } } } + + /** Acquires a read lock on the cache for the duration of `f`. */ + private def readLock[A](f: => A): A = { + val rl = lock.readLock() + rl.lock() + try f finally { + rl.unlock() + } + } + + /** Acquires a write lock on the cache for the duration of `f`. */ + private def writeLock[A](f: => A): A = { + val wl = lock.writeLock() + wl.lock() + try f finally { + wl.unlock() + } + } } From eec74ba8bde7f9446cc38e687bda103e85669d35 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 5 Nov 2015 19:02:18 -0800 Subject: [PATCH 09/88] [SPARK-7542][SQL] Support off-heap index/sort buffer This brings the support of off-heap memory for array inside BytesToBytesMap and InMemorySorter, then we could allocate all the memory from off-heap for execution. Closes #8068 Author: Davies Liu Closes #9477 from davies/unsafe_timsort. --- .../apache/spark/memory/MemoryConsumer.java | 36 +++++----- .../spark/memory/TaskMemoryManager.java | 6 +- .../shuffle/sort/ShuffleExternalSorter.java | 26 +++---- .../shuffle/sort/ShuffleInMemorySorter.java | 67 ++++++++++--------- .../shuffle/sort/ShuffleSortDataFormat.java | 38 +++++++---- .../spark/unsafe/map/BytesToBytesMap.java | 18 +++-- .../unsafe/sort/UnsafeExternalSorter.java | 28 +++----- .../unsafe/sort/UnsafeInMemorySorter.java | 66 +++++++++++------- .../unsafe/sort/UnsafeSortDataFormat.java | 47 +++++++------ .../spark/memory/TaskMemoryManagerSuite.java | 23 ------- .../spark/memory/TestMemoryConsumer.java | 45 +++++++++++++ .../sort/ShuffleInMemorySorterSuite.java | 16 +++-- .../sort/UnsafeExternalSorterSuite.java | 1 - .../sort/UnsafeInMemorySorterSuite.java | 12 ++-- .../sql/execution/UnsafeKVExternalSorter.java | 3 +- .../apache/spark/unsafe/array/LongArray.java | 18 ++++- .../spark/unsafe/array/LongArraySuite.java | 4 ++ 17 files changed, 265 insertions(+), 189 deletions(-) create mode 100644 core/src/test/java/org/apache/spark/memory/TestMemoryConsumer.java diff --git a/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java b/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java index 008799cc77395..8fbdb72832adf 100644 --- a/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java +++ b/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java @@ -20,6 +20,7 @@ import java.io.IOException; +import org.apache.spark.unsafe.array.LongArray; import org.apache.spark.unsafe.memory.MemoryBlock; @@ -28,9 +29,9 @@ */ public abstract class MemoryConsumer { - private final TaskMemoryManager taskMemoryManager; + protected final TaskMemoryManager taskMemoryManager; private final long pageSize; - private long used; + protected long used; protected MemoryConsumer(TaskMemoryManager taskMemoryManager, long pageSize) { this.taskMemoryManager = taskMemoryManager; @@ -74,26 +75,29 @@ public void spill() throws IOException { public abstract long spill(long size, MemoryConsumer trigger) throws IOException; /** - * Acquire `size` bytes memory. - * - * If there is not enough memory, throws OutOfMemoryError. + * Allocates a LongArray of `size`. */ - protected void acquireMemory(long size) { - long got = taskMemoryManager.acquireExecutionMemory(size, this); - if (got < size) { - taskMemoryManager.releaseExecutionMemory(got, this); + public LongArray allocateArray(long size) { + long required = size * 8L; + MemoryBlock page = taskMemoryManager.allocatePage(required, this); + if (page == null || page.size() < required) { + long got = 0; + if (page != null) { + got = page.size(); + taskMemoryManager.freePage(page, this); + } taskMemoryManager.showMemoryUsage(); - throw new OutOfMemoryError("Could not acquire " + size + " bytes of memory, got " + got); + throw new OutOfMemoryError("Unable to acquire " + required + " bytes of memory, got " + got); } - used += got; + used += required; + return new LongArray(page); } /** - * Release `size` bytes memory. + * Frees a LongArray. */ - protected void releaseMemory(long size) { - used -= size; - taskMemoryManager.releaseExecutionMemory(size, this); + public void freeArray(LongArray array) { + freePage(array.memoryBlock()); } /** @@ -109,7 +113,7 @@ protected MemoryBlock allocatePage(long required) { long got = 0; if (page != null) { got = page.size(); - freePage(page); + taskMemoryManager.freePage(page, this); } taskMemoryManager.showMemoryUsage(); throw new OutOfMemoryError("Unable to acquire " + required + " bytes of memory, got " + got); diff --git a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java index 4230575446d31..6440f9c0f30de 100644 --- a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java +++ b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java @@ -137,7 +137,7 @@ public long acquireExecutionMemory(long required, MemoryConsumer consumer) { if (got < required) { // Call spill() on other consumers to release memory for (MemoryConsumer c: consumers) { - if (c != null && c != consumer && c.getUsed() > 0) { + if (c != consumer && c.getUsed() > 0) { try { long released = c.spill(required - got, consumer); if (released > 0) { @@ -173,7 +173,9 @@ public long acquireExecutionMemory(long required, MemoryConsumer consumer) { } } - consumers.add(consumer); + if (consumer != null) { + consumers.add(consumer); + } logger.debug("Task {} acquire {} for {}", taskAttemptId, Utils.bytesToString(got), consumer); return got; } diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java index 400d8520019b9..9affff80143d7 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java @@ -39,6 +39,7 @@ import org.apache.spark.storage.DiskBlockObjectWriter; import org.apache.spark.storage.TempShuffleBlockId; import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.array.LongArray; import org.apache.spark.unsafe.memory.MemoryBlock; import org.apache.spark.util.Utils; @@ -114,8 +115,7 @@ public ShuffleExternalSorter( this.numElementsForSpillThreshold = conf.getLong("spark.shuffle.spill.numElementsForceSpillThreshold", Long.MAX_VALUE); this.writeMetrics = writeMetrics; - acquireMemory(initialSize * 8L); - this.inMemSorter = new ShuffleInMemorySorter(initialSize); + this.inMemSorter = new ShuffleInMemorySorter(this, initialSize); this.peakMemoryUsedBytes = getMemoryUsage(); } @@ -301,9 +301,8 @@ private long freeMemory() { public void cleanupResources() { freeMemory(); if (inMemSorter != null) { - long sorterMemoryUsage = inMemSorter.getMemoryUsage(); + inMemSorter.free(); inMemSorter = null; - releaseMemory(sorterMemoryUsage); } for (SpillInfo spill : spills) { if (spill.file.exists() && !spill.file.delete()) { @@ -321,9 +320,10 @@ private void growPointerArrayIfNecessary() throws IOException { assert(inMemSorter != null); if (!inMemSorter.hasSpaceForAnotherRecord()) { long used = inMemSorter.getMemoryUsage(); - long needed = used + inMemSorter.getMemoryToExpand(); + LongArray array; try { - acquireMemory(needed); // could trigger spilling + // could trigger spilling + array = allocateArray(used / 8 * 2); } catch (OutOfMemoryError e) { // should have trigger spilling assert(inMemSorter.hasSpaceForAnotherRecord()); @@ -331,16 +331,9 @@ private void growPointerArrayIfNecessary() throws IOException { } // check if spilling is triggered or not if (inMemSorter.hasSpaceForAnotherRecord()) { - releaseMemory(needed); + freeArray(array); } else { - try { - inMemSorter.expandPointerArray(); - releaseMemory(used); - } catch (OutOfMemoryError oom) { - // Just in case that JVM had run out of memory - releaseMemory(needed); - spill(); - } + inMemSorter.expandPointerArray(array); } } } @@ -404,9 +397,8 @@ public SpillInfo[] closeAndGetSpills() throws IOException { // Do not count the final file towards the spill count. writeSortedFile(true); freeMemory(); - long sorterMemoryUsage = inMemSorter.getMemoryUsage(); + inMemSorter.free(); inMemSorter = null; - releaseMemory(sorterMemoryUsage); } return spills.toArray(new SpillInfo[spills.size()]); } catch (IOException e) { diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java index e630575d1ae19..58ad88e1ed87b 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java @@ -19,11 +19,14 @@ import java.util.Comparator; +import org.apache.spark.memory.MemoryConsumer; +import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.array.LongArray; import org.apache.spark.util.collection.Sorter; final class ShuffleInMemorySorter { - private final Sorter sorter; + private final Sorter sorter; private static final class SortComparator implements Comparator { @Override public int compare(PackedRecordPointer left, PackedRecordPointer right) { @@ -32,24 +35,34 @@ public int compare(PackedRecordPointer left, PackedRecordPointer right) { } private static final SortComparator SORT_COMPARATOR = new SortComparator(); + private final MemoryConsumer consumer; + /** * An array of record pointers and partition ids that have been encoded by * {@link PackedRecordPointer}. The sort operates on this array instead of directly manipulating * records. */ - private long[] array; + private LongArray array; /** * The position in the pointer array where new records can be inserted. */ private int pos = 0; - public ShuffleInMemorySorter(int initialSize) { + public ShuffleInMemorySorter(MemoryConsumer consumer, int initialSize) { + this.consumer = consumer; assert (initialSize > 0); - this.array = new long[initialSize]; + this.array = consumer.allocateArray(initialSize); this.sorter = new Sorter<>(ShuffleSortDataFormat.INSTANCE); } + public void free() { + if (array != null) { + consumer.freeArray(array); + array = null; + } + } + public int numRecords() { return pos; } @@ -58,30 +71,25 @@ public void reset() { pos = 0; } - private int newLength() { - // Guard against overflow: - return array.length <= Integer.MAX_VALUE / 2 ? (array.length * 2) : Integer.MAX_VALUE; - } - - /** - * Returns the memory needed to expand - */ - public long getMemoryToExpand() { - return ((long) (newLength() - array.length)) * 8; - } - - public void expandPointerArray() { - final long[] oldArray = array; - array = new long[newLength()]; - System.arraycopy(oldArray, 0, array, 0, oldArray.length); + public void expandPointerArray(LongArray newArray) { + assert(newArray.size() > array.size()); + Platform.copyMemory( + array.getBaseObject(), + array.getBaseOffset(), + newArray.getBaseObject(), + newArray.getBaseOffset(), + array.size() * 8L + ); + consumer.freeArray(array); + array = newArray; } public boolean hasSpaceForAnotherRecord() { - return pos < array.length; + return pos < array.size(); } public long getMemoryUsage() { - return array.length * 8L; + return array.size() * 8L; } /** @@ -96,14 +104,9 @@ public long getMemoryUsage() { */ public void insertRecord(long recordPointer, int partitionId) { if (!hasSpaceForAnotherRecord()) { - if (array.length == Integer.MAX_VALUE) { - throw new IllegalStateException("Sort pointer array has reached maximum size"); - } else { - expandPointerArray(); - } + expandPointerArray(consumer.allocateArray(array.size() * 2)); } - array[pos] = - PackedRecordPointer.packPointer(recordPointer, partitionId); + array.set(pos, PackedRecordPointer.packPointer(recordPointer, partitionId)); pos++; } @@ -112,12 +115,12 @@ public void insertRecord(long recordPointer, int partitionId) { */ public static final class ShuffleSorterIterator { - private final long[] pointerArray; + private final LongArray pointerArray; private final int numRecords; final PackedRecordPointer packedRecordPointer = new PackedRecordPointer(); private int position = 0; - public ShuffleSorterIterator(int numRecords, long[] pointerArray) { + public ShuffleSorterIterator(int numRecords, LongArray pointerArray) { this.numRecords = numRecords; this.pointerArray = pointerArray; } @@ -127,7 +130,7 @@ public boolean hasNext() { } public void loadNext() { - packedRecordPointer.set(pointerArray[position]); + packedRecordPointer.set(pointerArray.get(position)); position++; } } diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleSortDataFormat.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleSortDataFormat.java index 8a1e5aec6ff0e..8f4e3229976dc 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleSortDataFormat.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleSortDataFormat.java @@ -17,16 +17,19 @@ package org.apache.spark.shuffle.sort; +import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.array.LongArray; +import org.apache.spark.unsafe.memory.MemoryBlock; import org.apache.spark.util.collection.SortDataFormat; -final class ShuffleSortDataFormat extends SortDataFormat { +final class ShuffleSortDataFormat extends SortDataFormat { public static final ShuffleSortDataFormat INSTANCE = new ShuffleSortDataFormat(); private ShuffleSortDataFormat() { } @Override - public PackedRecordPointer getKey(long[] data, int pos) { + public PackedRecordPointer getKey(LongArray data, int pos) { // Since we re-use keys, this method shouldn't be called. throw new UnsupportedOperationException(); } @@ -37,31 +40,38 @@ public PackedRecordPointer newKey() { } @Override - public PackedRecordPointer getKey(long[] data, int pos, PackedRecordPointer reuse) { - reuse.set(data[pos]); + public PackedRecordPointer getKey(LongArray data, int pos, PackedRecordPointer reuse) { + reuse.set(data.get(pos)); return reuse; } @Override - public void swap(long[] data, int pos0, int pos1) { - final long temp = data[pos0]; - data[pos0] = data[pos1]; - data[pos1] = temp; + public void swap(LongArray data, int pos0, int pos1) { + final long temp = data.get(pos0); + data.set(pos0, data.get(pos1)); + data.set(pos1, temp); } @Override - public void copyElement(long[] src, int srcPos, long[] dst, int dstPos) { - dst[dstPos] = src[srcPos]; + public void copyElement(LongArray src, int srcPos, LongArray dst, int dstPos) { + dst.set(dstPos, src.get(srcPos)); } @Override - public void copyRange(long[] src, int srcPos, long[] dst, int dstPos, int length) { - System.arraycopy(src, srcPos, dst, dstPos, length); + public void copyRange(LongArray src, int srcPos, LongArray dst, int dstPos, int length) { + Platform.copyMemory( + src.getBaseObject(), + src.getBaseOffset() + srcPos * 8, + dst.getBaseObject(), + dst.getBaseOffset() + dstPos * 8, + length * 8 + ); } @Override - public long[] allocate(int length) { - return new long[length]; + public LongArray allocate(int length) { + // This buffer is used temporary (usually small), so it's fine to allocated from JVM heap. + return new LongArray(MemoryBlock.fromLongArray(new long[length])); } } diff --git a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java index 6656fd1d0bc59..04694dc54418c 100644 --- a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java +++ b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java @@ -20,7 +20,6 @@ import javax.annotation.Nullable; import java.io.File; import java.io.IOException; -import java.util.Arrays; import java.util.Iterator; import java.util.LinkedList; @@ -724,11 +723,10 @@ public long spill(long size, MemoryConsumer trigger) throws IOException { */ private void allocate(int capacity) { assert (capacity >= 0); - // The capacity needs to be divisible by 64 so that our bit set can be sized properly capacity = Math.max((int) Math.min(MAX_CAPACITY, ByteArrayMethods.nextPowerOf2(capacity)), 64); assert (capacity <= MAX_CAPACITY); - acquireMemory(capacity * 16); - longArray = new LongArray(MemoryBlock.fromLongArray(new long[capacity * 2])); + longArray = allocateArray(capacity * 2); + longArray.zeroOut(); this.growthThreshold = (int) (capacity * loadFactor); this.mask = capacity - 1; @@ -743,9 +741,8 @@ private void allocate(int capacity) { public void free() { updatePeakMemoryUsed(); if (longArray != null) { - long used = longArray.memoryBlock().size(); + freeArray(longArray); longArray = null; - releaseMemory(used); } Iterator dataPagesIterator = dataPages.iterator(); while (dataPagesIterator.hasNext()) { @@ -834,9 +831,9 @@ public int getNumDataPages() { /** * Returns the underline long[] of longArray. */ - public long[] getArray() { + public LongArray getArray() { assert(longArray != null); - return (long[]) longArray.memoryBlock().getBaseObject(); + return longArray; } /** @@ -844,7 +841,8 @@ public long[] getArray() { */ public void reset() { numElements = 0; - Arrays.fill(getArray(), 0); + longArray.zeroOut(); + while (dataPages.size() > 0) { MemoryBlock dataPage = dataPages.removeLast(); freePage(dataPage); @@ -887,7 +885,7 @@ void growAndRehash() { longArray.set(newPos * 2, keyPointer); longArray.set(newPos * 2 + 1, hashcode); } - releaseMemory(oldLongArray.memoryBlock().size()); + freeArray(oldLongArray); if (enablePerfMetrics) { timeSpentResizingNs += System.nanoTime() - resizeStartTime; diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java index cba043bc48cc8..9a7b2ad06cab6 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java @@ -32,6 +32,7 @@ import org.apache.spark.memory.TaskMemoryManager; import org.apache.spark.storage.BlockManager; import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.array.LongArray; import org.apache.spark.unsafe.memory.MemoryBlock; import org.apache.spark.util.TaskCompletionListener; import org.apache.spark.util.Utils; @@ -123,9 +124,8 @@ private UnsafeExternalSorter( this.writeMetrics = new ShuffleWriteMetrics(); if (existingInMemorySorter == null) { - this.inMemSorter = - new UnsafeInMemorySorter(taskMemoryManager, recordComparator, prefixComparator, initialSize); - acquireMemory(inMemSorter.getMemoryUsage()); + this.inMemSorter = new UnsafeInMemorySorter( + this, taskMemoryManager, recordComparator, prefixComparator, initialSize); } else { this.inMemSorter = existingInMemorySorter; } @@ -277,9 +277,8 @@ public void cleanupResources() { deleteSpillFiles(); freeMemory(); if (inMemSorter != null) { - long used = inMemSorter.getMemoryUsage(); + inMemSorter.free(); inMemSorter = null; - releaseMemory(used); } } } @@ -293,9 +292,10 @@ private void growPointerArrayIfNecessary() throws IOException { assert(inMemSorter != null); if (!inMemSorter.hasSpaceForAnotherRecord()) { long used = inMemSorter.getMemoryUsage(); - long needed = used + inMemSorter.getMemoryToExpand(); + LongArray array; try { - acquireMemory(needed); // could trigger spilling + // could trigger spilling + array = allocateArray(used / 8 * 2); } catch (OutOfMemoryError e) { // should have trigger spilling assert(inMemSorter.hasSpaceForAnotherRecord()); @@ -303,16 +303,9 @@ private void growPointerArrayIfNecessary() throws IOException { } // check if spilling is triggered or not if (inMemSorter.hasSpaceForAnotherRecord()) { - releaseMemory(needed); + freeArray(array); } else { - try { - inMemSorter.expandPointerArray(); - releaseMemory(used); - } catch (OutOfMemoryError oom) { - // Just in case that JVM had run out of memory - releaseMemory(needed); - spill(); - } + inMemSorter.expandPointerArray(array); } } } @@ -498,9 +491,8 @@ public void loadNext() throws IOException { nextUpstream = null; assert(inMemSorter != null); - long used = inMemSorter.getMemoryUsage(); + inMemSorter.free(); inMemSorter = null; - releaseMemory(used); } numRecords--; upstream.loadNext(); diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java index d57213b9b8bfc..a218ad4623f46 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java @@ -19,8 +19,10 @@ import java.util.Comparator; +import org.apache.spark.memory.MemoryConsumer; import org.apache.spark.memory.TaskMemoryManager; import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.array.LongArray; import org.apache.spark.util.collection.Sorter; /** @@ -62,15 +64,16 @@ public int compare(RecordPointerAndKeyPrefix r1, RecordPointerAndKeyPrefix r2) { } } + private final MemoryConsumer consumer; private final TaskMemoryManager memoryManager; - private final Sorter sorter; + private final Sorter sorter; private final Comparator sortComparator; /** * Within this buffer, position {@code 2 * i} holds a pointer pointer to the record at * index {@code i}, while position {@code 2 * i + 1} in the array holds an 8-byte key prefix. */ - private long[] array; + private LongArray array; /** * The position in the sort buffer where new records can be inserted. @@ -78,22 +81,33 @@ public int compare(RecordPointerAndKeyPrefix r1, RecordPointerAndKeyPrefix r2) { private int pos = 0; public UnsafeInMemorySorter( + final MemoryConsumer consumer, final TaskMemoryManager memoryManager, final RecordComparator recordComparator, final PrefixComparator prefixComparator, int initialSize) { - this(memoryManager, recordComparator, prefixComparator, new long[initialSize * 2]); + this(consumer, memoryManager, recordComparator, prefixComparator, + consumer.allocateArray(initialSize * 2)); } public UnsafeInMemorySorter( + final MemoryConsumer consumer, final TaskMemoryManager memoryManager, final RecordComparator recordComparator, final PrefixComparator prefixComparator, - long[] array) { - this.array = array; + LongArray array) { + this.consumer = consumer; this.memoryManager = memoryManager; this.sorter = new Sorter<>(UnsafeSortDataFormat.INSTANCE); this.sortComparator = new SortComparator(recordComparator, prefixComparator, memoryManager); + this.array = array; + } + + /** + * Free the memory used by pointer array. + */ + public void free() { + consumer.freeArray(array); } public void reset() { @@ -107,26 +121,26 @@ public int numRecords() { return pos / 2; } - private int newLength() { - return array.length < Integer.MAX_VALUE / 2 ? (array.length * 2) : Integer.MAX_VALUE; - } - - public long getMemoryToExpand() { - return (long) (newLength() - array.length) * 8L; - } - public long getMemoryUsage() { - return array.length * 8L; + return array.size() * 8L; } public boolean hasSpaceForAnotherRecord() { - return pos + 2 <= array.length; + return pos + 2 <= array.size(); } - public void expandPointerArray() { - final long[] oldArray = array; - array = new long[newLength()]; - System.arraycopy(oldArray, 0, array, 0, oldArray.length); + public void expandPointerArray(LongArray newArray) { + if (newArray.size() < array.size()) { + throw new OutOfMemoryError("Not enough memory to grow pointer array"); + } + Platform.copyMemory( + array.getBaseObject(), + array.getBaseOffset(), + newArray.getBaseObject(), + newArray.getBaseOffset(), + array.size() * 8L); + consumer.freeArray(array); + array = newArray; } /** @@ -138,11 +152,11 @@ public void expandPointerArray() { */ public void insertRecord(long recordPointer, long keyPrefix) { if (!hasSpaceForAnotherRecord()) { - expandPointerArray(); + expandPointerArray(consumer.allocateArray(array.size() * 2)); } - array[pos] = recordPointer; + array.set(pos, recordPointer); pos++; - array[pos] = keyPrefix; + array.set(pos, keyPrefix); pos++; } @@ -150,7 +164,7 @@ public static final class SortedIterator extends UnsafeSorterIterator { private final TaskMemoryManager memoryManager; private final int sortBufferInsertPosition; - private final long[] sortBuffer; + private final LongArray sortBuffer; private int position = 0; private Object baseObject; private long baseOffset; @@ -160,7 +174,7 @@ public static final class SortedIterator extends UnsafeSorterIterator { private SortedIterator( TaskMemoryManager memoryManager, int sortBufferInsertPosition, - long[] sortBuffer) { + LongArray sortBuffer) { this.memoryManager = memoryManager; this.sortBufferInsertPosition = sortBufferInsertPosition; this.sortBuffer = sortBuffer; @@ -188,11 +202,11 @@ public int numRecordsLeft() { @Override public void loadNext() { // This pointer points to a 4-byte record length, followed by the record's bytes - final long recordPointer = sortBuffer[position]; + final long recordPointer = sortBuffer.get(position); baseObject = memoryManager.getPage(recordPointer); baseOffset = memoryManager.getOffsetInPage(recordPointer) + 4; // Skip over record length recordLength = Platform.getInt(baseObject, baseOffset - 4); - keyPrefix = sortBuffer[position + 1]; + keyPrefix = sortBuffer.get(position + 1); position += 2; } diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSortDataFormat.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSortDataFormat.java index d09c728a7a638..d3137f5f31c25 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSortDataFormat.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSortDataFormat.java @@ -17,6 +17,9 @@ package org.apache.spark.util.collection.unsafe.sort; +import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.array.LongArray; +import org.apache.spark.unsafe.memory.MemoryBlock; import org.apache.spark.util.collection.SortDataFormat; /** @@ -26,14 +29,14 @@ * Within each long[] buffer, position {@code 2 * i} holds a pointer pointer to the record at * index {@code i}, while position {@code 2 * i + 1} in the array holds an 8-byte key prefix. */ -final class UnsafeSortDataFormat extends SortDataFormat { +final class UnsafeSortDataFormat extends SortDataFormat { public static final UnsafeSortDataFormat INSTANCE = new UnsafeSortDataFormat(); private UnsafeSortDataFormat() { } @Override - public RecordPointerAndKeyPrefix getKey(long[] data, int pos) { + public RecordPointerAndKeyPrefix getKey(LongArray data, int pos) { // Since we re-use keys, this method shouldn't be called. throw new UnsupportedOperationException(); } @@ -44,37 +47,43 @@ public RecordPointerAndKeyPrefix newKey() { } @Override - public RecordPointerAndKeyPrefix getKey(long[] data, int pos, RecordPointerAndKeyPrefix reuse) { - reuse.recordPointer = data[pos * 2]; - reuse.keyPrefix = data[pos * 2 + 1]; + public RecordPointerAndKeyPrefix getKey(LongArray data, int pos, RecordPointerAndKeyPrefix reuse) { + reuse.recordPointer = data.get(pos * 2); + reuse.keyPrefix = data.get(pos * 2 + 1); return reuse; } @Override - public void swap(long[] data, int pos0, int pos1) { - long tempPointer = data[pos0 * 2]; - long tempKeyPrefix = data[pos0 * 2 + 1]; - data[pos0 * 2] = data[pos1 * 2]; - data[pos0 * 2 + 1] = data[pos1 * 2 + 1]; - data[pos1 * 2] = tempPointer; - data[pos1 * 2 + 1] = tempKeyPrefix; + public void swap(LongArray data, int pos0, int pos1) { + long tempPointer = data.get(pos0 * 2); + long tempKeyPrefix = data.get(pos0 * 2 + 1); + data.set(pos0 * 2, data.get(pos1 * 2)); + data.set(pos0 * 2 + 1, data.get(pos1 * 2 + 1)); + data.set(pos1 * 2, tempPointer); + data.set(pos1 * 2 + 1, tempKeyPrefix); } @Override - public void copyElement(long[] src, int srcPos, long[] dst, int dstPos) { - dst[dstPos * 2] = src[srcPos * 2]; - dst[dstPos * 2 + 1] = src[srcPos * 2 + 1]; + public void copyElement(LongArray src, int srcPos, LongArray dst, int dstPos) { + dst.set(dstPos * 2, src.get(srcPos * 2)); + dst.set(dstPos * 2 + 1, src.get(srcPos * 2 + 1)); } @Override - public void copyRange(long[] src, int srcPos, long[] dst, int dstPos, int length) { - System.arraycopy(src, srcPos * 2, dst, dstPos * 2, length * 2); + public void copyRange(LongArray src, int srcPos, LongArray dst, int dstPos, int length) { + Platform.copyMemory( + src.getBaseObject(), + src.getBaseOffset() + srcPos * 16, + dst.getBaseObject(), + dst.getBaseOffset() + dstPos * 16, + length * 16); } @Override - public long[] allocate(int length) { + public LongArray allocate(int length) { assert (length < Integer.MAX_VALUE / 2) : "Length " + length + " is too large"; - return new long[length * 2]; + // This is used as temporary buffer, it's fine to allocate from JVM heap. + return new LongArray(MemoryBlock.fromLongArray(new long[length * 2])); } } diff --git a/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java b/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java index dab7b0592cb4e..c731317395612 100644 --- a/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java +++ b/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java @@ -17,8 +17,6 @@ package org.apache.spark.memory; -import java.io.IOException; - import org.junit.Assert; import org.junit.Test; @@ -27,27 +25,6 @@ public class TaskMemoryManagerSuite { - class TestMemoryConsumer extends MemoryConsumer { - TestMemoryConsumer(TaskMemoryManager memoryManager) { - super(memoryManager); - } - - @Override - public long spill(long size, MemoryConsumer trigger) throws IOException { - long used = getUsed(); - releaseMemory(used); - return used; - } - - void use(long size) { - acquireMemory(size); - } - - void free(long size) { - releaseMemory(size); - } - } - @Test public void leakedPageMemoryIsDetected() { final TaskMemoryManager manager = new TaskMemoryManager( diff --git a/core/src/test/java/org/apache/spark/memory/TestMemoryConsumer.java b/core/src/test/java/org/apache/spark/memory/TestMemoryConsumer.java new file mode 100644 index 0000000000000..8ae3642738509 --- /dev/null +++ b/core/src/test/java/org/apache/spark/memory/TestMemoryConsumer.java @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.memory; + +import java.io.IOException; + +public class TestMemoryConsumer extends MemoryConsumer { + public TestMemoryConsumer(TaskMemoryManager memoryManager) { + super(memoryManager); + } + + @Override + public long spill(long size, MemoryConsumer trigger) throws IOException { + long used = getUsed(); + free(used); + return used; + } + + void use(long size) { + long got = taskMemoryManager.acquireExecutionMemory(size, this); + used += got; + } + + void free(long size) { + used -= size; + taskMemoryManager.releaseExecutionMemory(size, this); + } +} + + diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java index 2293b1bbc113e..faa5a863ee630 100644 --- a/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java @@ -25,13 +25,19 @@ import org.apache.spark.HashPartitioner; import org.apache.spark.SparkConf; -import org.apache.spark.unsafe.Platform; +import org.apache.spark.memory.TaskMemoryManager; +import org.apache.spark.memory.TestMemoryConsumer; import org.apache.spark.memory.TestMemoryManager; +import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.memory.MemoryBlock; -import org.apache.spark.memory.TaskMemoryManager; public class ShuffleInMemorySorterSuite { + final TestMemoryManager memoryManager = + new TestMemoryManager(new SparkConf().set("spark.unsafe.offHeap", "false")); + final TaskMemoryManager taskMemoryManager = new TaskMemoryManager(memoryManager, 0); + final TestMemoryConsumer consumer = new TestMemoryConsumer(taskMemoryManager); + private static String getStringFromDataPage(Object baseObject, long baseOffset, int strLength) { final byte[] strBytes = new byte[strLength]; Platform.copyMemory(baseObject, baseOffset, strBytes, Platform.BYTE_ARRAY_OFFSET, strLength); @@ -40,7 +46,7 @@ private static String getStringFromDataPage(Object baseObject, long baseOffset, @Test public void testSortingEmptyInput() { - final ShuffleInMemorySorter sorter = new ShuffleInMemorySorter(100); + final ShuffleInMemorySorter sorter = new ShuffleInMemorySorter(consumer, 100); final ShuffleInMemorySorter.ShuffleSorterIterator iter = sorter.getSortedIterator(); assert(!iter.hasNext()); } @@ -63,7 +69,7 @@ public void testBasicSorting() throws Exception { new TaskMemoryManager(new TestMemoryManager(conf), 0); final MemoryBlock dataPage = memoryManager.allocatePage(2048, null); final Object baseObject = dataPage.getBaseObject(); - final ShuffleInMemorySorter sorter = new ShuffleInMemorySorter(4); + final ShuffleInMemorySorter sorter = new ShuffleInMemorySorter(consumer, 4); final HashPartitioner hashPartitioner = new HashPartitioner(4); // Write the records into the data page and store pointers into the sorter @@ -104,7 +110,7 @@ public void testBasicSorting() throws Exception { @Test public void testSortingManyNumbers() throws Exception { - ShuffleInMemorySorter sorter = new ShuffleInMemorySorter(4); + ShuffleInMemorySorter sorter = new ShuffleInMemorySorter(consumer, 4); int[] numbersToSort = new int[128000]; Random random = new Random(16); for (int i = 0; i < numbersToSort.length; i++) { diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java index cfead0e5924b8..11c3a7be38875 100644 --- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java @@ -390,7 +390,6 @@ public void testPeakMemoryUsed() throws Exception { for (int i = 0; i < numRecordsPerPage * 10; i++) { insertNumber(sorter, i); newPeakMemory = sorter.getPeakMemoryUsedBytes(); - // The first page is pre-allocated on instantiation if (i % numRecordsPerPage == 0) { // We allocated a new page for this record, so peak memory should change assertEquals(previousPeakMemory + pageSizeBytes, newPeakMemory); diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java index 642f6585f8a15..a203a09648ac0 100644 --- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java @@ -23,6 +23,7 @@ import org.apache.spark.HashPartitioner; import org.apache.spark.SparkConf; +import org.apache.spark.memory.TestMemoryConsumer; import org.apache.spark.memory.TestMemoryManager; import org.apache.spark.memory.TaskMemoryManager; import org.apache.spark.unsafe.Platform; @@ -44,9 +45,11 @@ private static String getStringFromDataPage(Object baseObject, long baseOffset, @Test public void testSortingEmptyInput() { - final UnsafeInMemorySorter sorter = new UnsafeInMemorySorter( - new TaskMemoryManager( - new TestMemoryManager(new SparkConf().set("spark.unsafe.offHeap", "false")), 0), + final TaskMemoryManager memoryManager = new TaskMemoryManager( + new TestMemoryManager(new SparkConf().set("spark.unsafe.offHeap", "false")), 0); + final TestMemoryConsumer consumer = new TestMemoryConsumer(memoryManager); + final UnsafeInMemorySorter sorter = new UnsafeInMemorySorter(consumer, + memoryManager, mock(RecordComparator.class), mock(PrefixComparator.class), 100); @@ -69,6 +72,7 @@ public void testSortingOnlyByIntegerPrefix() throws Exception { }; final TaskMemoryManager memoryManager = new TaskMemoryManager( new TestMemoryManager(new SparkConf().set("spark.unsafe.offHeap", "false")), 0); + final TestMemoryConsumer consumer = new TestMemoryConsumer(memoryManager); final MemoryBlock dataPage = memoryManager.allocatePage(2048, null); final Object baseObject = dataPage.getBaseObject(); // Write the records into the data page: @@ -102,7 +106,7 @@ public int compare(long prefix1, long prefix2) { return (int) prefix1 - (int) prefix2; } }; - UnsafeInMemorySorter sorter = new UnsafeInMemorySorter(memoryManager, recordComparator, + UnsafeInMemorySorter sorter = new UnsafeInMemorySorter(consumer, memoryManager, recordComparator, prefixComparator, dataToSort.length); // Given a page of records, insert those records into the sorter one-by-one: position = dataPage.getBaseOffset(); diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java index e2898ef2e2158..8c9b9c85e37fc 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java @@ -85,8 +85,9 @@ public UnsafeKVExternalSorter( } else { // During spilling, the array in map will not be used, so we can borrow that and use it // as the underline array for in-memory sorter (it's always large enough). + // Since we will not grow the array, it's fine to pass `null` as consumer. final UnsafeInMemorySorter inMemSorter = new UnsafeInMemorySorter( - taskMemoryManager, recordComparator, prefixComparator, map.getArray()); + null, taskMemoryManager, recordComparator, prefixComparator, map.getArray()); // We cannot use the destructive iterator here because we are reusing the existing memory // pages in BytesToBytesMap to hold records during sorting. diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/array/LongArray.java b/unsafe/src/main/java/org/apache/spark/unsafe/array/LongArray.java index 74105050e4191..1a3cdff638264 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/array/LongArray.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/array/LongArray.java @@ -39,7 +39,6 @@ public final class LongArray { private final long length; public LongArray(MemoryBlock memory) { - assert memory.size() % WIDTH == 0 : "Memory not aligned (" + memory.size() + ")"; assert memory.size() < (long) Integer.MAX_VALUE * 8: "Array size > 4 billion elements"; this.memory = memory; this.baseObj = memory.getBaseObject(); @@ -51,6 +50,14 @@ public MemoryBlock memoryBlock() { return memory; } + public Object getBaseObject() { + return baseObj; + } + + public long getBaseOffset() { + return baseOffset; + } + /** * Returns the number of elements this array can hold. */ @@ -58,6 +65,15 @@ public long size() { return length; } + /** + * Fill this all with 0L. + */ + public void zeroOut() { + for (long off = baseOffset; off < baseOffset + length * WIDTH; off += WIDTH) { + Platform.putLong(baseObj, off, 0); + } + } + /** * Sets the value at position {@code index}. */ diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/array/LongArraySuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/array/LongArraySuite.java index 5974cf91ff993..fb8e53b3348f3 100644 --- a/unsafe/src/test/java/org/apache/spark/unsafe/array/LongArraySuite.java +++ b/unsafe/src/test/java/org/apache/spark/unsafe/array/LongArraySuite.java @@ -34,5 +34,9 @@ public void basicTest() { Assert.assertEquals(2, arr.size()); Assert.assertEquals(1L, arr.get(0)); Assert.assertEquals(3L, arr.get(1)); + + arr.zeroOut(); + Assert.assertEquals(0L, arr.get(0)); + Assert.assertEquals(0L, arr.get(1)); } } From 363a476c3fefb0263e63fd24df0b2779a64f79ec Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Thu, 5 Nov 2015 21:42:32 -0800 Subject: [PATCH 10/88] [SPARK-11528] [SQL] Typed aggregations for Datasets This PR adds the ability to do typed SQL aggregations. We will likely also want to provide an interface to allow users to do aggregations on objects, but this is deferred to another PR. ```scala val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() ds.groupBy(_._1).agg(sum("_2").as[Int]).collect() res0: Array(("a", 30), ("b", 3), ("c", 1)) ``` Author: Michael Armbrust Closes #9499 from marmbrus/dataset-agg. --- .../expressions/namedExpressions.scala | 4 + .../scala/org/apache/spark/sql/Dataset.scala | 2 +- .../org/apache/spark/sql/GroupedDataset.scala | 93 ++++++++++++++++++- .../org/apache/spark/sql/DatasetSuite.scala | 36 +++++++ 4 files changed, 132 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index 8957df0be6814..9ab5c299d0f55 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -254,6 +254,10 @@ case class AttributeReference( } override def toString: String = s"$name#${exprId.id}$typeSuffix" + + // Since the expression id is not in the first constructor it is missing from the default + // tree string. + override def simpleString: String = s"$name#${exprId.id}: ${dataType.simpleString}" } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 500227e93a472..4bca9c3b3fe54 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -55,7 +55,7 @@ import org.apache.spark.sql.types.StructType * @since 1.6.0 */ @Experimental -class Dataset[T] private( +class Dataset[T] private[sql]( @transient val sqlContext: SQLContext, @transient val queryExecution: QueryExecution, unresolvedEncoder: Encoder[T]) extends Serializable { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala index 96d6e9dd548e5..b8fc373dffcf5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala @@ -17,16 +17,25 @@ package org.apache.spark.sql +import org.apache.spark.annotation.Experimental +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAlias, UnresolvedAttribute} import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, encoderFor, Encoder} -import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.expressions.{Expression, NamedExpression, Alias, Attribute} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.QueryExecution /** + * :: Experimental :: * A [[Dataset]] has been logically grouped by a user specified grouping key. Users should not * construct a [[GroupedDataset]] directly, but should instead call `groupBy` on an existing * [[Dataset]]. + * + * COMPATIBILITY NOTE: Long term we plan to make [[GroupedDataset)]] extend `GroupedData`. However, + * making this change to the class hierarchy would break some function signatures. As such, this + * class should be considered a preview of the final API. Changes will be made to the interface + * after Spark 1.6. */ +@Experimental class GroupedDataset[K, T] private[sql]( private val kEncoder: Encoder[K], private val tEncoder: Encoder[T], @@ -35,7 +44,7 @@ class GroupedDataset[K, T] private[sql]( private val groupingAttributes: Seq[Attribute]) extends Serializable { private implicit val kEnc = kEncoder match { - case e: ExpressionEncoder[K] => e.resolve(groupingAttributes) + case e: ExpressionEncoder[K] => e.unbind(groupingAttributes).resolve(groupingAttributes) case other => throw new UnsupportedOperationException("Only expression encoders are currently supported") } @@ -46,9 +55,16 @@ class GroupedDataset[K, T] private[sql]( throw new UnsupportedOperationException("Only expression encoders are currently supported") } + /** Encoders for built in aggregations. */ + private implicit def newLongEncoder: Encoder[Long] = ExpressionEncoder[Long](flat = true) + private def logicalPlan = queryExecution.analyzed private def sqlContext = queryExecution.sqlContext + private def groupedData = + new GroupedData( + new DataFrame(sqlContext, logicalPlan), groupingAttributes, GroupedData.GroupByType) + /** * Returns a new [[GroupedDataset]] where the type of the key has been mapped to the specified * type. The mapping of key columns to the type follows the same rules as `as` on [[Dataset]]. @@ -88,6 +104,79 @@ class GroupedDataset[K, T] private[sql]( MapGroups(f, groupingAttributes, logicalPlan)) } + // To ensure valid overloading. + protected def agg(expr: Column, exprs: Column*): DataFrame = + groupedData.agg(expr, exprs: _*) + + /** + * Internal helper function for building typed aggregations that return tuples. For simplicity + * and code reuse, we do this without the help of the type system and then use helper functions + * that cast appropriately for the user facing interface. + * TODO: does not handle aggrecations that return nonflat results, + */ + protected def aggUntyped(columns: TypedColumn[_]*): Dataset[_] = { + val aliases = (groupingAttributes ++ columns.map(_.expr)).map { + case u: UnresolvedAttribute => UnresolvedAlias(u) + case expr: NamedExpression => expr + case expr: Expression => Alias(expr, expr.prettyString)() + } + + val unresolvedPlan = Aggregate(groupingAttributes, aliases, logicalPlan) + val execution = new QueryExecution(sqlContext, unresolvedPlan) + + val columnEncoders = columns.map(_.encoder.asInstanceOf[ExpressionEncoder[_]]) + + // Rebind the encoders to the nested schema that will be produced by the aggregation. + val encoders = (kEnc +: columnEncoders).zip(execution.analyzed.output).map { + case (e: ExpressionEncoder[_], a) if !e.flat => + e.nested(a).resolve(execution.analyzed.output) + case (e, a) => + e.unbind(a :: Nil).resolve(execution.analyzed.output) + } + new Dataset(sqlContext, execution, ExpressionEncoder.tuple(encoders)) + } + + /** + * Computes the given aggregation, returning a [[Dataset]] of tuples for each unique key + * and the result of computing this aggregation over all elements in the group. + */ + def agg[A1](col1: TypedColumn[A1]): Dataset[(K, A1)] = + aggUntyped(col1).asInstanceOf[Dataset[(K, A1)]] + + /** + * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key + * and the result of computing these aggregations over all elements in the group. + */ + def agg[A1, A2](col1: TypedColumn[A1], col2: TypedColumn[A2]): Dataset[(K, A1, A2)] = + aggUntyped(col1, col2).asInstanceOf[Dataset[(K, A1, A2)]] + + /** + * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key + * and the result of computing these aggregations over all elements in the group. + */ + def agg[A1, A2, A3]( + col1: TypedColumn[A1], + col2: TypedColumn[A2], + col3: TypedColumn[A3]): Dataset[(K, A1, A2, A3)] = + aggUntyped(col1, col2, col3).asInstanceOf[Dataset[(K, A1, A2, A3)]] + + /** + * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key + * and the result of computing these aggregations over all elements in the group. + */ + def agg[A1, A2, A3, A4]( + col1: TypedColumn[A1], + col2: TypedColumn[A2], + col3: TypedColumn[A3], + col4: TypedColumn[A4]): Dataset[(K, A1, A2, A3, A4)] = + aggUntyped(col1, col2, col3, col4).asInstanceOf[Dataset[(K, A1, A2, A3, A4)]] + + /** + * Returns a [[Dataset]] that contains a tuple with each key and the number of items present + * for that key. + */ + def count(): Dataset[(K, Long)] = agg(functions.count("*").as[Long]) + /** * Applies the given function to each cogrouped data. For each unique group, the function will * be passed the grouping key and 2 iterators containing all elements in the group from diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 3e9b621cfd67f..d61e17edc64ed 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -258,6 +258,42 @@ class DatasetSuite extends QueryTest with SharedSQLContext { (ClassData("a", 1), 30), (ClassData("b", 1), 3), (ClassData("c", 1), 1)) } + test("typed aggregation: expr") { + val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() + + checkAnswer( + ds.groupBy(_._1).agg(sum("_2").as[Int]), + ("a", 30), ("b", 3), ("c", 1)) + } + + test("typed aggregation: expr, expr") { + val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() + + checkAnswer( + ds.groupBy(_._1).agg(sum("_2").as[Int], sum($"_2" + 1).as[Long]), + ("a", 30, 32L), ("b", 3, 5L), ("c", 1, 2L)) + } + + test("typed aggregation: expr, expr, expr") { + val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() + + checkAnswer( + ds.groupBy(_._1).agg(sum("_2").as[Int], sum($"_2" + 1).as[Long], count("*").as[Long]), + ("a", 30, 32L, 2L), ("b", 3, 5L, 2L), ("c", 1, 2L, 1L)) + } + + test("typed aggregation: expr, expr, expr, expr") { + val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() + + checkAnswer( + ds.groupBy(_._1).agg( + sum("_2").as[Int], + sum($"_2" + 1).as[Long], + count("*").as[Long], + avg("_2").as[Double]), + ("a", 30, 32L, 2L, 15.0), ("b", 3, 5L, 2L, 1.5), ("c", 1, 2L, 1L, 1.0)) + } + test("cogroup") { val ds1 = Seq(1 -> "a", 3 -> "abc", 5 -> "hello", 3 -> "foo").toDS() val ds2 = Seq(2 -> "q", 3 -> "w", 5 -> "e", 5 -> "r").toDS() From bc5d6c03893a9bd340d6b94d3550e25648412241 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 5 Nov 2015 22:03:26 -0800 Subject: [PATCH 11/88] [SPARK-11541][SQL] Break JdbcDialects.scala into multiple files and mark various dialects as private. Author: Reynold Xin Closes #9511 from rxin/SPARK-11541. --- project/MimaExcludes.scala | 19 +- .../org/apache/spark/sql/GroupedData.scala | 2 +- .../spark/sql/jdbc/AggregatedDialect.scala | 44 ++++ .../apache/spark/sql/jdbc/DB2Dialect.scala | 32 +++ .../apache/spark/sql/jdbc/DerbyDialect.scala | 44 ++++ .../apache/spark/sql/jdbc/JdbcDialects.scala | 190 +----------------- .../spark/sql/jdbc/MsSqlServerDialect.scala | 41 ++++ .../apache/spark/sql/jdbc/MySQLDialect.scala | 48 +++++ .../apache/spark/sql/jdbc/OracleDialect.scala | 45 +++++ .../spark/sql/jdbc/PostgresDialect.scala | 54 +++++ 10 files changed, 332 insertions(+), 187 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/jdbc/AggregatedDialect.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 40f5c9fec8bb8..dacef911e397e 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -116,7 +116,24 @@ object MimaExcludes { "org.apache.spark.rdd.MapPartitionsWithPreparationRDD$") ) ++ Seq( // SPARK-11485 - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.DataFrameHolder.df") + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.DataFrameHolder.df"), + // SPARK-11541 mark various JDBC dialects as private + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.jdbc.NoopDialect.productElement"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.jdbc.NoopDialect.productArity"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.jdbc.NoopDialect.canEqual"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.jdbc.NoopDialect.productIterator"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.jdbc.NoopDialect.productPrefix"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.jdbc.NoopDialect.toString"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.jdbc.NoopDialect.hashCode"), + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.sql.jdbc.PostgresDialect$"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.jdbc.PostgresDialect.productElement"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.jdbc.PostgresDialect.productArity"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.jdbc.PostgresDialect.canEqual"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.jdbc.PostgresDialect.productIterator"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.jdbc.PostgresDialect.productPrefix"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.jdbc.PostgresDialect.toString"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.jdbc.PostgresDialect.hashCode"), + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.sql.jdbc.NoopDialect$") ) case v if v.startsWith("1.5") => Seq( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala index 7cf66b65c8722..f9eab5c2e965b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala @@ -40,7 +40,7 @@ import org.apache.spark.sql.types.NumericType class GroupedData protected[sql]( df: DataFrame, groupingExprs: Seq[Expression], - private val groupType: GroupedData.GroupType) { + groupType: GroupedData.GroupType) { private[this] def toDF(aggExprs: Seq[Expression]): DataFrame = { val aggregates = if (df.sqlContext.conf.dataFrameRetainGroupColumns) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/AggregatedDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/AggregatedDialect.scala new file mode 100644 index 0000000000000..467d8d62d1b7f --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/AggregatedDialect.scala @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.jdbc + +import org.apache.spark.sql.types.{DataType, MetadataBuilder} + +/** + * AggregatedDialect can unify multiple dialects into one virtual Dialect. + * Dialects are tried in order, and the first dialect that does not return a + * neutral element will will. + * + * @param dialects List of dialects. + */ +private class AggregatedDialect(dialects: List[JdbcDialect]) extends JdbcDialect { + + require(dialects.nonEmpty) + + override def canHandle(url : String): Boolean = + dialects.map(_.canHandle(url)).reduce(_ && _) + + override def getCatalystType( + sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = { + dialects.flatMap(_.getCatalystType(sqlType, typeName, size, md)).headOption + } + + override def getJDBCType(dt: DataType): Option[JdbcType] = { + dialects.flatMap(_.getJDBCType(dt)).headOption + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala new file mode 100644 index 0000000000000..b1cb0e55026be --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.jdbc + +import org.apache.spark.sql.types.{BooleanType, StringType, DataType} + + +private object DB2Dialect extends JdbcDialect { + + override def canHandle(url: String): Boolean = url.startsWith("jdbc:db2") + + override def getJDBCType(dt: DataType): Option[JdbcType] = dt match { + case StringType => Option(JdbcType("CLOB", java.sql.Types.CLOB)) + case BooleanType => Option(JdbcType("CHAR(1)", java.sql.Types.CHAR)) + case _ => None + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala new file mode 100644 index 0000000000000..84f68e779c38c --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.jdbc + +import java.sql.Types + +import org.apache.spark.sql.types._ + + +private object DerbyDialect extends JdbcDialect { + + override def canHandle(url: String): Boolean = url.startsWith("jdbc:derby") + + override def getCatalystType( + sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = { + if (sqlType == Types.REAL) Option(FloatType) else None + } + + override def getJDBCType(dt: DataType): Option[JdbcType] = dt match { + case StringType => Option(JdbcType("CLOB", java.sql.Types.CLOB)) + case ByteType => Option(JdbcType("SMALLINT", java.sql.Types.SMALLINT)) + case ShortType => Option(JdbcType("SMALLINT", java.sql.Types.SMALLINT)) + case BooleanType => Option(JdbcType("BOOLEAN", java.sql.Types.BOOLEAN)) + // 31 is the maximum precision and 5 is the default scale for a Derby DECIMAL + case t: DecimalType if t.precision > 31 => + Option(JdbcType("DECIMAL(31,5)", java.sql.Types.DECIMAL)) + case _ => None + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index f9a6a09b6270d..14bfea4e3e287 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.jdbc -import java.sql.Types - import org.apache.spark.sql.types._ import org.apache.spark.annotation.DeveloperApi @@ -115,11 +113,10 @@ abstract class JdbcDialect { @DeveloperApi object JdbcDialects { - private var dialects = List[JdbcDialect]() - /** * Register a dialect for use on all new matching jdbc [[org.apache.spark.sql.DataFrame]]. * Readding an existing dialect will cause a move-to-front. + * * @param dialect The new dialect. */ def registerDialect(dialect: JdbcDialect) : Unit = { @@ -128,12 +125,15 @@ object JdbcDialects { /** * Unregister a dialect. Does nothing if the dialect is not registered. + * * @param dialect The jdbc dialect. */ def unregisterDialect(dialect : JdbcDialect) : Unit = { dialects = dialects.filterNot(_ == dialect) } + private[this] var dialects = List[JdbcDialect]() + registerDialect(MySQLDialect) registerDialect(PostgresDialect) registerDialect(DB2Dialect) @@ -141,7 +141,6 @@ object JdbcDialects { registerDialect(DerbyDialect) registerDialect(OracleDialect) - /** * Fetch the JdbcDialect class corresponding to a given database url. */ @@ -156,187 +155,8 @@ object JdbcDialects { } /** - * :: DeveloperApi :: - * AggregatedDialect can unify multiple dialects into one virtual Dialect. - * Dialects are tried in order, and the first dialect that does not return a - * neutral element will will. - * @param dialects List of dialects. - */ -@DeveloperApi -class AggregatedDialect(dialects: List[JdbcDialect]) extends JdbcDialect { - - require(dialects.nonEmpty) - - override def canHandle(url : String): Boolean = - dialects.map(_.canHandle(url)).reduce(_ && _) - - override def getCatalystType( - sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = { - dialects.flatMap(_.getCatalystType(sqlType, typeName, size, md)).headOption - } - - override def getJDBCType(dt: DataType): Option[JdbcType] = { - dialects.flatMap(_.getJDBCType(dt)).headOption - } -} - -/** - * :: DeveloperApi :: * NOOP dialect object, always returning the neutral element. */ -@DeveloperApi -case object NoopDialect extends JdbcDialect { +private object NoopDialect extends JdbcDialect { override def canHandle(url : String): Boolean = true } - -/** - * :: DeveloperApi :: - * Default postgres dialect, mapping bit/cidr/inet on read and string/binary/boolean on write. - */ -@DeveloperApi -case object PostgresDialect extends JdbcDialect { - override def canHandle(url: String): Boolean = url.startsWith("jdbc:postgresql") - override def getCatalystType( - sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = { - if (sqlType == Types.BIT && typeName.equals("bit") && size != 1) { - Option(BinaryType) - } else if (sqlType == Types.OTHER && typeName.equals("cidr")) { - Option(StringType) - } else if (sqlType == Types.OTHER && typeName.equals("inet")) { - Option(StringType) - } else if (sqlType == Types.OTHER && typeName.equals("json")) { - Option(StringType) - } else if (sqlType == Types.OTHER && typeName.equals("jsonb")) { - Option(StringType) - } else None - } - - override def getJDBCType(dt: DataType): Option[JdbcType] = dt match { - case StringType => Some(JdbcType("TEXT", java.sql.Types.CHAR)) - case BinaryType => Some(JdbcType("BYTEA", java.sql.Types.BINARY)) - case BooleanType => Some(JdbcType("BOOLEAN", java.sql.Types.BOOLEAN)) - case _ => None - } - - override def getTableExistsQuery(table: String): String = { - s"SELECT 1 FROM $table LIMIT 1" - } - -} - -/** - * :: DeveloperApi :: - * Default mysql dialect to read bit/bitsets correctly. - */ -@DeveloperApi -case object MySQLDialect extends JdbcDialect { - override def canHandle(url : String): Boolean = url.startsWith("jdbc:mysql") - override def getCatalystType( - sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = { - if (sqlType == Types.VARBINARY && typeName.equals("BIT") && size != 1) { - // This could instead be a BinaryType if we'd rather return bit-vectors of up to 64 bits as - // byte arrays instead of longs. - md.putLong("binarylong", 1) - Option(LongType) - } else if (sqlType == Types.BIT && typeName.equals("TINYINT")) { - Option(BooleanType) - } else None - } - - override def quoteIdentifier(colName: String): String = { - s"`$colName`" - } - - override def getTableExistsQuery(table: String): String = { - s"SELECT 1 FROM $table LIMIT 1" - } -} - -/** - * :: DeveloperApi :: - * Default DB2 dialect, mapping string/boolean on write to valid DB2 types. - * By default string, and boolean gets mapped to db2 invalid types TEXT, and BIT(1). - */ -@DeveloperApi -case object DB2Dialect extends JdbcDialect { - - override def canHandle(url: String): Boolean = url.startsWith("jdbc:db2") - - override def getJDBCType(dt: DataType): Option[JdbcType] = dt match { - case StringType => Some(JdbcType("CLOB", java.sql.Types.CLOB)) - case BooleanType => Some(JdbcType("CHAR(1)", java.sql.Types.CHAR)) - case _ => None - } -} - -/** - * :: DeveloperApi :: - * Default Microsoft SQL Server dialect, mapping the datetimeoffset types to a String on read. - */ -@DeveloperApi -case object MsSqlServerDialect extends JdbcDialect { - override def canHandle(url: String): Boolean = url.startsWith("jdbc:sqlserver") - override def getCatalystType( - sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = { - if (typeName.contains("datetimeoffset")) { - // String is recommend by Microsoft SQL Server for datetimeoffset types in non-MS clients - Option(StringType) - } else None - } - - override def getJDBCType(dt: DataType): Option[JdbcType] = dt match { - case TimestampType => Some(JdbcType("DATETIME", java.sql.Types.TIMESTAMP)) - case _ => None - } -} - -/** - * :: DeveloperApi :: - * Default Apache Derby dialect, mapping real on read - * and string/byte/short/boolean/decimal on write. - */ -@DeveloperApi -case object DerbyDialect extends JdbcDialect { - override def canHandle(url: String): Boolean = url.startsWith("jdbc:derby") - override def getCatalystType( - sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = { - if (sqlType == Types.REAL) Option(FloatType) else None - } - - override def getJDBCType(dt: DataType): Option[JdbcType] = dt match { - case StringType => Some(JdbcType("CLOB", java.sql.Types.CLOB)) - case ByteType => Some(JdbcType("SMALLINT", java.sql.Types.SMALLINT)) - case ShortType => Some(JdbcType("SMALLINT", java.sql.Types.SMALLINT)) - case BooleanType => Some(JdbcType("BOOLEAN", java.sql.Types.BOOLEAN)) - // 31 is the maximum precision and 5 is the default scale for a Derby DECIMAL - case (t: DecimalType) if (t.precision > 31) => - Some(JdbcType("DECIMAL(31,5)", java.sql.Types.DECIMAL)) - case _ => None - } - -} - -/** - * :: DeveloperApi :: - * Default Oracle dialect, mapping a nonspecific numeric type to a general decimal type. - */ -@DeveloperApi -case object OracleDialect extends JdbcDialect { - override def canHandle(url: String): Boolean = url.startsWith("jdbc:oracle") - override def getCatalystType( - sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = { - // Handle NUMBER fields that have no precision/scale in special way - // because JDBC ResultSetMetaData converts this to 0 procision and -127 scale - // For more details, please see - // https://github.com/apache/spark/pull/8780#issuecomment-145598968 - // and - // https://github.com/apache/spark/pull/8780#issuecomment-144541760 - if (sqlType == Types.NUMERIC && size == 0) { - // This is sub-optimal as we have to pick a precision/scale in advance whereas the data - // in Oracle is allowed to have different precision/scale for each value. - Some(DecimalType(DecimalType.MAX_PRECISION, 10)) - } else { - None - } - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala new file mode 100644 index 0000000000000..3eb722b070d5d --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.jdbc + +import org.apache.spark.sql.types._ + + +private object MsSqlServerDialect extends JdbcDialect { + + override def canHandle(url: String): Boolean = url.startsWith("jdbc:sqlserver") + + override def getCatalystType( + sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = { + if (typeName.contains("datetimeoffset")) { + // String is recommend by Microsoft SQL Server for datetimeoffset types in non-MS clients + Option(StringType) + } else { + None + } + } + + override def getJDBCType(dt: DataType): Option[JdbcType] = dt match { + case TimestampType => Some(JdbcType("DATETIME", java.sql.Types.TIMESTAMP)) + case _ => None + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala new file mode 100644 index 0000000000000..da413ed1f08b5 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.jdbc + +import java.sql.Types + +import org.apache.spark.sql.types.{BooleanType, LongType, DataType, MetadataBuilder} + + +private case object MySQLDialect extends JdbcDialect { + + override def canHandle(url : String): Boolean = url.startsWith("jdbc:mysql") + + override def getCatalystType( + sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = { + if (sqlType == Types.VARBINARY && typeName.equals("BIT") && size != 1) { + // This could instead be a BinaryType if we'd rather return bit-vectors of up to 64 bits as + // byte arrays instead of longs. + md.putLong("binarylong", 1) + Option(LongType) + } else if (sqlType == Types.BIT && typeName.equals("TINYINT")) { + Option(BooleanType) + } else None + } + + override def quoteIdentifier(colName: String): String = { + s"`$colName`" + } + + override def getTableExistsQuery(table: String): String = { + s"SELECT 1 FROM $table LIMIT 1" + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala new file mode 100644 index 0000000000000..4165c382689f9 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.jdbc + +import java.sql.Types + +import org.apache.spark.sql.types._ + + +private case object OracleDialect extends JdbcDialect { + + override def canHandle(url: String): Boolean = url.startsWith("jdbc:oracle") + + override def getCatalystType( + sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = { + // Handle NUMBER fields that have no precision/scale in special way + // because JDBC ResultSetMetaData converts this to 0 procision and -127 scale + // For more details, please see + // https://github.com/apache/spark/pull/8780#issuecomment-145598968 + // and + // https://github.com/apache/spark/pull/8780#issuecomment-144541760 + if (sqlType == Types.NUMERIC && size == 0) { + // This is sub-optimal as we have to pick a precision/scale in advance whereas the data + // in Oracle is allowed to have different precision/scale for each value. + Option(DecimalType(DecimalType.MAX_PRECISION, 10)) + } else { + None + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala new file mode 100644 index 0000000000000..e701a7fcd9e16 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.jdbc + +import java.sql.Types + +import org.apache.spark.sql.types._ + + +private object PostgresDialect extends JdbcDialect { + + override def canHandle(url: String): Boolean = url.startsWith("jdbc:postgresql") + + override def getCatalystType( + sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = { + if (sqlType == Types.BIT && typeName.equals("bit") && size != 1) { + Option(BinaryType) + } else if (sqlType == Types.OTHER && typeName.equals("cidr")) { + Option(StringType) + } else if (sqlType == Types.OTHER && typeName.equals("inet")) { + Option(StringType) + } else if (sqlType == Types.OTHER && typeName.equals("json")) { + Option(StringType) + } else if (sqlType == Types.OTHER && typeName.equals("jsonb")) { + Option(StringType) + } else None + } + + override def getJDBCType(dt: DataType): Option[JdbcType] = dt match { + case StringType => Some(JdbcType("TEXT", java.sql.Types.CHAR)) + case BinaryType => Some(JdbcType("BYTEA", java.sql.Types.BINARY)) + case BooleanType => Some(JdbcType("BOOLEAN", java.sql.Types.BOOLEAN)) + case _ => None + } + + override def getTableExistsQuery(table: String): String = { + s"SELECT 1 FROM $table LIMIT 1" + } +} From 253e87e8ab8717ffef40a6d0d376b1add155ef90 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 6 Nov 2015 06:38:49 -0800 Subject: [PATCH 12/88] [SPARK-11453][SQL][FOLLOW-UP] remove DecimalLit A cleanup for https://github.com/apache/spark/pull/9085. The `DecimalLit` is very similar to `FloatLit`, we can just keep one of them. Also added low level unit test at `SqlParserSuite` Author: Wenchen Fan Closes #9482 from cloud-fan/parser. --- .../sql/catalyst/AbstractSparkSQLParser.scala | 23 ++++++++----------- .../apache/spark/sql/catalyst/SqlParser.scala | 20 ++++------------ .../spark/sql/catalyst/SqlParserSuite.scala | 21 +++++++++++++++++ 3 files changed, 35 insertions(+), 29 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala index 04ac4f20c66ec..bdc52c08acb66 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala @@ -78,10 +78,6 @@ private[sql] abstract class AbstractSparkSQLParser } class SqlLexical extends StdLexical { - case class FloatLit(chars: String) extends Token { - override def toString: String = chars - } - case class DecimalLit(chars: String) extends Token { override def toString: String = chars } @@ -106,17 +102,16 @@ class SqlLexical extends StdLexical { } override lazy val token: Parser[Token] = - ( rep1(digit) ~ ('.' ~> digit.*).? ~ (exp ~> sign.? ~ rep1(digit)) ^^ { - case i ~ None ~ (sig ~ rest) => - DecimalLit(i.mkString + "e" + sig.mkString + rest.mkString) - case i ~ Some(d) ~ (sig ~ rest) => - DecimalLit(i.mkString + "." + d.mkString + "e" + sig.mkString + rest.mkString) - } + ( rep1(digit) ~ scientificNotation ^^ { case i ~ s => DecimalLit(i.mkString + s) } + | '.' ~> (rep1(digit) ~ scientificNotation) ^^ + { case i ~ s => DecimalLit("0." + i.mkString + s) } + | rep1(digit) ~ ('.' ~> digit.*) ~ scientificNotation ^^ + { case i1 ~ i2 ~ s => DecimalLit(i1.mkString + "." + i2.mkString + s) } | digit.* ~ identChar ~ (identChar | digit).* ^^ { case first ~ middle ~ rest => processIdent((first ++ (middle :: rest)).mkString) } | rep1(digit) ~ ('.' ~> digit.*).? ^^ { case i ~ None => NumericLit(i.mkString) - case i ~ Some(d) => FloatLit(i.mkString + "." + d.mkString) + case i ~ Some(d) => DecimalLit(i.mkString + "." + d.mkString) } | '\'' ~> chrExcept('\'', '\n', EofCh).* <~ '\'' ^^ { case chars => StringLit(chars mkString "") } @@ -133,8 +128,10 @@ class SqlLexical extends StdLexical { override def identChar: Parser[Elem] = letter | elem('_') - private lazy val sign: Parser[Elem] = elem("s", c => c == '+' || c == '-') - private lazy val exp: Parser[Elem] = elem("e", c => c == 'E' || c == 'e') + private lazy val scientificNotation: Parser[String] = + (elem('e') | elem('E')) ~> (elem('+') | elem('-')).? ~ rep1(digit) ^^ { + case s ~ rest => "e" + s.mkString + rest.mkString + } override def whitespace: Parser[Any] = ( whitespaceChar diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index 440e9e28fa783..cd717c09f8e5e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -334,27 +334,15 @@ object SqlParser extends AbstractSparkSQLParser with DataTypeParser { protected lazy val numericLiteral: Parser[Literal] = ( integral ^^ { case i => Literal(toNarrowestIntegerType(i)) } - | sign.? ~ unsignedFloat ^^ { - case s ~ f => Literal(toDecimalOrDouble(s.getOrElse("") + f)) - } - | sign.? ~ unsignedDecimal ^^ { - case s ~ d => Literal(toDecimalOrDouble(s.getOrElse("") + d)) - } + | sign.? ~ unsignedFloat ^^ + { case s ~ f => Literal(toDecimalOrDouble(s.getOrElse("") + f)) } ) protected lazy val unsignedFloat: Parser[String] = ( "." ~> numericLit ^^ { u => "0." + u } - | elem("decimal", _.isInstanceOf[lexical.FloatLit]) ^^ (_.chars) + | elem("decimal", _.isInstanceOf[lexical.DecimalLit]) ^^ (_.chars) ) - protected lazy val unsignedDecimal: Parser[String] = - ( "." ~> decimalLit ^^ { u => "0." + u } - | elem("scientific_notation", _.isInstanceOf[lexical.DecimalLit]) ^^ (_.chars) - ) - - def decimalLit: Parser[String] = - elem("scientific_notation", _.isInstanceOf[lexical.DecimalLit]) ^^ (_.chars) - protected lazy val sign: Parser[String] = ("+" | "-") protected lazy val integral: Parser[String] = @@ -477,7 +465,7 @@ object SqlParser extends AbstractSparkSQLParser with DataTypeParser { protected lazy val baseExpression: Parser[Expression] = ( "*" ^^^ UnresolvedStar(None) - | (ident <~ "."). + <~ "*" ^^ { case target => UnresolvedStar(Option(target))} + | rep1(ident <~ ".") <~ "*" ^^ { case target => UnresolvedStar(Option(target))} | primary ) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SqlParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SqlParserSuite.scala index ea28bfa021bed..9ff893b84775b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SqlParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SqlParserSuite.scala @@ -126,4 +126,25 @@ class SqlParserSuite extends PlanTest { checkSingleUnit("13.123456789", "second") checkSingleUnit("-13.123456789", "second") } + + test("support scientific notation") { + def assertRight(input: String, output: Double): Unit = { + val parsed = SqlParser.parse("SELECT " + input) + val expected = Project( + UnresolvedAlias( + Literal(output) + ) :: Nil, + OneRowRelation) + comparePlans(parsed, expected) + } + + assertRight("9.0e1", 90) + assertRight(".9e+2", 90) + assertRight("0.9e+2", 90) + assertRight("900e-1", 90) + assertRight("900.0E-1", 90) + assertRight("9.e+1", 90) + + intercept[RuntimeException](SqlParser.parse("SELECT .e3")) + } } From cf69ce136590fea51843bc54f44f0f45c7d0ac36 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Fri, 6 Nov 2015 14:51:53 +0000 Subject: [PATCH 13/88] [SPARK-11511][STREAMING] Fix NPE when an InputDStream is not used Just ignored `InputDStream`s that have null `rememberDuration` in `DStreamGraph.getMaxInputStreamRememberDuration`. Author: Shixiong Zhu Closes #9476 from zsxwing/SPARK-11511. --- .../apache/spark/streaming/DStreamGraph.scala | 3 ++- .../spark/streaming/StreamingContextSuite.scala | 16 ++++++++++++++++ 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala b/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala index 1b0b7890b3b00..7829f5e887995 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala @@ -167,7 +167,8 @@ final private[streaming] class DStreamGraph extends Serializable with Logging { * safe remember duration which can be used to perform cleanup operations. */ def getMaxInputStreamRememberDuration(): Duration = { - inputStreams.map { _.rememberDuration }.maxBy { _.milliseconds } + // If an InputDStream is not used, its `rememberDuration` will be null and we can ignore them + inputStreams.map(_.rememberDuration).filter(_ != null).maxBy(_.milliseconds) } @throws(classOf[IOException]) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala index c7a877142b374..860fac29c0ee0 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala @@ -780,6 +780,22 @@ class StreamingContextSuite extends SparkFunSuite with BeforeAndAfter with Timeo "Please don't use queueStream when checkpointing is enabled.")) } + test("Creating an InputDStream but not using it should not crash") { + ssc = new StreamingContext(master, appName, batchDuration) + val input1 = addInputStream(ssc) + val input2 = addInputStream(ssc) + val output = new TestOutputStream(input2) + output.register() + val batchCount = new BatchCounter(ssc) + ssc.start() + // Just wait for completing 2 batches to make sure it triggers + // `DStream.getMaxInputStreamRememberDuration` + batchCount.waitUntilBatchesCompleted(2, 10000) + // Throw the exception if crash + ssc.awaitTerminationOrTimeout(1) + ssc.stop() + } + def addInputStream(s: StreamingContext): DStream[Int] = { val input = (1 to 100).map(i => 1 to i) val inputStream = new TestInputStream(s, input, 1) From 574141a29835ce78d68c97bb54336cf4fd3c39d3 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 6 Nov 2015 10:52:04 -0800 Subject: [PATCH 14/88] [SPARK-9162] [SQL] Implement code generation for ScalaUDF JIRA: https://issues.apache.org/jira/browse/SPARK-9162 Currently ScalaUDF extends CodegenFallback and doesn't provide code generation implementation. This path implements code generation for ScalaUDF. Author: Liang-Chi Hsieh Closes #9270 from viirya/scalaudf-codegen. --- .../sql/catalyst/expressions/ScalaUDF.scala | 85 ++++++++++++++++++- .../scala/org/apache/spark/sql/UDFSuite.scala | 41 +++++++++ 2 files changed, 124 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala index 11c7950c0613b..3388cc20a9803 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.CatalystTypeConverters -import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.types.DataType /** @@ -31,7 +31,7 @@ case class ScalaUDF( dataType: DataType, children: Seq[Expression], inputTypes: Seq[DataType] = Nil) - extends Expression with ImplicitCastInputTypes with CodegenFallback { + extends Expression with ImplicitCastInputTypes { override def nullable: Boolean = true @@ -60,6 +60,10 @@ case class ScalaUDF( */ + // Accessors used in genCode + def userDefinedFunc(): AnyRef = function + def getChildren(): Seq[Expression] = children + private[this] val f = children.size match { case 0 => val func = function.asInstanceOf[() => Any] @@ -960,6 +964,83 @@ case class ScalaUDF( } // scalastyle:on + + // Generate codes used to convert the arguments to Scala type for user-defined funtions + private[this] def genCodeForConverter(ctx: CodeGenContext, index: Int): String = { + val converterClassName = classOf[Any => Any].getName + val typeConvertersClassName = CatalystTypeConverters.getClass.getName + ".MODULE$" + val expressionClassName = classOf[Expression].getName + val scalaUDFClassName = classOf[ScalaUDF].getName + + val converterTerm = ctx.freshName("converter") + val expressionIdx = ctx.references.size - 1 + ctx.addMutableState(converterClassName, converterTerm, + s"this.$converterTerm = ($converterClassName)$typeConvertersClassName" + + s".createToScalaConverter(((${expressionClassName})((($scalaUDFClassName)" + + s"expressions[$expressionIdx]).getChildren().apply($index))).dataType());") + converterTerm + } + + override def genCode( + ctx: CodeGenContext, + ev: GeneratedExpressionCode): String = { + + ctx.references += this + + val scalaUDFClassName = classOf[ScalaUDF].getName + val converterClassName = classOf[Any => Any].getName + val typeConvertersClassName = CatalystTypeConverters.getClass.getName + ".MODULE$" + val expressionClassName = classOf[Expression].getName + + // Generate codes used to convert the returned value of user-defined functions to Catalyst type + val catalystConverterTerm = ctx.freshName("catalystConverter") + val catalystConverterTermIdx = ctx.references.size - 1 + ctx.addMutableState(converterClassName, catalystConverterTerm, + s"this.$catalystConverterTerm = ($converterClassName)$typeConvertersClassName" + + s".createToCatalystConverter((($scalaUDFClassName)expressions" + + s"[$catalystConverterTermIdx]).dataType());") + + val resultTerm = ctx.freshName("result") + + // This must be called before children expressions' codegen + // because ctx.references is used in genCodeForConverter + val converterTerms = (0 until children.size).map(genCodeForConverter(ctx, _)) + + // Initialize user-defined function + val funcClassName = s"scala.Function${children.size}" + + val funcTerm = ctx.freshName("udf") + val funcExpressionIdx = ctx.references.size - 1 + ctx.addMutableState(funcClassName, funcTerm, + s"this.$funcTerm = ($funcClassName)((($scalaUDFClassName)expressions" + + s"[$funcExpressionIdx]).userDefinedFunc());") + + // codegen for children expressions + val evals = children.map(_.gen(ctx)) + + // Generate the codes for expressions and calling user-defined function + // We need to get the boxedType of dataType's javaType here. Because for the dataType + // such as IntegerType, its javaType is `int` and the returned type of user-defined + // function is Object. Trying to convert an Object to `int` will cause casting exception. + val evalCode = evals.map(_.code).mkString + val funcArguments = converterTerms.zip(evals).map { + case (converter, eval) => s"$converter.apply(${eval.value})" + }.mkString(",") + val callFunc = s"${ctx.boxedType(ctx.javaType(dataType))} $resultTerm = " + + s"(${ctx.boxedType(ctx.javaType(dataType))})${catalystConverterTerm}" + + s".apply($funcTerm.apply($funcArguments));" + + evalCode + s""" + ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + Boolean ${ev.isNull}; + + $callFunc + + ${ev.value} = $resultTerm; + ${ev.isNull} = $resultTerm == null; + """ + } + private[this] val converter = CatalystTypeConverters.createToCatalystConverter(dataType) override def eval(input: InternalRow): Any = converter(f(input)) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index e0435a0dba6ad..9837fa6bdb357 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -191,4 +191,45 @@ class UDFSuite extends QueryTest with SharedSQLContext { // pass a decimal to intExpected. assert(sql("SELECT intExpected(1.0)").head().getInt(0) === 1) } + + test("udf in different types") { + sqlContext.udf.register("testDataFunc", (n: Int, s: String) => { (n, s) }) + sqlContext.udf.register("decimalDataFunc", + (a: java.math.BigDecimal, b: java.math.BigDecimal) => { (a, b) }) + sqlContext.udf.register("binaryDataFunc", (a: Array[Byte], b: Int) => { (a, b) }) + sqlContext.udf.register("arrayDataFunc", + (data: Seq[Int], nestedData: Seq[Seq[Int]]) => { (data, nestedData) }) + sqlContext.udf.register("mapDataFunc", + (data: scala.collection.Map[Int, String]) => { data }) + sqlContext.udf.register("complexDataFunc", + (m: Map[String, Int], a: Seq[Int], b: Boolean) => { (m, a, b) } ) + + checkAnswer( + sql("SELECT tmp.t.* FROM (SELECT testDataFunc(key, value) AS t from testData) tmp").toDF(), + testData) + checkAnswer( + sql(""" + | SELECT tmp.t.* FROM + | (SELECT decimalDataFunc(a, b) AS t FROM decimalData) tmp + """.stripMargin).toDF(), decimalData) + checkAnswer( + sql(""" + | SELECT tmp.t.* FROM + | (SELECT binaryDataFunc(a, b) AS t FROM binaryData) tmp + """.stripMargin).toDF(), binaryData) + checkAnswer( + sql(""" + | SELECT tmp.t.* FROM + | (SELECT arrayDataFunc(data, nestedData) AS t FROM arrayData) tmp + """.stripMargin).toDF(), arrayData.toDF()) + checkAnswer( + sql(""" + | SELECT mapDataFunc(data) AS t FROM mapData + """.stripMargin).toDF(), mapData.toDF()) + checkAnswer( + sql(""" + | SELECT tmp.t.* FROM + | (SELECT complexDataFunc(m, a, b) AS t FROM complexData) tmp + """.stripMargin).toDF(), complexData.select("m", "a", "b")) + } } From c048929c6a9f7ce57f384037cd6c0bf5751c447a Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Fri, 6 Nov 2015 11:11:36 -0800 Subject: [PATCH 15/88] [SPARK-10978][SQL][FOLLOW-UP] More comprehensive tests for PR #9399 This PR adds test cases that test various column pruning and filter push-down cases. Author: Cheng Lian Closes #9468 from liancheng/spark-10978.follow-up. --- .../spark/sql/sources/FilteredScanSuite.scala | 21 +- .../SimpleTextHadoopFsRelationSuite.scala | 335 ++++++++++++++++-- .../sql/sources/SimpleTextRelation.scala | 11 + 3 files changed, 321 insertions(+), 46 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala index 7541e723029bf..2cad964e55b2b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala @@ -17,16 +17,15 @@ package org.apache.spark.sql.sources -import org.apache.spark.sql.execution.datasources.LogicalRelation - import scala.language.existentials import org.apache.spark.rdd.RDD -import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.expressions.PredicateHelper +import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ - +import org.apache.spark.unsafe.types.UTF8String class FilteredScanSource extends RelationProvider { override def createRelation( @@ -130,7 +129,7 @@ object ColumnsRequired { var set: Set[String] = Set.empty } -class FilteredScanSuite extends DataSourceTest with SharedSQLContext { +class FilteredScanSuite extends DataSourceTest with SharedSQLContext with PredicateHelper { protected override lazy val sql = caseInsensitiveContext.sql _ override def beforeAll(): Unit = { @@ -144,9 +143,6 @@ class FilteredScanSuite extends DataSourceTest with SharedSQLContext { | to '10' |) """.stripMargin) - - // UDF for testing filter push-down - caseInsensitiveContext.udf.register("udf_gt3", (_: Int) > 3) } sqlTest( @@ -276,14 +272,15 @@ class FilteredScanSuite extends DataSourceTest with SharedSQLContext { testPushDown("SELECT c FROM oneToTenFiltered WHERE c = 'aaaaaAAAAA'", 1, Set("c")) testPushDown("SELECT c FROM oneToTenFiltered WHERE c IN ('aaaaaAAAAA', 'foo')", 1, Set("c")) - // Columns only referenced by UDF filter must be required, as UDF filters can't be pushed down. - testPushDown("SELECT c FROM oneToTenFiltered WHERE udf_gt3(A)", 10, Set("a", "c")) + // Filters referencing multiple columns are not convertible, all referenced columns must be + // required. + testPushDown("SELECT c FROM oneToTenFiltered WHERE A + b > 9", 10, Set("a", "b", "c")) - // A query with an unconvertible filter, an unhandled filter, and a handled filter. + // A query with an inconvertible filter, an unhandled filter, and a handled filter. testPushDown( """SELECT a | FROM oneToTenFiltered - | WHERE udf_gt3(b) + | WHERE a + b > 9 | AND b < 16 | AND c IN ('bbbbbBBBBB', 'cccccCCCCC', 'dddddDDDDD', 'foo') """.stripMargin.split("\n").map(_.trim).mkString(" "), 3, Set("a", "b")) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala index d945408341fc9..9251a69f31a47 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala @@ -17,15 +17,21 @@ package org.apache.spark.sql.sources +import java.io.File + import org.apache.hadoop.fs.Path + +import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.CatalystTypeConverters -import org.apache.spark.sql.execution.PhysicalRDD +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, PredicateHelper} +import org.apache.spark.sql.execution.{LogicalRDD, PhysicalRDD} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ +import org.apache.spark.sql.{Column, DataFrame, Row, execution} +import org.apache.spark.util.Utils -class SimpleTextHadoopFsRelationSuite extends HadoopFsRelationTest { +class SimpleTextHadoopFsRelationSuite extends HadoopFsRelationTest with PredicateHelper { import testImplicits._ override val dataSourceName: String = classOf[SimpleTextSource].getCanonicalName @@ -70,43 +76,304 @@ class SimpleTextHadoopFsRelationSuite extends HadoopFsRelationTest { } } - private val writer = testDF.write.option("dataSchema", dataSchema.json).format(dataSourceName) - private val reader = sqlContext.read.option("dataSchema", dataSchema.json).format(dataSourceName) - - test("unhandledFilters") { - withTempPath { dir => - - val path = dir.getCanonicalPath - writer.save(s"$path/p=0") - writer.save(s"$path/p=1") - - val isOdd = udf((_: Int) % 2 == 1) - val df = reader.load(path) - .filter( - // This filter is inconvertible - isOdd('a) && - // This filter is convertible but unhandled - 'a > 1 && - // This filter is convertible and handled - 'b > "val_1" && - // This filter references a partiiton column, won't be pushed down - 'p === 1 - ).select('a, 'p) - val rawScan = df.queryExecution.executedPlan collect { + private var tempPath: File = _ + + private var partitionedDF: DataFrame = _ + + private val partitionedDataSchema: StructType = StructType('a.int :: 'b.int :: 'c.string :: Nil) + + protected override def beforeAll(): Unit = { + this.tempPath = Utils.createTempDir() + + val df = sqlContext.range(10).select( + 'id cast IntegerType as 'a, + ('id cast IntegerType) * 2 as 'b, + concat(lit("val_"), 'id) as 'c + ) + + partitionedWriter(df).save(s"${tempPath.getCanonicalPath}/p=0") + partitionedWriter(df).save(s"${tempPath.getCanonicalPath}/p=1") + + partitionedDF = partitionedReader.load(tempPath.getCanonicalPath) + } + + override protected def afterAll(): Unit = { + Utils.deleteRecursively(tempPath) + } + + private def partitionedWriter(df: DataFrame) = + df.write.option("dataSchema", partitionedDataSchema.json).format(dataSourceName) + + private def partitionedReader = + sqlContext.read.option("dataSchema", partitionedDataSchema.json).format(dataSourceName) + + /** + * Constructs test cases that test column pruning and filter push-down. + * + * For filter push-down, the following filters are not pushed-down. + * + * 1. Partitioning filters don't participate filter push-down, they are handled separately in + * `DataSourceStrategy` + * + * 2. Catalyst filter `Expression`s that cannot be converted to data source `Filter`s are not + * pushed down (e.g. UDF and filters referencing multiple columns). + * + * 3. Catalyst filter `Expression`s that can be converted to data source `Filter`s but cannot be + * handled by the underlying data source are not pushed down (e.g. returned from + * `BaseRelation.unhandledFilters()`). + * + * Note that for [[SimpleTextRelation]], all data source [[Filter]]s other than [[GreaterThan]] + * are unhandled. We made this assumption in [[SimpleTextRelation.unhandledFilters()]] only + * for testing purposes. + * + * @param projections Projection list of the query + * @param filter Filter condition of the query + * @param requiredColumns Expected names of required columns + * @param pushedFilters Expected data source [[Filter]]s that are pushed down + * @param inconvertibleFilters Expected Catalyst filter [[Expression]]s that cannot be converted + * to data source [[Filter]]s + * @param unhandledFilters Expected Catalyst flter [[Expression]]s that can be converted to data + * source [[Filter]]s but cannot be handled by the data source relation + * @param partitioningFilters Expected Catalyst filter [[Expression]]s that reference partition + * columns + * @param expectedRawScanAnswer Expected query result of the raw table scan returned by the data + * source relation + * @param expectedAnswer Expected query result of the full query + */ + def testPruningAndFiltering( + projections: Seq[Column], + filter: Column, + requiredColumns: Seq[String], + pushedFilters: Seq[Filter], + inconvertibleFilters: Seq[Column], + unhandledFilters: Seq[Column], + partitioningFilters: Seq[Column])( + expectedRawScanAnswer: => Seq[Row])( + expectedAnswer: => Seq[Row]): Unit = { + test(s"pruning and filtering: df.select(${projections.mkString(", ")}).where($filter)") { + val df = partitionedDF.where(filter).select(projections: _*) + val queryExecution = df.queryExecution + val executedPlan = queryExecution.executedPlan + + val rawScan = executedPlan.collect { case p: PhysicalRDD => p } match { - case Seq(p) => p + case Seq(scan) => scan + case _ => fail(s"More than one PhysicalRDD found\n$queryExecution") } - val outputSchema = new StructType().add("a", IntegerType).add("p", IntegerType) + markup("Checking raw scan answer") + checkAnswer( + DataFrame(sqlContext, LogicalRDD(rawScan.output, rawScan.rdd)(sqlContext)), + expectedRawScanAnswer) - assertResult(Set((2, 1), (3, 1))) { - rawScan.execute().collect() - .map { CatalystTypeConverters.convertToScala(_, outputSchema) } - .map { case Row(a, p) => (a, p) }.toSet + markup("Checking full query answer") + checkAnswer(df, expectedAnswer) + + markup("Checking required columns") + assert(requiredColumns === SimpleTextRelation.requiredColumns) + + val nonPushedFilters = { + val boundFilters = executedPlan.collect { + case f: execution.Filter => f + } match { + case Nil => Nil + case Seq(f) => splitConjunctivePredicates(f.condition) + case _ => fail(s"More than one PhysicalRDD found\n$queryExecution") + } + + // Unbound these bound filters so that we can easily compare them with expected results. + boundFilters.map { + _.transform { case a: AttributeReference => UnresolvedAttribute(a.name) } + }.toSet } - checkAnswer(df, Row(3, 1)) + markup("Checking pushed filters") + assert(SimpleTextRelation.pushedFilters === pushedFilters.toSet) + + val expectedInconvertibleFilters = inconvertibleFilters.map(_.expr).toSet + val expectedUnhandledFilters = unhandledFilters.map(_.expr).toSet + val expectedPartitioningFilters = partitioningFilters.map(_.expr).toSet + + markup("Checking unhandled and inconvertible filters") + assert(expectedInconvertibleFilters ++ expectedUnhandledFilters === nonPushedFilters) + + markup("Checking partitioning filters") + val actualPartitioningFilters = splitConjunctivePredicates(filter.expr).filter { + _.references.contains(UnresolvedAttribute("p")) + }.toSet + + // Partitioning filters are handled separately and don't participate filter push-down. So they + // shouldn't be part of non-pushed filters. + assert(expectedPartitioningFilters.intersect(nonPushedFilters).isEmpty) + assert(expectedPartitioningFilters === actualPartitioningFilters) } } + + testPruningAndFiltering( + projections = Seq('*), + filter = 'p > 0, + requiredColumns = Seq("a", "b", "c"), + pushedFilters = Nil, + inconvertibleFilters = Nil, + unhandledFilters = Nil, + partitioningFilters = Seq('p > 0) + ) { + Seq( + Row(0, 0, "val_0", 1), + Row(1, 2, "val_1", 1), + Row(2, 4, "val_2", 1), + Row(3, 6, "val_3", 1), + Row(4, 8, "val_4", 1), + Row(5, 10, "val_5", 1), + Row(6, 12, "val_6", 1), + Row(7, 14, "val_7", 1), + Row(8, 16, "val_8", 1), + Row(9, 18, "val_9", 1)) + } { + Seq( + Row(0, 0, "val_0", 1), + Row(1, 2, "val_1", 1), + Row(2, 4, "val_2", 1), + Row(3, 6, "val_3", 1), + Row(4, 8, "val_4", 1), + Row(5, 10, "val_5", 1), + Row(6, 12, "val_6", 1), + Row(7, 14, "val_7", 1), + Row(8, 16, "val_8", 1), + Row(9, 18, "val_9", 1)) + } + + testPruningAndFiltering( + projections = Seq('c, 'p), + filter = 'a < 3 && 'p > 0, + requiredColumns = Seq("c", "a"), + pushedFilters = Nil, + inconvertibleFilters = Nil, + unhandledFilters = Seq('a < 3), + partitioningFilters = Seq('p > 0) + ) { + Seq( + Row("val_0", 1, 0), + Row("val_1", 1, 1), + Row("val_2", 1, 2), + Row("val_3", 1, 3), + Row("val_4", 1, 4), + Row("val_5", 1, 5), + Row("val_6", 1, 6), + Row("val_7", 1, 7), + Row("val_8", 1, 8), + Row("val_9", 1, 9)) + } { + Seq( + Row("val_0", 1), + Row("val_1", 1), + Row("val_2", 1)) + } + + testPruningAndFiltering( + projections = Seq('*), + filter = 'a > 8, + requiredColumns = Seq("a", "b", "c"), + pushedFilters = Seq(GreaterThan("a", 8)), + inconvertibleFilters = Nil, + unhandledFilters = Nil, + partitioningFilters = Nil + ) { + Seq( + Row(9, 18, "val_9", 0), + Row(9, 18, "val_9", 1)) + } { + Seq( + Row(9, 18, "val_9", 0), + Row(9, 18, "val_9", 1)) + } + + testPruningAndFiltering( + projections = Seq('b, 'p), + filter = 'a > 8, + requiredColumns = Seq("b"), + pushedFilters = Seq(GreaterThan("a", 8)), + inconvertibleFilters = Nil, + unhandledFilters = Nil, + partitioningFilters = Nil + ) { + Seq( + Row(18, 0), + Row(18, 1)) + } { + Seq( + Row(18, 0), + Row(18, 1)) + } + + testPruningAndFiltering( + projections = Seq('b, 'p), + filter = 'a > 8 && 'p > 0, + requiredColumns = Seq("b"), + pushedFilters = Seq(GreaterThan("a", 8)), + inconvertibleFilters = Nil, + unhandledFilters = Nil, + partitioningFilters = Seq('p > 0) + ) { + Seq( + Row(18, 1)) + } { + Seq( + Row(18, 1)) + } + + testPruningAndFiltering( + projections = Seq('b, 'p), + filter = 'c > "val_7" && 'b < 18 && 'p > 0, + requiredColumns = Seq("b"), + pushedFilters = Seq(GreaterThan("c", "val_7")), + inconvertibleFilters = Nil, + unhandledFilters = Seq('b < 18), + partitioningFilters = Seq('p > 0) + ) { + Seq( + Row(16, 1), + Row(18, 1)) + } { + Seq( + Row(16, 1)) + } + + testPruningAndFiltering( + projections = Seq('b, 'p), + filter = 'a % 2 === 0 && 'c > "val_7" && 'b < 18 && 'p > 0, + requiredColumns = Seq("b", "a"), + pushedFilters = Seq(GreaterThan("c", "val_7")), + inconvertibleFilters = Seq('a % 2 === 0), + unhandledFilters = Seq('b < 18), + partitioningFilters = Seq('p > 0) + ) { + Seq( + Row(16, 1, 8), + Row(18, 1, 9)) + } { + Seq( + Row(16, 1)) + } + + testPruningAndFiltering( + projections = Seq('b, 'p), + filter = 'a > 7 && 'a < 9, + requiredColumns = Seq("b", "a"), + pushedFilters = Seq(GreaterThan("a", 7)), + inconvertibleFilters = Nil, + unhandledFilters = Seq('a < 9), + partitioningFilters = Nil + ) { + Seq( + Row(16, 0, 8), + Row(16, 1, 8), + Row(18, 0, 9), + Row(18, 1, 9)) + } { + Seq( + Row(16, 0), + Row(16, 1)) + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala index da09e1b00ae48..bdc48a383bbbf 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala @@ -128,6 +128,9 @@ class SimpleTextRelation( filters: Array[Filter], inputFiles: Array[FileStatus]): RDD[Row] = { + SimpleTextRelation.requiredColumns = requiredColumns + SimpleTextRelation.pushedFilters = filters.toSet + val fields = this.dataSchema.map(_.dataType) val inputAttributes = this.dataSchema.toAttributes val outputAttributes = requiredColumns.flatMap(name => inputAttributes.find(_.name == name)) @@ -191,6 +194,14 @@ class SimpleTextRelation( } } +object SimpleTextRelation { + // Used to test column pruning + var requiredColumns: Seq[String] = Nil + + // Used to test filter push-down + var pushedFilters: Set[Filter] = Set.empty +} + /** * A simple example [[HadoopFsRelationProvider]]. */ From 8211aab0793cf64202b99be4f31bb8a9ae77050d Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Fri, 6 Nov 2015 11:13:51 -0800 Subject: [PATCH 16/88] [SPARK-9858][SQL] Add an ExchangeCoordinator to estimate the number of post-shuffle partitions for aggregates and joins (follow-up) https://issues.apache.org/jira/browse/SPARK-9858 This PR is the follow-up work of https://github.com/apache/spark/pull/9276. It addresses JoshRosen's comments. Author: Yin Huai Closes #9453 from yhuai/numReducer-followUp. --- .../plans/physical/partitioning.scala | 8 - .../apache/spark/sql/execution/Exchange.scala | 40 +++-- .../sql/execution/ExchangeCoordinator.scala | 31 ++-- .../apache/spark/sql/CachedTableSuite.scala | 150 ++++++++++++++---- 4 files changed, 167 insertions(+), 62 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index 9312c8123e92e..86b9417477ba3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -165,11 +165,6 @@ sealed trait Partitioning { * produced by `A` could have also been produced by `B`. */ def guarantees(other: Partitioning): Boolean = this == other - - def withNumPartitions(newNumPartitions: Int): Partitioning = { - throw new IllegalStateException( - s"It is not allowed to call withNumPartitions method of a ${this.getClass.getSimpleName}") - } } object Partitioning { @@ -254,9 +249,6 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) case _ => false } - override def withNumPartitions(newNumPartitions: Int): HashPartitioning = { - HashPartitioning(expressions, newNumPartitions) - } } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index 0f72ec6cc107a..a4ce328c1a9eb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -242,7 +242,7 @@ case class Exchange( // update the number of post-shuffle partitions. specifiedPartitionStartIndices.foreach { indices => assert(newPartitioning.isInstanceOf[HashPartitioning]) - newPartitioning = newPartitioning.withNumPartitions(indices.length) + newPartitioning = UnknownPartitioning(indices.length) } new ShuffledRowRDD(shuffleDependency, specifiedPartitionStartIndices) } @@ -262,7 +262,7 @@ case class Exchange( object Exchange { def apply(newPartitioning: Partitioning, child: SparkPlan): Exchange = { - Exchange(newPartitioning, child, None: Option[ExchangeCoordinator]) + Exchange(newPartitioning, child, coordinator = None: Option[ExchangeCoordinator]) } } @@ -315,7 +315,7 @@ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[ child.outputPartitioning match { case hash: HashPartitioning => true case collection: PartitioningCollection => - collection.partitionings.exists(_.isInstanceOf[HashPartitioning]) + collection.partitionings.forall(_.isInstanceOf[HashPartitioning]) case _ => false } } @@ -416,28 +416,48 @@ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[ // First check if the existing partitions of the children all match. This means they are // partitioned by the same partitioning into the same number of partitions. In that case, // don't try to make them match `defaultPartitions`, just use the existing partitioning. - // TODO: this should be a cost based decision. For example, a big relation should probably - // maintain its existing number of partitions and smaller partitions should be shuffled. - // defaultPartitions is arbitrary. - val numPartitions = children.head.outputPartitioning.numPartitions + val maxChildrenNumPartitions = children.map(_.outputPartitioning.numPartitions).max val useExistingPartitioning = children.zip(requiredChildDistributions).forall { case (child, distribution) => { child.outputPartitioning.guarantees( - createPartitioning(distribution, numPartitions)) + createPartitioning(distribution, maxChildrenNumPartitions)) } } children = if (useExistingPartitioning) { + // We do not need to shuffle any child's output. children } else { + // We need to shuffle at least one child's output. + // Now, we will determine the number of partitions that will be used by created + // partitioning schemes. + val numPartitions = { + // Let's see if we need to shuffle all child's outputs when we use + // maxChildrenNumPartitions. + val shufflesAllChildren = children.zip(requiredChildDistributions).forall { + case (child, distribution) => { + !child.outputPartitioning.guarantees( + createPartitioning(distribution, maxChildrenNumPartitions)) + } + } + // If we need to shuffle all children, we use defaultNumPreShufflePartitions as the + // number of partitions. Otherwise, we use maxChildrenNumPartitions. + if (shufflesAllChildren) defaultNumPreShufflePartitions else maxChildrenNumPartitions + } + children.zip(requiredChildDistributions).map { case (child, distribution) => { val targetPartitioning = - createPartitioning(distribution, defaultNumPreShufflePartitions) + createPartitioning(distribution, numPartitions) if (child.outputPartitioning.guarantees(targetPartitioning)) { child } else { - Exchange(targetPartitioning, child) + child match { + // If child is an exchange, we replace it with + // a new one having targetPartitioning. + case Exchange(_, c, _) => Exchange(targetPartitioning, c) + case _ => Exchange(targetPartitioning, child) + } } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExchangeCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExchangeCoordinator.scala index 8dbd69e1f44b8..827fdd278460a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExchangeCoordinator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExchangeCoordinator.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution import java.util.{Map => JMap, HashMap => JHashMap} +import javax.annotation.concurrent.GuardedBy import scala.collection.mutable.ArrayBuffer @@ -97,6 +98,7 @@ private[sql] class ExchangeCoordinator( * Registers an [[Exchange]] operator to this coordinator. This method is only allowed to be * called in the `doPrepare` method of an [[Exchange]] operator. */ + @GuardedBy("this") def registerExchange(exchange: Exchange): Unit = synchronized { exchanges += exchange } @@ -109,7 +111,7 @@ private[sql] class ExchangeCoordinator( */ private[sql] def estimatePartitionStartIndices( mapOutputStatistics: Array[MapOutputStatistics]): Array[Int] = { - // If we have mapOutputStatistics.length <= numExchange, it is because we do not submit + // If we have mapOutputStatistics.length < numExchange, it is because we do not submit // a stage when the number of partitions of this dependency is 0. assert(mapOutputStatistics.length <= numExchanges) @@ -121,6 +123,8 @@ private[sql] class ExchangeCoordinator( val totalPostShuffleInputSize = mapOutputStatistics.map(_.bytesByPartitionId.sum).sum // The max at here is to make sure that when we have an empty table, we // only have a single post-shuffle partition. + // There is no particular reason that we pick 16. We just need a number to + // prevent maxPostShuffleInputSize from being set to 0. val maxPostShuffleInputSize = math.max(math.ceil(totalPostShuffleInputSize / numPartitions.toDouble).toLong, 16) math.min(maxPostShuffleInputSize, advisoryTargetPostShuffleInputSize) @@ -135,6 +139,12 @@ private[sql] class ExchangeCoordinator( // Make sure we do get the same number of pre-shuffle partitions for those stages. val distinctNumPreShufflePartitions = mapOutputStatistics.map(stats => stats.bytesByPartitionId.length).distinct + // The reason that we are expecting a single value of the number of pre-shuffle partitions + // is that when we add Exchanges, we set the number of pre-shuffle partitions + // (i.e. map output partitions) using a static setting, which is the value of + // spark.sql.shuffle.partitions. Even if two input RDDs are having different + // number of partitions, they will have the same number of pre-shuffle partitions + // (i.e. map output partitions). assert( distinctNumPreShufflePartitions.length == 1, "There should be only one distinct value of the number pre-shuffle partitions " + @@ -177,6 +187,7 @@ private[sql] class ExchangeCoordinator( partitionStartIndices.toArray } + @GuardedBy("this") private def doEstimationIfNecessary(): Unit = synchronized { // It is unlikely that this method will be called from multiple threads // (when multiple threads trigger the execution of THIS physical) @@ -209,11 +220,11 @@ private[sql] class ExchangeCoordinator( // Wait for the finishes of those submitted map stages. val mapOutputStatistics = new Array[MapOutputStatistics](submittedStageFutures.length) - i = 0 - while (i < submittedStageFutures.length) { + var j = 0 + while (j < submittedStageFutures.length) { // This call is a blocking call. If the stage has not finished, we will wait at here. - mapOutputStatistics(i) = submittedStageFutures(i).get() - i += 1 + mapOutputStatistics(j) = submittedStageFutures(j).get() + j += 1 } // Now, we estimate partitionStartIndices. partitionStartIndices.length will be the @@ -225,14 +236,14 @@ private[sql] class ExchangeCoordinator( Some(estimatePartitionStartIndices(mapOutputStatistics)) } - i = 0 - while (i < numExchanges) { - val exchange = exchanges(i) + var k = 0 + while (k < numExchanges) { + val exchange = exchanges(k) val rdd = - exchange.preparePostShuffleRDD(shuffleDependencies(i), partitionStartIndices) + exchange.preparePostShuffleRDD(shuffleDependencies(k), partitionStartIndices) newPostShuffleRDDs.put(exchange, rdd) - i += 1 + k += 1 } // Finally, we set postShuffleRDDs and estimated. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index dbcb011f603f7..bce94dafad755 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -29,12 +29,12 @@ import org.scalatest.concurrent.Eventually._ import org.apache.spark.Accumulators import org.apache.spark.sql.columnar._ import org.apache.spark.sql.functions._ -import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.test.{SQLTestUtils, SharedSQLContext} import org.apache.spark.storage.{StorageLevel, RDDBlockId} private case class BigData(s: String) -class CachedTableSuite extends QueryTest with SharedSQLContext { +class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext { import testImplicits._ def rddIdOf(tableName: String): Int = { @@ -375,53 +375,135 @@ class CachedTableSuite extends QueryTest with SharedSQLContext { sql("SELECT key, count(*) FROM orderedTable GROUP BY key ORDER BY key"), sql("SELECT key, count(*) FROM testData3x GROUP BY key ORDER BY key").collect()) sqlContext.uncacheTable("orderedTable") + sqlContext.dropTempTable("orderedTable") // Set up two tables distributed in the same way. Try this with the data distributed into // different number of partitions. for (numPartitions <- 1 until 10 by 4) { - testData.repartition(numPartitions, $"key").registerTempTable("t1") - testData2.repartition(numPartitions, $"a").registerTempTable("t2") + withTempTable("t1", "t2") { + testData.repartition(numPartitions, $"key").registerTempTable("t1") + testData2.repartition(numPartitions, $"a").registerTempTable("t2") + sqlContext.cacheTable("t1") + sqlContext.cacheTable("t2") + + // Joining them should result in no exchanges. + verifyNumExchanges(sql("SELECT * FROM t1 t1 JOIN t2 t2 ON t1.key = t2.a"), 0) + checkAnswer(sql("SELECT * FROM t1 t1 JOIN t2 t2 ON t1.key = t2.a"), + sql("SELECT * FROM testData t1 JOIN testData2 t2 ON t1.key = t2.a")) + + // Grouping on the partition key should result in no exchanges + verifyNumExchanges(sql("SELECT count(*) FROM t1 GROUP BY key"), 0) + checkAnswer(sql("SELECT count(*) FROM t1 GROUP BY key"), + sql("SELECT count(*) FROM testData GROUP BY key")) + + sqlContext.uncacheTable("t1") + sqlContext.uncacheTable("t2") + } + } + + // Distribute the tables into non-matching number of partitions. Need to shuffle one side. + withTempTable("t1", "t2") { + testData.repartition(6, $"key").registerTempTable("t1") + testData2.repartition(3, $"a").registerTempTable("t2") sqlContext.cacheTable("t1") sqlContext.cacheTable("t2") - // Joining them should result in no exchanges. - verifyNumExchanges(sql("SELECT * FROM t1 t1 JOIN t2 t2 ON t1.key = t2.a"), 0) - checkAnswer(sql("SELECT * FROM t1 t1 JOIN t2 t2 ON t1.key = t2.a"), - sql("SELECT * FROM testData t1 JOIN testData2 t2 ON t1.key = t2.a")) + val query = sql("SELECT key, value, a, b FROM t1 t1 JOIN t2 t2 ON t1.key = t2.a") + verifyNumExchanges(query, 1) + assert(query.queryExecution.executedPlan.outputPartitioning.numPartitions === 6) + checkAnswer( + query, + testData.join(testData2, $"key" === $"a").select($"key", $"value", $"a", $"b")) + sqlContext.uncacheTable("t1") + sqlContext.uncacheTable("t2") + } - // Grouping on the partition key should result in no exchanges - verifyNumExchanges(sql("SELECT count(*) FROM t1 GROUP BY key"), 0) - checkAnswer(sql("SELECT count(*) FROM t1 GROUP BY key"), - sql("SELECT count(*) FROM testData GROUP BY key")) + // One side of join is not partitioned in the desired way. Need to shuffle one side. + withTempTable("t1", "t2") { + testData.repartition(6, $"value").registerTempTable("t1") + testData2.repartition(6, $"a").registerTempTable("t2") + sqlContext.cacheTable("t1") + sqlContext.cacheTable("t2") + val query = sql("SELECT key, value, a, b FROM t1 t1 JOIN t2 t2 ON t1.key = t2.a") + verifyNumExchanges(query, 1) + assert(query.queryExecution.executedPlan.outputPartitioning.numPartitions === 6) + checkAnswer( + query, + testData.join(testData2, $"key" === $"a").select($"key", $"value", $"a", $"b")) sqlContext.uncacheTable("t1") sqlContext.uncacheTable("t2") - sqlContext.dropTempTable("t1") - sqlContext.dropTempTable("t2") } - // Distribute the tables into non-matching number of partitions. Need to shuffle. - testData.repartition(6, $"key").registerTempTable("t1") - testData2.repartition(3, $"a").registerTempTable("t2") - sqlContext.cacheTable("t1") - sqlContext.cacheTable("t2") + withTempTable("t1", "t2") { + testData.repartition(6, $"value").registerTempTable("t1") + testData2.repartition(12, $"a").registerTempTable("t2") + sqlContext.cacheTable("t1") + sqlContext.cacheTable("t2") - verifyNumExchanges(sql("SELECT * FROM t1 t1 JOIN t2 t2 ON t1.key = t2.a"), 2) - sqlContext.uncacheTable("t1") - sqlContext.uncacheTable("t2") - sqlContext.dropTempTable("t1") - sqlContext.dropTempTable("t2") + val query = sql("SELECT key, value, a, b FROM t1 t1 JOIN t2 t2 ON t1.key = t2.a") + verifyNumExchanges(query, 1) + assert(query.queryExecution.executedPlan.outputPartitioning.numPartitions === 12) + checkAnswer( + query, + testData.join(testData2, $"key" === $"a").select($"key", $"value", $"a", $"b")) + sqlContext.uncacheTable("t1") + sqlContext.uncacheTable("t2") + } - // One side of join is not partitioned in the desired way. Need to shuffle. - testData.repartition(6, $"value").registerTempTable("t1") - testData2.repartition(6, $"a").registerTempTable("t2") - sqlContext.cacheTable("t1") - sqlContext.cacheTable("t2") + // One side of join is not partitioned in the desired way. Since the number of partitions of + // the side that has already partitioned is smaller than the side that is not partitioned, + // we shuffle both side. + withTempTable("t1", "t2") { + testData.repartition(6, $"value").registerTempTable("t1") + testData2.repartition(3, $"a").registerTempTable("t2") + sqlContext.cacheTable("t1") + sqlContext.cacheTable("t2") - verifyNumExchanges(sql("SELECT * FROM t1 t1 JOIN t2 t2 ON t1.key = t2.a"), 2) - sqlContext.uncacheTable("t1") - sqlContext.uncacheTable("t2") - sqlContext.dropTempTable("t1") - sqlContext.dropTempTable("t2") + val query = sql("SELECT key, value, a, b FROM t1 t1 JOIN t2 t2 ON t1.key = t2.a") + verifyNumExchanges(query, 2) + checkAnswer( + query, + testData.join(testData2, $"key" === $"a").select($"key", $"value", $"a", $"b")) + sqlContext.uncacheTable("t1") + sqlContext.uncacheTable("t2") + } + + // repartition's column ordering is different from group by column ordering. + // But they use the same set of columns. + withTempTable("t1") { + testData.repartition(6, $"value", $"key").registerTempTable("t1") + sqlContext.cacheTable("t1") + + val query = sql("SELECT value, key from t1 group by key, value") + verifyNumExchanges(query, 0) + checkAnswer( + query, + testData.distinct().select($"value", $"key")) + sqlContext.uncacheTable("t1") + } + + // repartition's column ordering is different from join condition's column ordering. + // We will still shuffle because hashcodes of a row depend on the column ordering. + // If we do not shuffle, we may actually partition two tables in totally two different way. + // See PartitioningSuite for more details. + withTempTable("t1", "t2") { + val df1 = testData + df1.repartition(6, $"value", $"key").registerTempTable("t1") + val df2 = testData2.select($"a", $"b".cast("string")) + df2.repartition(6, $"a", $"b").registerTempTable("t2") + sqlContext.cacheTable("t1") + sqlContext.cacheTable("t2") + + val query = + sql("SELECT key, value, a, b FROM t1 t1 JOIN t2 t2 ON t1.key = t2.a and t1.value = t2.b") + verifyNumExchanges(query, 1) + assert(query.queryExecution.executedPlan.outputPartitioning.numPartitions === 6) + checkAnswer( + query, + df1.join(df2, $"key" === $"a" && $"value" === $"b").select($"key", $"value", $"a", $"b")) + sqlContext.uncacheTable("t1") + sqlContext.uncacheTable("t2") + } } } From 62bb290773c9f9fa53cbe6d4eedc6e153761a763 Mon Sep 17 00:00:00 2001 From: Jacek Laskowski Date: Fri, 6 Nov 2015 20:05:18 +0000 Subject: [PATCH 17/88] Typo fixes + code readability improvements Author: Jacek Laskowski Closes #9501 from jaceklaskowski/typos-with-style. --- .../scala/org/apache/spark/rdd/HadoopRDD.scala | 14 ++++++-------- .../org/apache/spark/scheduler/DAGScheduler.scala | 12 +++++++++--- .../apache/spark/scheduler/ShuffleMapTask.scala | 10 +++++----- .../scala/org/apache/spark/scheduler/TaskSet.scala | 2 +- 4 files changed, 21 insertions(+), 17 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index d841f05ec52cf..0453614f6a1d3 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -88,8 +88,8 @@ private[spark] class HadoopPartition(rddId: Int, idx: Int, s: InputSplit) * * @param sc The SparkContext to associate the RDD with. * @param broadcastedConf A general Hadoop Configuration, or a subclass of it. If the enclosed - * variabe references an instance of JobConf, then that JobConf will be used for the Hadoop job. - * Otherwise, a new JobConf will be created on each slave using the enclosed Configuration. + * variable references an instance of JobConf, then that JobConf will be used for the Hadoop job. + * Otherwise, a new JobConf will be created on each slave using the enclosed Configuration. * @param initLocalJobConfFuncOpt Optional closure used to initialize any JobConf that HadoopRDD * creates. * @param inputFormatClass Storage format of the data to be read. @@ -123,7 +123,7 @@ class HadoopRDD[K, V]( sc, sc.broadcast(new SerializableConfiguration(conf)) .asInstanceOf[Broadcast[SerializableConfiguration]], - None /* initLocalJobConfFuncOpt */, + initLocalJobConfFuncOpt = None, inputFormatClass, keyClass, valueClass, @@ -184,8 +184,9 @@ class HadoopRDD[K, V]( protected def getInputFormat(conf: JobConf): InputFormat[K, V] = { val newInputFormat = ReflectionUtils.newInstance(inputFormatClass.asInstanceOf[Class[_]], conf) .asInstanceOf[InputFormat[K, V]] - if (newInputFormat.isInstanceOf[Configurable]) { - newInputFormat.asInstanceOf[Configurable].setConf(conf) + newInputFormat match { + case c: Configurable => c.setConf(conf) + case _ => } newInputFormat } @@ -195,9 +196,6 @@ class HadoopRDD[K, V]( // add the credentials here as this can be called before SparkContext initialized SparkHadoopUtil.get.addCredentials(jobConf) val inputFormat = getInputFormat(jobConf) - if (inputFormat.isInstanceOf[Configurable]) { - inputFormat.asInstanceOf[Configurable].setConf(jobConf) - } val inputSplits = inputFormat.getSplits(jobConf, minPartitions) val array = new Array[Partition](inputSplits.size) for (i <- 0 until inputSplits.size) { diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index a1f0fd05f661a..4a9518fff4e7b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -541,8 +541,7 @@ class DAGScheduler( } /** - * Submit an action job to the scheduler and get a JobWaiter object back. The JobWaiter object - * can be used to block until the the job finishes executing or can be used to cancel the job. + * Submit an action job to the scheduler. * * @param rdd target RDD to run tasks on * @param func a function to run on each partition of the RDD @@ -551,6 +550,11 @@ class DAGScheduler( * @param callSite where in the user program this job was called * @param resultHandler callback to pass each result to * @param properties scheduler properties to attach to this job, e.g. fair scheduler pool name + * + * @return a JobWaiter object that can be used to block until the job finishes executing + * or can be used to cancel the job. + * + * @throws IllegalArgumentException when partitions ids are illegal */ def submitJob[T, U]( rdd: RDD[T], @@ -584,7 +588,7 @@ class DAGScheduler( /** * Run an action job on the given RDD and pass all the results to the resultHandler function as - * they arrive. Throws an exception if the job fials, or returns normally if successful. + * they arrive. * * @param rdd target RDD to run tasks on * @param func a function to run on each partition of the RDD @@ -593,6 +597,8 @@ class DAGScheduler( * @param callSite where in the user program this job was called * @param resultHandler callback to pass each result to * @param properties scheduler properties to attach to this job, e.g. fair scheduler pool name + * + * @throws Exception when the job fails */ def runJob[T, U]( rdd: RDD[T], diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala index f478f9982afef..ea97ef0e746d8 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala @@ -27,11 +27,11 @@ import org.apache.spark.rdd.RDD import org.apache.spark.shuffle.ShuffleWriter /** -* A ShuffleMapTask divides the elements of an RDD into multiple buckets (based on a partitioner -* specified in the ShuffleDependency). -* -* See [[org.apache.spark.scheduler.Task]] for more information. -* + * A ShuffleMapTask divides the elements of an RDD into multiple buckets (based on a partitioner + * specified in the ShuffleDependency). + * + * See [[org.apache.spark.scheduler.Task]] for more information. + * * @param stageId id of the stage this task belongs to * @param taskBinary broadcast version of the RDD and the ShuffleDependency. Once deserialized, * the type should be (RDD[_], ShuffleDependency[_, _, _]). diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala index be8526ba9b94f..517c8991aed78 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala @@ -29,7 +29,7 @@ private[spark] class TaskSet( val stageAttemptId: Int, val priority: Int, val properties: Properties) { - val id: String = stageId + "." + stageAttemptId + val id: String = stageId + "." + stageAttemptId override def toString: String = "TaskSet " + id } From 49f1a820372d1cba41f3f00d07eb5728f2ed6705 Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Fri, 6 Nov 2015 20:06:24 +0000 Subject: [PATCH 18/88] [SPARK-10116][CORE] XORShiftRandom.hashSeed is random in high bits https://issues.apache.org/jira/browse/SPARK-10116 This is really trivial, just happened to notice it -- if `XORShiftRandom.hashSeed` is really supposed to have random bits throughout (as the comment implies), it needs to do something for the conversion to `long`. mengxr mkolod Author: Imran Rashid Closes #8314 from squito/SPARK-10116. --- R/pkg/inst/tests/test_sparkSQL.R | 8 +-- .../spark/util/random/XORShiftRandom.scala | 6 ++- .../java/org/apache/spark/JavaAPISuite.java | 20 ++++--- .../spark/rdd/PairRDDFunctionsSuite.scala | 52 +++++++++++++------ .../util/random/XORShiftRandomSuite.scala | 15 ++++++ .../MultilayerPerceptronClassifierSuite.scala | 5 +- .../spark/ml/feature/Word2VecSuite.scala | 16 ++++-- .../clustering/StreamingKMeansSuite.scala | 13 +++-- python/pyspark/ml/feature.py | 20 +++---- python/pyspark/ml/recommendation.py | 6 +-- python/pyspark/mllib/recommendation.py | 4 +- python/pyspark/sql/dataframe.py | 6 +-- .../catalyst/expressions/RandomSuite.scala | 8 +-- .../apache/spark/sql/JavaDataFrameSuite.java | 6 ++- .../apache/spark/sql/DataFrameStatSuite.scala | 4 +- 15 files changed, 128 insertions(+), 61 deletions(-) diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index 816315b1e4e13..92cff1fba7193 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -875,9 +875,9 @@ test_that("column binary mathfunctions", { expect_equal(collect(select(df, shiftRight(df$b, 1)))[4, 1], 4) expect_equal(collect(select(df, shiftRightUnsigned(df$b, 1)))[4, 1], 4) expect_equal(class(collect(select(df, rand()))[2, 1]), "numeric") - expect_equal(collect(select(df, rand(1)))[1, 1], 0.45, tolerance = 0.01) + expect_equal(collect(select(df, rand(1)))[1, 1], 0.134, tolerance = 0.01) expect_equal(class(collect(select(df, randn()))[2, 1]), "numeric") - expect_equal(collect(select(df, randn(1)))[1, 1], -0.0111, tolerance = 0.01) + expect_equal(collect(select(df, randn(1)))[1, 1], -1.03, tolerance = 0.01) }) test_that("string operators", { @@ -1458,8 +1458,8 @@ test_that("sampleBy() on a DataFrame", { fractions <- list("0" = 0.1, "1" = 0.2) sample <- sampleBy(df, "key", fractions, 0) result <- collect(orderBy(count(groupBy(sample, "key")), "key")) - expect_identical(as.list(result[1, ]), list(key = "0", count = 2)) - expect_identical(as.list(result[2, ]), list(key = "1", count = 10)) + expect_identical(as.list(result[1, ]), list(key = "0", count = 3)) + expect_identical(as.list(result[2, ]), list(key = "1", count = 7)) }) test_that("SQL error message is returned from JVM", { diff --git a/core/src/main/scala/org/apache/spark/util/random/XORShiftRandom.scala b/core/src/main/scala/org/apache/spark/util/random/XORShiftRandom.scala index 85fb923cd9bc7..e8cdb6e98bf36 100644 --- a/core/src/main/scala/org/apache/spark/util/random/XORShiftRandom.scala +++ b/core/src/main/scala/org/apache/spark/util/random/XORShiftRandom.scala @@ -60,9 +60,11 @@ private[spark] class XORShiftRandom(init: Long) extends JavaRandom(init) { private[spark] object XORShiftRandom { /** Hash seeds to have 0/1 bits throughout. */ - private def hashSeed(seed: Long): Long = { + private[random] def hashSeed(seed: Long): Long = { val bytes = ByteBuffer.allocate(java.lang.Long.SIZE).putLong(seed).array() - MurmurHash3.bytesHash(bytes) + val lowBits = MurmurHash3.bytesHash(bytes) + val highBits = MurmurHash3.bytesHash(bytes, lowBits) + (highBits.toLong << 32) | (lowBits.toLong & 0xFFFFFFFFL) } /** diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java index fd8f7f39b7cc8..4d4e9820500e7 100644 --- a/core/src/test/java/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java @@ -146,21 +146,29 @@ public void intersection() { public void sample() { List ints = Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9, 10); JavaRDD rdd = sc.parallelize(ints); - JavaRDD sample20 = rdd.sample(true, 0.2, 3); + // the seeds here are "magic" to make this work out nicely + JavaRDD sample20 = rdd.sample(true, 0.2, 8); Assert.assertEquals(2, sample20.count()); - JavaRDD sample20WithoutReplacement = rdd.sample(false, 0.2, 5); + JavaRDD sample20WithoutReplacement = rdd.sample(false, 0.2, 2); Assert.assertEquals(2, sample20WithoutReplacement.count()); } @Test public void randomSplit() { - List ints = Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9, 10); + List ints = new ArrayList<>(1000); + for (int i = 0; i < 1000; i++) { + ints.add(i); + } JavaRDD rdd = sc.parallelize(ints); JavaRDD[] splits = rdd.randomSplit(new double[] { 0.4, 0.6, 1.0 }, 31); + // the splits aren't perfect -- not enough data for them to be -- just check they're about right Assert.assertEquals(3, splits.length); - Assert.assertEquals(1, splits[0].count()); - Assert.assertEquals(2, splits[1].count()); - Assert.assertEquals(7, splits[2].count()); + long s0 = splits[0].count(); + long s1 = splits[1].count(); + long s2 = splits[2].count(); + Assert.assertTrue(s0 + " not within expected range", s0 > 150 && s0 < 250); + Assert.assertTrue(s1 + " not within expected range", s1 > 250 && s0 < 350); + Assert.assertTrue(s2 + " not within expected range", s2 > 430 && s2 < 570); } @Test diff --git a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala index 1321ec84735b5..7d2cfcca9436a 100644 --- a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.rdd +import org.apache.commons.math3.distribution.{PoissonDistribution, BinomialDistribution} import org.apache.hadoop.fs.FileSystem import org.apache.hadoop.mapred._ import org.apache.hadoop.util.Progressable @@ -578,17 +579,36 @@ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext { (x: Int) => if (x % 10 < (10 * fractionPositive).toInt) "1" else "0" } - def checkSize(exact: Boolean, - withReplacement: Boolean, - expected: Long, - actual: Long, - p: Double): Boolean = { + def assertBinomialSample( + exact: Boolean, + actual: Int, + trials: Int, + p: Double): Unit = { + if (exact) { + assert(actual == math.ceil(p * trials).toInt) + } else { + val dist = new BinomialDistribution(trials, p) + val q = dist.cumulativeProbability(actual) + withClue(s"p = $p: trials = $trials") { + assert(q >= 0.001 && q <= 0.999) + } + } + } + + def assertPoissonSample( + exact: Boolean, + actual: Int, + trials: Int, + p: Double): Unit = { if (exact) { - return expected == actual + assert(actual == math.ceil(p * trials).toInt) + } else { + val dist = new PoissonDistribution(p * trials) + val q = dist.cumulativeProbability(actual) + withClue(s"p = $p: trials = $trials") { + assert(q >= 0.001 && q <= 0.999) + } } - val stdev = if (withReplacement) math.sqrt(expected) else math.sqrt(expected * p * (1 - p)) - // Very forgiving margin since we're dealing with very small sample sizes most of the time - math.abs(actual - expected) <= 6 * stdev } def testSampleExact(stratifiedData: RDD[(String, Int)], @@ -613,8 +633,7 @@ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext { samplingRate: Double, seed: Long, n: Long): Unit = { - val expectedSampleSize = stratifiedData.countByKey() - .mapValues(count => math.ceil(count * samplingRate).toInt) + val trials = stratifiedData.countByKey() val fractions = Map("1" -> samplingRate, "0" -> samplingRate) val sample = if (exact) { stratifiedData.sampleByKeyExact(false, fractions, seed) @@ -623,8 +642,10 @@ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext { } val sampleCounts = sample.countByKey() val takeSample = sample.collect() - sampleCounts.foreach { case(k, v) => - assert(checkSize(exact, false, expectedSampleSize(k), v, samplingRate)) } + sampleCounts.foreach { case (k, v) => + assertBinomialSample(exact = exact, actual = v.toInt, trials = trials(k).toInt, + p = samplingRate) + } assert(takeSample.size === takeSample.toSet.size) takeSample.foreach { x => assert(1 <= x._2 && x._2 <= n, s"elements not in [1, $n]") } } @@ -635,6 +656,7 @@ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext { samplingRate: Double, seed: Long, n: Long): Unit = { + val trials = stratifiedData.countByKey() val expectedSampleSize = stratifiedData.countByKey().mapValues(count => math.ceil(count * samplingRate).toInt) val fractions = Map("1" -> samplingRate, "0" -> samplingRate) @@ -646,7 +668,7 @@ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext { val sampleCounts = sample.countByKey() val takeSample = sample.collect() sampleCounts.foreach { case (k, v) => - assert(checkSize(exact, true, expectedSampleSize(k), v, samplingRate)) + assertPoissonSample(exact, actual = v.toInt, trials = trials(k).toInt, p = samplingRate) } val groupedByKey = takeSample.groupBy(_._1) for ((key, v) <- groupedByKey) { @@ -657,7 +679,7 @@ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext { if (exact) { assert(v.toSet.size <= expectedSampleSize(key)) } else { - assert(checkSize(false, true, expectedSampleSize(key), v.toSet.size, samplingRate)) + assertPoissonSample(false, actual = v.toSet.size, trials(key).toInt, p = samplingRate) } } } diff --git a/core/src/test/scala/org/apache/spark/util/random/XORShiftRandomSuite.scala b/core/src/test/scala/org/apache/spark/util/random/XORShiftRandomSuite.scala index d26667bf720cf..a5b50fce5c0a9 100644 --- a/core/src/test/scala/org/apache/spark/util/random/XORShiftRandomSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/random/XORShiftRandomSuite.scala @@ -65,4 +65,19 @@ class XORShiftRandomSuite extends SparkFunSuite with Matchers { val random = new XORShiftRandom(0L) assert(random.nextInt() != 0) } + + test ("hashSeed has random bits throughout") { + val totalBitCount = (0 until 10).map { seed => + val hashed = XORShiftRandom.hashSeed(seed) + val bitCount = java.lang.Long.bitCount(hashed) + // make sure we have roughly equal numbers of 0s and 1s. Mostly just check that we + // don't have all 0s or 1s in the high bits + bitCount should be > 20 + bitCount should be < 44 + bitCount + }.sum + // and over all the seeds, very close to equal numbers of 0s & 1s + totalBitCount should be > (32 * 10 - 30) + totalBitCount should be < (32 * 10 + 30) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala index 17db8c44777d4..a326432d017fc 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala @@ -61,8 +61,9 @@ class MultilayerPerceptronClassifierSuite extends SparkFunSuite with MLlibTestSp val xMean = Array(5.843, 3.057, 3.758, 1.199) val xVariance = Array(0.6856, 0.1899, 3.116, 0.581) + // the input seed is somewhat magic, to make this test pass val rdd = sc.parallelize(generateMultinomialLogisticInput( - coefficients, xMean, xVariance, true, nPoints, 42), 2) + coefficients, xMean, xVariance, true, nPoints, 1), 2) val dataFrame = sqlContext.createDataFrame(rdd).toDF("label", "features") val numClasses = 3 val numIterations = 100 @@ -70,7 +71,7 @@ class MultilayerPerceptronClassifierSuite extends SparkFunSuite with MLlibTestSp val trainer = new MultilayerPerceptronClassifier() .setLayers(layers) .setBlockSize(1) - .setSeed(11L) + .setSeed(11L) // currently this seed is ignored .setMaxIter(numIterations) val model = trainer.fit(dataFrame) val numFeatures = dataFrame.select("features").first().getAs[Vector](0).size diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala index a2e46f2029956..23dfdaa9f8fc6 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala @@ -66,9 +66,12 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext { // copied model must have the same parent. MLTestingUtils.checkCopy(model) + // These expectations are just magic values, characterizing the current + // behavior. The test needs to be updated to be more general, see SPARK-11502 + val magicExp = Vectors.dense(0.30153007534417237, -0.6833061711354689, 0.5116530778733167) model.transform(docDF).select("result", "expected").collect().foreach { case Row(vector1: Vector, vector2: Vector) => - assert(vector1 ~== vector2 absTol 1E-5, "Transformed vector is different with expected.") + assert(vector1 ~== magicExp absTol 1E-5, "Transformed vector is different with expected.") } } @@ -99,8 +102,15 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext { val realVectors = model.getVectors.sort("word").select("vector").map { case Row(v: Vector) => v }.collect() + // These expectations are just magic values, characterizing the current + // behavior. The test needs to be updated to be more general, see SPARK-11502 + val magicExpected = Seq( + Vectors.dense(0.3326166272163391, -0.5603077411651611, -0.2309209555387497), + Vectors.dense(0.32463887333869934, -0.9306551218032837, 1.393115520477295), + Vectors.dense(-0.27150997519493103, 0.4372006058692932, -0.13465698063373566) + ) - realVectors.zip(expectedVectors).foreach { + realVectors.zip(magicExpected).foreach { case (real, expected) => assert(real ~== expected absTol 1E-5, "Actual vector is different from expected.") } @@ -122,7 +132,7 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext { .setSeed(42L) .fit(docDF) - val expectedSimilarity = Array(0.2789285076917586, -0.6336972059851644) + val expectedSimilarity = Array(0.18032623242822343, -0.5717976464798823) val (synonyms, similarity) = model.findSynonyms("a", 2).map { case Row(w: String, sim: Double) => (w, sim) }.collect().unzip diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala index 3645d29dccdb2..65e37c64d404e 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala @@ -98,9 +98,16 @@ class StreamingKMeansSuite extends SparkFunSuite with TestSuiteBase { runStreams(ssc, numBatches, numBatches) // check that estimated centers are close to true centers - // NOTE exact assignment depends on the initialization! - assert(centers(0) ~== kMeans.latestModel().clusterCenters(0) absTol 1E-1) - assert(centers(1) ~== kMeans.latestModel().clusterCenters(1) absTol 1E-1) + // cluster ordering is arbitrary, so choose closest cluster + val d0 = Vectors.sqdist(kMeans.latestModel().clusterCenters(0), centers(0)) + val d1 = Vectors.sqdist(kMeans.latestModel().clusterCenters(0), centers(1)) + val (c0, c1) = if (d0 < d1) { + (centers(0), centers(1)) + } else { + (centers(1), centers(0)) + } + assert(c0 ~== kMeans.latestModel().clusterCenters(0) absTol 1E-1) + assert(c1 ~== kMeans.latestModel().clusterCenters(1) absTol 1E-1) } test("detecting dying clusters") { diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index c7b6dd926c3e8..b02d41b52ab25 100644 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -1788,21 +1788,21 @@ class Word2Vec(JavaEstimator, HasStepSize, HasMaxIter, HasSeed, HasInputCol, Has +----+--------------------+ |word| vector| +----+--------------------+ - | a|[-0.3511952459812...| - | b|[0.29077222943305...| - | c|[0.02315592765808...| + | a|[0.09461779892444...| + | b|[1.15474212169647...| + | c|[-0.3794820010662...| +----+--------------------+ ... >>> model.findSynonyms("a", 2).show() - +----+-------------------+ - |word| similarity| - +----+-------------------+ - | b|0.29255685145799626| - | c|-0.5414068302988307| - +----+-------------------+ + +----+--------------------+ + |word| similarity| + +----+--------------------+ + | b| 0.16782984556103436| + | c|-0.46761559092107646| + +----+--------------------+ ... >>> model.transform(doc).head().model - DenseVector([-0.0422, -0.5138, -0.2546, 0.6885, 0.276]) + DenseVector([0.5524, -0.4995, -0.3599, 0.0241, 0.3461]) .. versionadded:: 1.4.0 """ diff --git a/python/pyspark/ml/recommendation.py b/python/pyspark/ml/recommendation.py index ec5748a1cfe94..b44c66f73cc49 100644 --- a/python/pyspark/ml/recommendation.py +++ b/python/pyspark/ml/recommendation.py @@ -76,11 +76,11 @@ class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, Ha >>> test = sqlContext.createDataFrame([(0, 2), (1, 0), (2, 0)], ["user", "item"]) >>> predictions = sorted(model.transform(test).collect(), key=lambda r: r[0]) >>> predictions[0] - Row(user=0, item=2, prediction=0.39...) + Row(user=0, item=2, prediction=-0.13807615637779236) >>> predictions[1] - Row(user=1, item=0, prediction=3.19...) + Row(user=1, item=0, prediction=2.6258413791656494) >>> predictions[2] - Row(user=2, item=0, prediction=-1.15...) + Row(user=2, item=0, prediction=-1.5018409490585327) .. versionadded:: 1.4.0 """ diff --git a/python/pyspark/mllib/recommendation.py b/python/pyspark/mllib/recommendation.py index b9442b0d16c0f..93e47a797f490 100644 --- a/python/pyspark/mllib/recommendation.py +++ b/python/pyspark/mllib/recommendation.py @@ -101,12 +101,12 @@ class MatrixFactorizationModel(JavaModelWrapper, JavaSaveable, JavaLoader): >>> model = ALS.train(ratings, 1, nonnegative=True, seed=10) >>> model.predict(2, 2) - 3.8... + 3.73... >>> df = sqlContext.createDataFrame([Rating(1, 1, 1.0), Rating(1, 2, 2.0), Rating(2, 1, 2.0)]) >>> model = ALS.train(df, 1, nonnegative=True, seed=10) >>> model.predict(2, 2) - 3.8... + 3.73... >>> model = ALS.trainImplicit(ratings, 1, nonnegative=True, seed=10) >>> model.predict(2, 2) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 3baff8147753d..765a4511b64bc 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -436,7 +436,7 @@ def sample(self, withReplacement, fraction, seed=None): """Returns a sampled subset of this :class:`DataFrame`. >>> df.sample(False, 0.5, 42).count() - 1 + 2 """ assert fraction >= 0.0, "Negative fraction value: %s" % fraction seed = seed if seed is not None else random.randint(0, sys.maxsize) @@ -463,8 +463,8 @@ def sampleBy(self, col, fractions, seed=None): +---+-----+ |key|count| +---+-----+ - | 0| 3| - | 1| 8| + | 0| 5| + | 1| 9| +---+-----+ """ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RandomSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RandomSuite.scala index 4a644d136f09c..b7a0d44fa7e57 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RandomSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RandomSuite.scala @@ -24,12 +24,12 @@ import org.apache.spark.SparkFunSuite class RandomSuite extends SparkFunSuite with ExpressionEvalHelper { test("random") { - checkDoubleEvaluation(Rand(30), 0.7363714192755834 +- 0.001) - checkDoubleEvaluation(Randn(30), 0.5181478766595276 +- 0.001) + checkDoubleEvaluation(Rand(30), 0.31429268272540556 +- 0.001) + checkDoubleEvaluation(Randn(30), -0.4798519469521663 +- 0.001) } test("SPARK-9127 codegen with long seed") { - checkDoubleEvaluation(Rand(5419823303878592871L), 0.4061913198963727 +- 0.001) - checkDoubleEvaluation(Randn(5419823303878592871L), -0.24417152005343168 +- 0.001) + checkDoubleEvaluation(Rand(5419823303878592871L), 0.2304755080444375 +- 0.001) + checkDoubleEvaluation(Randn(5419823303878592871L), -1.2824262718225607 +- 0.001) } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index 49f516e86d754..40bff57a17a03 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -257,7 +257,9 @@ public void testSampleBy() { DataFrame df = context.range(0, 100, 1, 2).select(col("id").mod(3).as("key")); DataFrame sampled = df.stat().sampleBy("key", ImmutableMap.of(0, 0.1, 1, 0.2), 0L); Row[] actual = sampled.groupBy("key").count().orderBy("key").collect(); - Row[] expected = {RowFactory.create(0, 5), RowFactory.create(1, 8)}; - Assert.assertArrayEquals(expected, actual); + Assert.assertEquals(0, actual[0].getLong(0)); + Assert.assertTrue(0 <= actual[0].getLong(1) && actual[0].getLong(1) <= 8); + Assert.assertEquals(1, actual[1].getLong(0)); + Assert.assertTrue(2 <= actual[1].getLong(1) && actual[1].getLong(1) <= 13); } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index 6524abcf5e97f..b15af42caa3ab 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -41,7 +41,7 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext { val data = sparkContext.parallelize(1 to n, 2).toDF("id") checkAnswer( data.sample(withReplacement = false, 0.05, seed = 13), - Seq(16, 23, 88, 100).map(Row(_)) + Seq(3, 17, 27, 58, 62).map(Row(_)) ) } @@ -186,6 +186,6 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext { val sampled = df.stat.sampleBy("key", Map(0 -> 0.1, 1 -> 0.2), 0L) checkAnswer( sampled.groupBy("key").count().orderBy("key"), - Seq(Row(0, 5), Row(1, 8))) + Seq(Row(0, 6), Row(1, 11))) } } From f328fedafd7bd084470a5e402de0429b5b7f8cd7 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Fri, 6 Nov 2015 12:21:53 -0800 Subject: [PATCH 19/88] [SPARK-11450] [SQL] Add Unsafe Row processing to Expand This PR enables the Expand operator to process and produce Unsafe Rows. Author: Herman van Hovell Closes #9414 from hvanhovell/SPARK-11450. --- .../sql/catalyst/expressions/Projection.scala | 6 ++- .../apache/spark/sql/execution/Expand.scala | 19 ++++--- .../spark/sql/execution/basicOperators.scala | 8 +-- .../spark/sql/execution/ExpandSuite.scala | 54 +++++++++++++++++++ 4 files changed, 73 insertions(+), 14 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/ExpandSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index a6fe730f6dad4..79dabe8e925ad 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -128,7 +128,11 @@ object UnsafeProjection { * Returns an UnsafeProjection for given sequence of Expressions (bounded). */ def create(exprs: Seq[Expression]): UnsafeProjection = { - GenerateUnsafeProjection.generate(exprs) + val unsafeExprs = exprs.map(_ transform { + case CreateStruct(children) => CreateStructUnsafe(children) + case CreateNamedStruct(children) => CreateNamedStructUnsafe(children) + }) + GenerateUnsafeProjection.generate(unsafeExprs) } def create(expr: Expression): UnsafeProjection = create(Seq(expr)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala index a458881f40948..55e95769d3faa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala @@ -41,14 +41,21 @@ case class Expand( // as UNKNOWN partitioning override def outputPartitioning: Partitioning = UnknownPartitioning(0) + override def outputsUnsafeRows: Boolean = child.outputsUnsafeRows + override def canProcessUnsafeRows: Boolean = true + override def canProcessSafeRows: Boolean = true + + private[this] val projection = { + if (outputsUnsafeRows) { + (exprs: Seq[Expression]) => UnsafeProjection.create(exprs, child.output) + } else { + (exprs: Seq[Expression]) => newMutableProjection(exprs, child.output)() + } + } + protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") { child.execute().mapPartitions { iter => - // TODO Move out projection objects creation and transfer to - // workers via closure. However we can't assume the Projection - // is serializable because of the code gen, so we have to - // create the projections within each of the partition processing. - val groups = projections.map(ee => newProjection(ee, child.output)).toArray - + val groups = projections.map(projection).toArray new Iterator[InternalRow] { private[this] var result: InternalRow = _ private[this] var idx = -1 // -1 means the initial state diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index d5a803f8c4b24..799650a4f784f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -67,16 +67,10 @@ case class TungstenProject(projectList: Seq[NamedExpression], child: SparkPlan) override def output: Seq[Attribute] = projectList.map(_.toAttribute) - /** Rewrite the project list to use unsafe expressions as needed. */ - protected val unsafeProjectList = projectList.map(_ transform { - case CreateStruct(children) => CreateStructUnsafe(children) - case CreateNamedStruct(children) => CreateNamedStructUnsafe(children) - }) - protected override def doExecute(): RDD[InternalRow] = { val numRows = longMetric("numRows") child.execute().mapPartitions { iter => - val project = UnsafeProjection.create(unsafeProjectList, child.output) + val project = UnsafeProjection.create(projectList, child.output) iter.map { row => numRows += 1 project(row) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExpandSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExpandSuite.scala new file mode 100644 index 0000000000000..faef76d52ae75 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExpandSuite.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, BoundReference, Alias, Literal} +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.IntegerType + +class ExpandSuite extends SparkPlanTest with SharedSQLContext { + import testImplicits.localSeqToDataFrameHolder + + private def testExpand(f: SparkPlan => SparkPlan): Unit = { + val input = (1 to 1000).map(Tuple1.apply) + val projections = Seq.tabulate(2) { i => + Alias(BoundReference(0, IntegerType, false), "id")() :: Alias(Literal(i), "gid")() :: Nil + } + val attributes = projections.head.map(_.toAttribute) + checkAnswer( + input.toDF(), + plan => Expand(projections, attributes, f(plan)), + input.flatMap(i => Seq.tabulate(2)(j => Row(i._1, j))) + ) + } + + test("inheriting child row type") { + val exprs = AttributeReference("a", IntegerType, false)() :: Nil + val plan = Expand(Seq(exprs), exprs, ConvertToUnsafe(LocalTableScan(exprs, Seq.empty))) + assert(plan.outputsUnsafeRows, "Expand should inherits the created row type from its child.") + } + + test("expanding UnsafeRows") { + testExpand(ConvertToUnsafe) + } + + test("expanding SafeRows") { + testExpand(identity) + } +} From 3a652f691b220fada0286f8d0a562c5657973d4d Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Fri, 6 Nov 2015 14:47:41 -0800 Subject: [PATCH 20/88] [SPARK-11561][SQL] Rename text data source's column name to value. Author: Reynold Xin Closes #9527 from rxin/SPARK-11561. --- .../sql/execution/datasources/text/DefaultSource.scala | 6 ++---- .../spark/sql/execution/datasources/text/TextSuite.scala | 2 +- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala index 52c4421d7e87e..4b8b8e4e74dad 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala @@ -30,14 +30,12 @@ import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.mapred.SparkHadoopMapRedUtil import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{UnsafeRow, GenericMutableRow} +import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.catalyst.expressions.codegen.{UnsafeRowWriter, BufferHolder} -import org.apache.spark.sql.columnar.MutableUnsafeRow import org.apache.spark.sql.{AnalysisException, Row, SQLContext} import org.apache.spark.sql.execution.datasources.PartitionSpec import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.{StringType, StructType} -import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.SerializableConfiguration /** @@ -78,7 +76,7 @@ private[sql] class TextRelation( extends HadoopFsRelation(maybePartitionSpec) { /** Data schema is always a single column, named "text". */ - override def dataSchema: StructType = new StructType().add("text", StringType) + override def dataSchema: StructType = new StructType().add("value", StringType) /** This is an internal data source that outputs internal row format. */ override val needConversion: Boolean = false diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala index 0a2306c06646c..914e516613f9e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala @@ -65,7 +65,7 @@ class TextSuite extends QueryTest with SharedSQLContext { /** Verifies data and schema. */ private def verifyFrame(df: DataFrame): Unit = { // schema - assert(df.schema == new StructType().add("text", StringType)) + assert(df.schema == new StructType().add("value", StringType)) // verify content val data = df.collect() From c447c9d54603890db7399fb80adc9fae40b71f64 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Fri, 6 Nov 2015 14:51:03 -0800 Subject: [PATCH 21/88] [SPARK-11217][ML] save/load for non-meta estimators and transformers This PR implements the default save/load for non-meta estimators and transformers using the JSON serialization of param values. The saved metadata includes: * class name * uid * timestamp * paramMap The save/load interface is similar to DataFrames. We use the current active context by default, which should be sufficient for most use cases. ~~~scala instance.save("path") instance.write.context(sqlContext).overwrite().save("path") Instance.load("path") ~~~ The param handling is different from the design doc. We didn't save default and user-set params separately, and when we load it back, all parameters are user-set. This does cause issues. But it also cause other issues if we modify the default params. TODOs: * [x] Java test * [ ] a follow-up PR to implement default save/load for all non-meta estimators and transformers cc jkbradley Author: Xiangrui Meng Closes #9454 from mengxr/SPARK-11217. --- .../apache/spark/ml/feature/Binarizer.scala | 11 +- .../org/apache/spark/ml/param/params.scala | 2 +- .../org/apache/spark/ml/util/ReadWrite.scala | 220 ++++++++++++++++++ .../ml/util/JavaDefaultReadWriteSuite.java | 74 ++++++ .../spark/ml/feature/BinarizerSuite.scala | 11 +- .../spark/ml/util/DefaultReadWriteTest.scala | 110 +++++++++ .../apache/spark/ml/util/TempDirectory.scala | 45 ++++ 7 files changed, 469 insertions(+), 4 deletions(-) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala create mode 100644 mllib/src/test/java/org/apache/spark/ml/util/JavaDefaultReadWriteSuite.java create mode 100644 mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/util/TempDirectory.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala index edad754436455..e5c25574d4b11 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala @@ -22,7 +22,7 @@ import org.apache.spark.ml.Transformer import org.apache.spark.ml.attribute.BinaryAttribute import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} -import org.apache.spark.ml.util.{Identifiable, SchemaUtils} +import org.apache.spark.ml.util._ import org.apache.spark.sql._ import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{DoubleType, StructType} @@ -33,7 +33,7 @@ import org.apache.spark.sql.types.{DoubleType, StructType} */ @Experimental final class Binarizer(override val uid: String) - extends Transformer with HasInputCol with HasOutputCol { + extends Transformer with Writable with HasInputCol with HasOutputCol { def this() = this(Identifiable.randomUID("binarizer")) @@ -86,4 +86,11 @@ final class Binarizer(override val uid: String) } override def copy(extra: ParamMap): Binarizer = defaultCopy(extra) + + override def write: Writer = new DefaultParamsWriter(this) +} + +object Binarizer extends Readable[Binarizer] { + + override def read: Reader[Binarizer] = new DefaultParamsReader[Binarizer] } diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index 8361406f87299..c9325709187c5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -592,7 +592,7 @@ trait Params extends Identifiable with Serializable { /** * Sets a parameter in the embedded param map. */ - protected final def set[T](param: Param[T], value: T): this.type = { + final def set[T](param: Param[T], value: T): this.type = { set(param -> value) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala new file mode 100644 index 0000000000000..ea790e0dddc7f --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala @@ -0,0 +1,220 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.util + +import java.io.IOException + +import org.apache.hadoop.fs.{FileSystem, Path} +import org.json4s._ +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods._ + +import org.apache.spark.{Logging, SparkContext} +import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.ml.param.{ParamPair, Params} +import org.apache.spark.sql.SQLContext +import org.apache.spark.util.Utils + +/** + * Trait for [[Writer]] and [[Reader]]. + */ +private[util] sealed trait BaseReadWrite { + private var optionSQLContext: Option[SQLContext] = None + + /** + * Sets the SQL context to use for saving/loading. + */ + @Since("1.6.0") + def context(sqlContext: SQLContext): this.type = { + optionSQLContext = Option(sqlContext) + this + } + + /** + * Returns the user-specified SQL context or the default. + */ + protected final def sqlContext: SQLContext = optionSQLContext.getOrElse { + SQLContext.getOrCreate(SparkContext.getOrCreate()) + } +} + +/** + * Abstract class for utility classes that can save ML instances. + */ +@Experimental +@Since("1.6.0") +abstract class Writer extends BaseReadWrite { + + protected var shouldOverwrite: Boolean = false + + /** + * Saves the ML instances to the input path. + */ + @Since("1.6.0") + @throws[IOException]("If the input path already exists but overwrite is not enabled.") + def save(path: String): Unit + + /** + * Overwrites if the output path already exists. + */ + @Since("1.6.0") + def overwrite(): this.type = { + shouldOverwrite = true + this + } + + // override for Java compatibility + override def context(sqlContext: SQLContext): this.type = super.context(sqlContext) +} + +/** + * Trait for classes that provide [[Writer]]. + */ +@Since("1.6.0") +trait Writable { + + /** + * Returns a [[Writer]] instance for this ML instance. + */ + @Since("1.6.0") + def write: Writer + + /** + * Saves this ML instance to the input path, a shortcut of `write.save(path)`. + */ + @Since("1.6.0") + @throws[IOException]("If the input path already exists but overwrite is not enabled.") + def save(path: String): Unit = write.save(path) +} + +/** + * Abstract class for utility classes that can load ML instances. + * @tparam T ML instance type + */ +@Experimental +@Since("1.6.0") +abstract class Reader[T] extends BaseReadWrite { + + /** + * Loads the ML component from the input path. + */ + @Since("1.6.0") + def load(path: String): T + + // override for Java compatibility + override def context(sqlContext: SQLContext): this.type = super.context(sqlContext) +} + +/** + * Trait for objects that provide [[Reader]]. + * @tparam T ML instance type + */ +@Experimental +@Since("1.6.0") +trait Readable[T] { + + /** + * Returns a [[Reader]] instance for this class. + */ + @Since("1.6.0") + def read: Reader[T] + + /** + * Reads an ML instance from the input path, a shortcut of `read.load(path)`. + */ + @Since("1.6.0") + def load(path: String): T = read.load(path) +} + +/** + * Default [[Writer]] implementation for transformers and estimators that contain basic + * (json4s-serializable) params and no data. This will not handle more complex params or types with + * data (e.g., models with coefficients). + * @param instance object to save + */ +private[ml] class DefaultParamsWriter(instance: Params) extends Writer with Logging { + + /** + * Saves the ML component to the input path. + */ + override def save(path: String): Unit = { + val sc = sqlContext.sparkContext + + val hadoopConf = sc.hadoopConfiguration + val fs = FileSystem.get(hadoopConf) + val p = new Path(path) + if (fs.exists(p)) { + if (shouldOverwrite) { + logInfo(s"Path $path already exists. It will be overwritten.") + // TODO: Revert back to the original content if save is not successful. + fs.delete(p, true) + } else { + throw new IOException( + s"Path $path already exists. Please use write.overwrite().save(path) to overwrite it.") + } + } + + val uid = instance.uid + val cls = instance.getClass.getName + val params = instance.extractParamMap().toSeq.asInstanceOf[Seq[ParamPair[Any]]] + val jsonParams = params.map { case ParamPair(p, v) => + p.name -> parse(p.jsonEncode(v)) + }.toList + val metadata = ("class" -> cls) ~ + ("timestamp" -> System.currentTimeMillis()) ~ + ("uid" -> uid) ~ + ("paramMap" -> jsonParams) + val metadataPath = new Path(path, "metadata").toString + val metadataJson = compact(render(metadata)) + sc.parallelize(Seq(metadataJson), 1).saveAsTextFile(metadataPath) + } +} + +/** + * Default [[Reader]] implementation for transformers and estimators that contain basic + * (json4s-serializable) params and no data. This will not handle more complex params or types with + * data (e.g., models with coefficients). + * @tparam T ML instance type + */ +private[ml] class DefaultParamsReader[T] extends Reader[T] { + + /** + * Loads the ML component from the input path. + */ + override def load(path: String): T = { + implicit val format = DefaultFormats + val sc = sqlContext.sparkContext + val metadataPath = new Path(path, "metadata").toString + val metadataStr = sc.textFile(metadataPath, 1).first() + val metadata = parse(metadataStr) + val cls = Utils.classForName((metadata \ "class").extract[String]) + val uid = (metadata \ "uid").extract[String] + val instance = cls.getConstructor(classOf[String]).newInstance(uid).asInstanceOf[Params] + (metadata \ "paramMap") match { + case JObject(pairs) => + pairs.foreach { case (paramName, jsonValue) => + val param = instance.getParam(paramName) + val value = param.jsonDecode(compact(render(jsonValue))) + instance.set(param, value) + } + case _ => + throw new IllegalArgumentException(s"Cannot recognize JSON metadata: $metadataStr.") + } + instance.asInstanceOf[T] + } +} diff --git a/mllib/src/test/java/org/apache/spark/ml/util/JavaDefaultReadWriteSuite.java b/mllib/src/test/java/org/apache/spark/ml/util/JavaDefaultReadWriteSuite.java new file mode 100644 index 0000000000000..c39538014be81 --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/util/JavaDefaultReadWriteSuite.java @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.util; + +import java.io.File; +import java.io.IOException; + +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.util.Utils; + +public class JavaDefaultReadWriteSuite { + + JavaSparkContext jsc = null; + File tempDir = null; + + @Before + public void setUp() { + jsc = new JavaSparkContext("local[2]", "JavaDefaultReadWriteSuite"); + tempDir = Utils.createTempDir( + System.getProperty("java.io.tmpdir"), "JavaDefaultReadWriteSuite"); + } + + @After + public void tearDown() { + if (jsc != null) { + jsc.stop(); + jsc = null; + } + Utils.deleteRecursively(tempDir); + } + + @Test + public void testDefaultReadWrite() throws IOException { + String uid = "my_params"; + MyParams instance = new MyParams(uid); + instance.set(instance.intParam(), 2); + String outputPath = new File(tempDir, uid).getPath(); + instance.save(outputPath); + try { + instance.save(outputPath); + Assert.fail( + "Write without overwrite enabled should fail if the output directory already exists."); + } catch (IOException e) { + // expected + } + SQLContext sqlContext = new SQLContext(jsc); + instance.write().context(sqlContext).overwrite().save(outputPath); + MyParams newInstance = MyParams.load(outputPath); + Assert.assertEquals("UID should match.", instance.uid(), newInstance.uid()); + Assert.assertEquals("Params should be preserved.", + 2, newInstance.getOrDefault(newInstance.intParam())); + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala index 2086043983661..9dfa1439cc303 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala @@ -19,10 +19,11 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{DataFrame, Row} -class BinarizerSuite extends SparkFunSuite with MLlibTestSparkContext { +class BinarizerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { @transient var data: Array[Double] = _ @@ -66,4 +67,12 @@ class BinarizerSuite extends SparkFunSuite with MLlibTestSparkContext { assert(x === y, "The feature value is not correct after binarization.") } } + + test("read/write") { + val binarizer = new Binarizer() + .setInputCol("feature") + .setOutputCol("binarized_feature") + .setThreshold(0.1) + testDefaultReadWrite(binarizer) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala new file mode 100644 index 0000000000000..4545b0f281f5a --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.util + +import java.io.{File, IOException} + +import org.scalatest.Suite + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.param._ +import org.apache.spark.mllib.util.MLlibTestSparkContext + +trait DefaultReadWriteTest extends TempDirectory { self: Suite => + + /** + * Checks "overwrite" option and params. + * @param instance ML instance to test saving/loading + * @tparam T ML instance type + */ + def testDefaultReadWrite[T <: Params with Writable](instance: T): Unit = { + val uid = instance.uid + val path = new File(tempDir, uid).getPath + + instance.save(path) + intercept[IOException] { + instance.save(path) + } + instance.write.overwrite().save(path) + val loader = instance.getClass.getMethod("read").invoke(null).asInstanceOf[Reader[T]] + val newInstance = loader.load(path) + + assert(newInstance.uid === instance.uid) + instance.params.foreach { p => + if (instance.isDefined(p)) { + (instance.getOrDefault(p), newInstance.getOrDefault(p)) match { + case (Array(values), Array(newValues)) => + assert(values === newValues, s"Values do not match on param ${p.name}.") + case (value, newValue) => + assert(value === newValue, s"Values do not match on param ${p.name}.") + } + } else { + assert(!newInstance.isDefined(p), s"Param ${p.name} shouldn't be defined.") + } + } + + val load = instance.getClass.getMethod("load", classOf[String]) + val another = load.invoke(instance, path).asInstanceOf[T] + assert(another.uid === instance.uid) + } +} + +class MyParams(override val uid: String) extends Params with Writable { + + final val intParamWithDefault: IntParam = new IntParam(this, "intParamWithDefault", "doc") + final val intParam: IntParam = new IntParam(this, "intParam", "doc") + final val floatParam: FloatParam = new FloatParam(this, "floatParam", "doc") + final val doubleParam: DoubleParam = new DoubleParam(this, "doubleParam", "doc") + final val longParam: LongParam = new LongParam(this, "longParam", "doc") + final val stringParam: Param[String] = new Param[String](this, "stringParam", "doc") + final val intArrayParam: IntArrayParam = new IntArrayParam(this, "intArrayParam", "doc") + final val doubleArrayParam: DoubleArrayParam = + new DoubleArrayParam(this, "doubleArrayParam", "doc") + final val stringArrayParam: StringArrayParam = + new StringArrayParam(this, "stringArrayParam", "doc") + + setDefault(intParamWithDefault -> 0) + set(intParam -> 1) + set(floatParam -> 2.0f) + set(doubleParam -> 3.0) + set(longParam -> 4L) + set(stringParam -> "5") + set(intArrayParam -> Array(6, 7)) + set(doubleArrayParam -> Array(8.0, 9.0)) + set(stringArrayParam -> Array("10", "11")) + + override def copy(extra: ParamMap): Params = defaultCopy(extra) + + override def write: Writer = new DefaultParamsWriter(this) +} + +object MyParams extends Readable[MyParams] { + + override def read: Reader[MyParams] = new DefaultParamsReader[MyParams] + + override def load(path: String): MyParams = read.load(path) +} + +class DefaultReadWriteSuite extends SparkFunSuite with MLlibTestSparkContext + with DefaultReadWriteTest { + + test("default read/write") { + val myParams = new MyParams("my_params") + testDefaultReadWrite(myParams) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/TempDirectory.scala b/mllib/src/test/scala/org/apache/spark/ml/util/TempDirectory.scala new file mode 100644 index 0000000000000..2742026a69c2e --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/util/TempDirectory.scala @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.util + +import java.io.File + +import org.scalatest.{BeforeAndAfterAll, Suite} + +import org.apache.spark.util.Utils + +/** + * Trait that creates a temporary directory before all tests and deletes it after all. + */ +trait TempDirectory extends BeforeAndAfterAll { self: Suite => + + private var _tempDir: File = _ + + /** Returns the temporary directory as a [[File]] instance. */ + protected def tempDir: File = _tempDir + + override def beforeAll(): Unit = { + super.beforeAll() + _tempDir = Utils.createTempDir(this.getClass.getName) + } + + override def afterAll(): Unit = { + Utils.deleteRecursively(_tempDir) + super.afterAll() + } +} From f6680cdc5d2912dea9768ef5c3e2cc101b06daf8 Mon Sep 17 00:00:00 2001 From: Thomas Graves Date: Fri, 6 Nov 2015 15:24:33 -0800 Subject: [PATCH 22/88] [SPARK-11555] spark on yarn spark-class --num-workers doesn't work I tested the various options with both spark-submit and spark-class of specifying number of executors in both client and cluster mode where it applied. --num-workers, --num-executors, spark.executor.instances, SPARK_EXECUTOR_INSTANCES, default nothing supplied Author: Thomas Graves Closes #9523 from tgravescs/SPARK-11555. --- .../org/apache/spark/deploy/yarn/ClientArguments.scala | 2 +- .../org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala | 7 +++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala index 1165061db21e3..a9f4374357356 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala @@ -81,7 +81,7 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf) .orNull // If dynamic allocation is enabled, start at the configured initial number of executors. // Default to minExecutors if no initialExecutors is set. - numExecutors = YarnSparkHadoopUtil.getInitialTargetExecutorNumber(sparkConf) + numExecutors = YarnSparkHadoopUtil.getInitialTargetExecutorNumber(sparkConf, numExecutors) principal = Option(principal) .orElse(sparkConf.getOption("spark.yarn.principal")) .orNull diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala index 561ad79ee0228..a290ebeec9001 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala @@ -392,8 +392,11 @@ object YarnSparkHadoopUtil { /** * Getting the initial target number of executors depends on whether dynamic allocation is * enabled. + * If not using dynamic allocation it gets the number of executors reqeusted by the user. */ - def getInitialTargetExecutorNumber(conf: SparkConf): Int = { + def getInitialTargetExecutorNumber( + conf: SparkConf, + numExecutors: Int = DEFAULT_NUMBER_EXECUTORS): Int = { if (Utils.isDynamicAllocationEnabled(conf)) { val minNumExecutors = conf.getInt("spark.dynamicAllocation.minExecutors", 0) val initialNumExecutors = @@ -406,7 +409,7 @@ object YarnSparkHadoopUtil { initialNumExecutors } else { val targetNumExecutors = - sys.env.get("SPARK_EXECUTOR_INSTANCES").map(_.toInt).getOrElse(DEFAULT_NUMBER_EXECUTORS) + sys.env.get("SPARK_EXECUTOR_INSTANCES").map(_.toInt).getOrElse(numExecutors) // System property can override environment variable. conf.getInt("spark.executor.instances", targetNumExecutors) } From 7e9a9e603abce8689938bdd62d04b29299644aa4 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 6 Nov 2015 15:37:07 -0800 Subject: [PATCH 23/88] [SPARK-11269][SQL] Java API support & test cases for Dataset This simply brings https://github.com/apache/spark/pull/9358 up-to-date. Author: Wenchen Fan Author: Reynold Xin Closes #9528 from rxin/dataset-java. --- .../spark/sql/catalyst/encoders/Encoder.scala | 123 +++++- .../sql/catalyst/expressions/objects.scala | 21 ++ .../scala/org/apache/spark/sql/Dataset.scala | 126 ++++++- .../org/apache/spark/sql/DatasetHolder.scala | 6 +- .../org/apache/spark/sql/GroupedDataset.scala | 17 + .../org/apache/spark/sql/SQLContext.scala | 4 + .../apache/spark/sql/JavaDatasetSuite.java | 357 ++++++++++++++++++ .../spark/sql/DatasetPrimitiveSuite.scala | 2 +- 8 files changed, 644 insertions(+), 12 deletions(-) create mode 100644 sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala index 329a132d3d8b2..f05e18288de2b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala @@ -17,11 +17,11 @@ package org.apache.spark.sql.catalyst.encoders - - import scala.reflect.ClassTag -import org.apache.spark.sql.types.StructType +import org.apache.spark.util.Utils +import org.apache.spark.sql.types.{DataType, ObjectType, StructField, StructType} +import org.apache.spark.sql.catalyst.expressions._ /** * Used to convert a JVM object of type `T` to and from the internal Spark SQL representation. @@ -37,3 +37,120 @@ trait Encoder[T] extends Serializable { /** A ClassTag that can be used to construct and Array to contain a collection of `T`. */ def clsTag: ClassTag[T] } + +object Encoder { + import scala.reflect.runtime.universe._ + + def BOOLEAN: Encoder[java.lang.Boolean] = ExpressionEncoder(flat = true) + def BYTE: Encoder[java.lang.Byte] = ExpressionEncoder(flat = true) + def SHORT: Encoder[java.lang.Short] = ExpressionEncoder(flat = true) + def INT: Encoder[java.lang.Integer] = ExpressionEncoder(flat = true) + def LONG: Encoder[java.lang.Long] = ExpressionEncoder(flat = true) + def FLOAT: Encoder[java.lang.Float] = ExpressionEncoder(flat = true) + def DOUBLE: Encoder[java.lang.Double] = ExpressionEncoder(flat = true) + def STRING: Encoder[java.lang.String] = ExpressionEncoder(flat = true) + + def tuple[T1, T2](enc1: Encoder[T1], enc2: Encoder[T2]): Encoder[(T1, T2)] = { + tuple(Seq(enc1, enc2).map(_.asInstanceOf[ExpressionEncoder[_]])) + .asInstanceOf[ExpressionEncoder[(T1, T2)]] + } + + def tuple[T1, T2, T3]( + enc1: Encoder[T1], + enc2: Encoder[T2], + enc3: Encoder[T3]): Encoder[(T1, T2, T3)] = { + tuple(Seq(enc1, enc2, enc3).map(_.asInstanceOf[ExpressionEncoder[_]])) + .asInstanceOf[ExpressionEncoder[(T1, T2, T3)]] + } + + def tuple[T1, T2, T3, T4]( + enc1: Encoder[T1], + enc2: Encoder[T2], + enc3: Encoder[T3], + enc4: Encoder[T4]): Encoder[(T1, T2, T3, T4)] = { + tuple(Seq(enc1, enc2, enc3, enc4).map(_.asInstanceOf[ExpressionEncoder[_]])) + .asInstanceOf[ExpressionEncoder[(T1, T2, T3, T4)]] + } + + def tuple[T1, T2, T3, T4, T5]( + enc1: Encoder[T1], + enc2: Encoder[T2], + enc3: Encoder[T3], + enc4: Encoder[T4], + enc5: Encoder[T5]): Encoder[(T1, T2, T3, T4, T5)] = { + tuple(Seq(enc1, enc2, enc3, enc4, enc5).map(_.asInstanceOf[ExpressionEncoder[_]])) + .asInstanceOf[ExpressionEncoder[(T1, T2, T3, T4, T5)]] + } + + private def tuple(encoders: Seq[ExpressionEncoder[_]]): ExpressionEncoder[_] = { + assert(encoders.length > 1) + // make sure all encoders are resolved, i.e. `Attribute` has been resolved to `BoundReference`. + assert(encoders.forall(_.constructExpression.find(_.isInstanceOf[Attribute]).isEmpty)) + + val schema = StructType(encoders.zipWithIndex.map { + case (e, i) => StructField(s"_${i + 1}", if (e.flat) e.schema.head.dataType else e.schema) + }) + + val cls = Utils.getContextOrSparkClassLoader.loadClass(s"scala.Tuple${encoders.size}") + + val extractExpressions = encoders.map { + case e if e.flat => e.extractExpressions.head + case other => CreateStruct(other.extractExpressions) + }.zipWithIndex.map { case (expr, index) => + expr.transformUp { + case BoundReference(0, t: ObjectType, _) => + Invoke( + BoundReference(0, ObjectType(cls), true), + s"_${index + 1}", + t) + } + } + + val constructExpressions = encoders.zipWithIndex.map { case (enc, index) => + if (enc.flat) { + enc.constructExpression.transform { + case b: BoundReference => b.copy(ordinal = index) + } + } else { + enc.constructExpression.transformUp { + case BoundReference(ordinal, dt, _) => + GetInternalRowField(BoundReference(index, enc.schema, true), ordinal, dt) + } + } + } + + val constructExpression = + NewInstance(cls, constructExpressions, false, ObjectType(cls)) + + new ExpressionEncoder[Any]( + schema, + false, + extractExpressions, + constructExpression, + ClassTag.apply(cls)) + } + + + def typeTagOfTuple2[T1 : TypeTag, T2 : TypeTag]: TypeTag[(T1, T2)] = typeTag[(T1, T2)] + + private def getTypeTag[T](c: Class[T]): TypeTag[T] = { + import scala.reflect.api + + // val mirror = runtimeMirror(c.getClassLoader) + val mirror = rootMirror + val sym = mirror.staticClass(c.getName) + val tpe = sym.selfType + TypeTag(mirror, new api.TypeCreator { + def apply[U <: api.Universe with Singleton](m: api.Mirror[U]) = + if (m eq mirror) tpe.asInstanceOf[U # Type] + else throw new IllegalArgumentException( + s"Type tag defined in $mirror cannot be migrated to other mirrors.") + }) + } + + def forTuple2[T1, T2](c1: Class[T1], c2: Class[T2]): Encoder[(T1, T2)] = { + implicit val typeTag1 = getTypeTag(c1) + implicit val typeTag2 = getTypeTag(c2) + ExpressionEncoder[(T1, T2)]() + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala index 81855289762c6..4f58464221b4b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala @@ -491,3 +491,24 @@ case class CreateExternalRow(children: Seq[Expression]) extends Expression { s"final ${classOf[Row].getName} ${ev.value} = new $rowClass($values);" } } + +case class GetInternalRowField(child: Expression, ordinal: Int, dataType: DataType) + extends UnaryExpression { + + override def nullable: Boolean = true + + override def eval(input: InternalRow): Any = + throw new UnsupportedOperationException("Only code-generated evaluation is supported") + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val row = child.gen(ctx) + s""" + ${row.code} + final boolean ${ev.isNull} = ${row.isNull}; + ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + if (!${ev.isNull}) { + ${ev.value} = ${ctx.getValue(row.value, dataType, ordinal.toString)}; + } + """ + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 4bca9c3b3fe54..fecbdac9a6004 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -17,9 +17,13 @@ package org.apache.spark.sql +import scala.collection.JavaConverters._ + import org.apache.spark.annotation.Experimental import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.analysis.UnresolvedAlias +import org.apache.spark.api.java.function.{Function => JFunction, Function2 => JFunction2, _} + import org.apache.spark.sql.catalyst.encoders._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.Inner @@ -151,18 +155,37 @@ class Dataset[T] private[sql]( def transform[U](t: Dataset[T] => Dataset[U]): Dataset[U] = t(this) /** + * (Scala-specific) * Returns a new [[Dataset]] that only contains elements where `func` returns `true`. * @since 1.6.0 */ def filter(func: T => Boolean): Dataset[T] = mapPartitions(_.filter(func)) /** + * (Java-specific) + * Returns a new [[Dataset]] that only contains elements where `func` returns `true`. + * @since 1.6.0 + */ + def filter(func: JFunction[T, java.lang.Boolean]): Dataset[T] = + filter(t => func.call(t).booleanValue()) + + /** + * (Scala-specific) * Returns a new [[Dataset]] that contains the result of applying `func` to each element. * @since 1.6.0 */ def map[U : Encoder](func: T => U): Dataset[U] = mapPartitions(_.map(func)) /** + * (Java-specific) + * Returns a new [[Dataset]] that contains the result of applying `func` to each element. + * @since 1.6.0 + */ + def map[U](func: JFunction[T, U], encoder: Encoder[U]): Dataset[U] = + map(t => func.call(t))(encoder) + + /** + * (Scala-specific) * Returns a new [[Dataset]] that contains the result of applying `func` to each element. * @since 1.6.0 */ @@ -177,30 +200,77 @@ class Dataset[T] private[sql]( logicalPlan)) } + /** + * (Java-specific) + * Returns a new [[Dataset]] that contains the result of applying `func` to each element. + * @since 1.6.0 + */ + def mapPartitions[U]( + f: FlatMapFunction[java.util.Iterator[T], U], + encoder: Encoder[U]): Dataset[U] = { + val func: (Iterator[T]) => Iterator[U] = x => f.call(x.asJava).iterator().asScala + mapPartitions(func)(encoder) + } + + /** + * (Scala-specific) + * Returns a new [[Dataset]] by first applying a function to all elements of this [[Dataset]], + * and then flattening the results. + * @since 1.6.0 + */ def flatMap[U : Encoder](func: T => TraversableOnce[U]): Dataset[U] = mapPartitions(_.flatMap(func)) + /** + * (Java-specific) + * Returns a new [[Dataset]] by first applying a function to all elements of this [[Dataset]], + * and then flattening the results. + * @since 1.6.0 + */ + def flatMap[U](f: FlatMapFunction[T, U], encoder: Encoder[U]): Dataset[U] = { + val func: (T) => Iterable[U] = x => f.call(x).asScala + flatMap(func)(encoder) + } + /* ************** * * Side effects * * ************** */ /** + * (Scala-specific) * Runs `func` on each element of this Dataset. * @since 1.6.0 */ def foreach(func: T => Unit): Unit = rdd.foreach(func) /** + * (Java-specific) + * Runs `func` on each element of this Dataset. + * @since 1.6.0 + */ + def foreach(func: VoidFunction[T]): Unit = foreach(func.call(_)) + + /** + * (Scala-specific) * Runs `func` on each partition of this Dataset. * @since 1.6.0 */ def foreachPartition(func: Iterator[T] => Unit): Unit = rdd.foreachPartition(func) + /** + * (Java-specific) + * Runs `func` on each partition of this Dataset. + * @since 1.6.0 + */ + def foreachPartition(func: VoidFunction[java.util.Iterator[T]]): Unit = + foreachPartition(it => func.call(it.asJava)) + /* ************* * * Aggregation * * ************* */ /** + * (Scala-specific) * Reduces the elements of this Dataset using the specified binary function. The given function * must be commutative and associative or the result may be non-deterministic. * @since 1.6.0 @@ -208,6 +278,15 @@ class Dataset[T] private[sql]( def reduce(func: (T, T) => T): T = rdd.reduce(func) /** + * (Java-specific) + * Reduces the elements of this Dataset using the specified binary function. The given function + * must be commutative and associative or the result may be non-deterministic. + * @since 1.6.0 + */ + def reduce(func: JFunction2[T, T, T]): T = reduce(func.call(_, _)) + + /** + * (Scala-specific) * Aggregates the elements of each partition, and then the results for all the partitions, using a * given associative and commutative function and a neutral "zero value". * @@ -221,6 +300,15 @@ class Dataset[T] private[sql]( def fold(zeroValue: T)(op: (T, T) => T): T = rdd.fold(zeroValue)(op) /** + * (Java-specific) + * Aggregates the elements of each partition, and then the results for all the partitions, using a + * given associative and commutative function and a neutral "zero value". + * @since 1.6.0 + */ + def fold(zeroValue: T, func: JFunction2[T, T, T]): T = fold(zeroValue)(func.call(_, _)) + + /** + * (Scala-specific) * Returns a [[GroupedDataset]] where the data is grouped by the given key function. * @since 1.6.0 */ @@ -258,6 +346,14 @@ class Dataset[T] private[sql]( keyAttributes) } + /** + * (Java-specific) + * Returns a [[GroupedDataset]] where the data is grouped by the given key function. + * @since 1.6.0 + */ + def groupBy[K](f: JFunction[T, K], encoder: Encoder[K]): GroupedDataset[K, T] = + groupBy(f.call(_))(encoder) + /* ****************** * * Typed Relational * * ****************** */ @@ -267,8 +363,7 @@ class Dataset[T] private[sql]( * {{{ * df.select($"colA", $"colB" + 1) * }}} - * @group dfops - * @since 1.3.0 + * @since 1.6.0 */ // Copied from Dataframe to make sure we don't have invalid overloads. @scala.annotation.varargs @@ -279,7 +374,7 @@ class Dataset[T] private[sql]( * * {{{ * val ds = Seq(1, 2, 3).toDS() - * val newDS = ds.select(e[Int]("value + 1")) + * val newDS = ds.select(expr("value + 1").as[Int]) * }}} * @since 1.6.0 */ @@ -405,6 +500,8 @@ class Dataset[T] private[sql]( * This type of join can be useful both for preserving type-safety with the original object * types as well as working with relational data where either side of the join has column * names in common. + * + * @since 1.6.0 */ def joinWith[U](other: Dataset[U], condition: Column): Dataset[(T, U)] = { val left = this.logicalPlan @@ -438,12 +535,31 @@ class Dataset[T] private[sql]( * Gather to Driver Actions * * ************************** */ - /** Returns the first element in this [[Dataset]]. */ + /** + * Returns the first element in this [[Dataset]]. + * @since 1.6.0 + */ def first(): T = rdd.first() - /** Collects the elements to an Array. */ + /** + * Collects the elements to an Array. + * @since 1.6.0 + */ def collect(): Array[T] = rdd.collect() + /** + * (Java-specific) + * Collects the elements to a Java list. + * + * Due to the incompatibility problem between Scala and Java, the return type of [[collect()]] at + * Java side is `java.lang.Object`, which is not easy to use. Java user can use this method + * instead and keep the generic type for result. + * + * @since 1.6.0 + */ + def collectAsList(): java.util.List[T] = + rdd.collect().toSeq.asJava + /** Returns the first `num` elements of this [[Dataset]] as an Array. */ def take(num: Int): Array[T] = rdd.take(num) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DatasetHolder.scala b/sql/core/src/main/scala/org/apache/spark/sql/DatasetHolder.scala index 45f0098b92887..08097e9f02084 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DatasetHolder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DatasetHolder.scala @@ -27,9 +27,9 @@ package org.apache.spark.sql * * @since 1.6.0 */ -case class DatasetHolder[T] private[sql](private val df: Dataset[T]) { +case class DatasetHolder[T] private[sql](private val ds: Dataset[T]) { // This is declared with parentheses to prevent the Scala compiler from treating - // `rdd.toDF("1")` as invoking this toDF and then apply on the returned DataFrame. - def toDS(): Dataset[T] = df + // `rdd.toDS("1")` as invoking this toDS and then apply on the returned Dataset. + def toDS(): Dataset[T] = ds } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala index b8fc373dffcf5..b2803d5a9a1e3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala @@ -17,7 +17,11 @@ package org.apache.spark.sql +import java.util.{Iterator => JIterator} +import scala.collection.JavaConverters._ + import org.apache.spark.annotation.Experimental +import org.apache.spark.api.java.function.{Function2 => JFunction2, Function3 => JFunction3, _} import org.apache.spark.sql.catalyst.analysis.{UnresolvedAlias, UnresolvedAttribute} import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, encoderFor, Encoder} import org.apache.spark.sql.catalyst.expressions.{Expression, NamedExpression, Alias, Attribute} @@ -104,6 +108,12 @@ class GroupedDataset[K, T] private[sql]( MapGroups(f, groupingAttributes, logicalPlan)) } + def mapGroups[U]( + f: JFunction2[K, JIterator[T], JIterator[U]], + encoder: Encoder[U]): Dataset[U] = { + mapGroups((key, data) => f.call(key, data.asJava).asScala)(encoder) + } + // To ensure valid overloading. protected def agg(expr: Column, exprs: Column*): DataFrame = groupedData.agg(expr, exprs: _*) @@ -196,4 +206,11 @@ class GroupedDataset[K, T] private[sql]( this.logicalPlan, other.logicalPlan)) } + + def cogroup[U, R]( + other: GroupedDataset[K, U], + f: JFunction3[K, JIterator[T], JIterator[U], JIterator[R]], + encoder: Encoder[R]): Dataset[R] = { + cogroup(other)((key, left, right) => f.call(key, left.asJava, right.asJava).asScala)(encoder) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 5ad3871093fc8..5598731af5fcc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -508,6 +508,10 @@ class SQLContext private[sql]( new Dataset[T](this, plan) } + def createDataset[T : Encoder](data: java.util.List[T]): Dataset[T] = { + createDataset(data.asScala) + } + /** * Creates a DataFrame from an RDD[Row]. User can specify whether the input rows should be * converted to Catalyst rows. diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java new file mode 100644 index 0000000000000..a9493d576d179 --- /dev/null +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -0,0 +1,357 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test.org.apache.spark.sql; + +import java.io.Serializable; +import java.util.*; + +import scala.Tuple2; +import scala.Tuple3; +import scala.Tuple4; +import scala.Tuple5; +import org.junit.*; + +import org.apache.spark.Accumulator; +import org.apache.spark.SparkContext; +import org.apache.spark.api.java.function.*; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.catalyst.encoders.Encoder; +import org.apache.spark.sql.catalyst.encoders.Encoder$; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.GroupedDataset; +import org.apache.spark.sql.test.TestSQLContext; + +import static org.apache.spark.sql.functions.*; + +public class JavaDatasetSuite implements Serializable { + private transient JavaSparkContext jsc; + private transient TestSQLContext context; + private transient Encoder$ e = Encoder$.MODULE$; + + @Before + public void setUp() { + // Trigger static initializer of TestData + SparkContext sc = new SparkContext("local[*]", "testing"); + jsc = new JavaSparkContext(sc); + context = new TestSQLContext(sc); + context.loadTestData(); + } + + @After + public void tearDown() { + context.sparkContext().stop(); + context = null; + jsc = null; + } + + private Tuple2 tuple2(T1 t1, T2 t2) { + return new Tuple2(t1, t2); + } + + @Test + public void testCollect() { + List data = Arrays.asList("hello", "world"); + Dataset ds = context.createDataset(data, e.STRING()); + String[] collected = (String[]) ds.collect(); + Assert.assertEquals(Arrays.asList("hello", "world"), Arrays.asList(collected)); + } + + @Test + public void testCommonOperation() { + List data = Arrays.asList("hello", "world"); + Dataset ds = context.createDataset(data, e.STRING()); + Assert.assertEquals("hello", ds.first()); + + Dataset filtered = ds.filter(new Function() { + @Override + public Boolean call(String v) throws Exception { + return v.startsWith("h"); + } + }); + Assert.assertEquals(Arrays.asList("hello"), filtered.collectAsList()); + + + Dataset mapped = ds.map(new Function() { + @Override + public Integer call(String v) throws Exception { + return v.length(); + } + }, e.INT()); + Assert.assertEquals(Arrays.asList(5, 5), mapped.collectAsList()); + + Dataset parMapped = ds.mapPartitions(new FlatMapFunction, String>() { + @Override + public Iterable call(Iterator it) throws Exception { + List ls = new LinkedList(); + while (it.hasNext()) { + ls.add(it.next().toUpperCase()); + } + return ls; + } + }, e.STRING()); + Assert.assertEquals(Arrays.asList("HELLO", "WORLD"), parMapped.collectAsList()); + + Dataset flatMapped = ds.flatMap(new FlatMapFunction() { + @Override + public Iterable call(String s) throws Exception { + List ls = new LinkedList(); + for (char c : s.toCharArray()) { + ls.add(String.valueOf(c)); + } + return ls; + } + }, e.STRING()); + Assert.assertEquals( + Arrays.asList("h", "e", "l", "l", "o", "w", "o", "r", "l", "d"), + flatMapped.collectAsList()); + } + + @Test + public void testForeach() { + final Accumulator accum = jsc.accumulator(0); + List data = Arrays.asList("a", "b", "c"); + Dataset ds = context.createDataset(data, e.STRING()); + + ds.foreach(new VoidFunction() { + @Override + public void call(String s) throws Exception { + accum.add(1); + } + }); + Assert.assertEquals(3, accum.value().intValue()); + } + + @Test + public void testReduce() { + List data = Arrays.asList(1, 2, 3); + Dataset ds = context.createDataset(data, e.INT()); + + int reduced = ds.reduce(new Function2() { + @Override + public Integer call(Integer v1, Integer v2) throws Exception { + return v1 + v2; + } + }); + Assert.assertEquals(6, reduced); + + int folded = ds.fold(1, new Function2() { + @Override + public Integer call(Integer v1, Integer v2) throws Exception { + return v1 * v2; + } + }); + Assert.assertEquals(6, folded); + } + + @Test + public void testGroupBy() { + List data = Arrays.asList("a", "foo", "bar"); + Dataset ds = context.createDataset(data, e.STRING()); + GroupedDataset grouped = ds.groupBy(new Function() { + @Override + public Integer call(String v) throws Exception { + return v.length(); + } + }, e.INT()); + + Dataset mapped = grouped.mapGroups( + new Function2, Iterator>() { + @Override + public Iterator call(Integer key, Iterator data) throws Exception { + StringBuilder sb = new StringBuilder(key.toString()); + while (data.hasNext()) { + sb.append(data.next()); + } + return Collections.singletonList(sb.toString()).iterator(); + } + }, + e.STRING()); + + Assert.assertEquals(Arrays.asList("1a", "3foobar"), mapped.collectAsList()); + + List data2 = Arrays.asList(2, 6, 10); + Dataset ds2 = context.createDataset(data2, e.INT()); + GroupedDataset grouped2 = ds2.groupBy(new Function() { + @Override + public Integer call(Integer v) throws Exception { + return v / 2; + } + }, e.INT()); + + Dataset cogrouped = grouped.cogroup( + grouped2, + new Function3, Iterator, Iterator>() { + @Override + public Iterator call( + Integer key, + Iterator left, + Iterator right) throws Exception { + StringBuilder sb = new StringBuilder(key.toString()); + while (left.hasNext()) { + sb.append(left.next()); + } + sb.append("#"); + while (right.hasNext()) { + sb.append(right.next()); + } + return Collections.singletonList(sb.toString()).iterator(); + } + }, + e.STRING()); + + Assert.assertEquals(Arrays.asList("1a#2", "3foobar#6", "5#10"), cogrouped.collectAsList()); + } + + @Test + public void testGroupByColumn() { + List data = Arrays.asList("a", "foo", "bar"); + Dataset ds = context.createDataset(data, e.STRING()); + GroupedDataset grouped = ds.groupBy(length(col("value"))).asKey(e.INT()); + + Dataset mapped = grouped.mapGroups( + new Function2, Iterator>() { + @Override + public Iterator call(Integer key, Iterator data) throws Exception { + StringBuilder sb = new StringBuilder(key.toString()); + while (data.hasNext()) { + sb.append(data.next()); + } + return Collections.singletonList(sb.toString()).iterator(); + } + }, + e.STRING()); + + Assert.assertEquals(Arrays.asList("1a", "3foobar"), mapped.collectAsList()); + } + + @Test + public void testSelect() { + List data = Arrays.asList(2, 6); + Dataset ds = context.createDataset(data, e.INT()); + + Dataset> selected = ds.select( + expr("value + 1").as(e.INT()), + col("value").cast("string").as(e.STRING())); + + Assert.assertEquals( + Arrays.asList(tuple2(3, "2"), tuple2(7, "6")), + selected.collectAsList()); + } + + @Test + public void testSetOperation() { + List data = Arrays.asList("abc", "abc", "xyz"); + Dataset ds = context.createDataset(data, e.STRING()); + + Assert.assertEquals( + Arrays.asList("abc", "xyz"), + sort(ds.distinct().collectAsList().toArray(new String[0]))); + + List data2 = Arrays.asList("xyz", "foo", "foo"); + Dataset ds2 = context.createDataset(data2, e.STRING()); + + Dataset intersected = ds.intersect(ds2); + Assert.assertEquals(Arrays.asList("xyz"), intersected.collectAsList()); + + Dataset unioned = ds.union(ds2); + Assert.assertEquals( + Arrays.asList("abc", "abc", "foo", "foo", "xyz", "xyz"), + sort(unioned.collectAsList().toArray(new String[0]))); + + Dataset subtracted = ds.subtract(ds2); + Assert.assertEquals(Arrays.asList("abc", "abc"), subtracted.collectAsList()); + } + + private > List sort(T[] data) { + Arrays.sort(data); + return Arrays.asList(data); + } + + @Test + public void testJoin() { + List data = Arrays.asList(1, 2, 3); + Dataset ds = context.createDataset(data, e.INT()).as("a"); + List data2 = Arrays.asList(2, 3, 4); + Dataset ds2 = context.createDataset(data2, e.INT()).as("b"); + + Dataset> joined = + ds.joinWith(ds2, col("a.value").equalTo(col("b.value"))); + Assert.assertEquals( + Arrays.asList(tuple2(2, 2), tuple2(3, 3)), + joined.collectAsList()); + } + + @Test + public void testTupleEncoder() { + Encoder> encoder2 = e.tuple(e.INT(), e.STRING()); + List> data2 = Arrays.asList(tuple2(1, "a"), tuple2(2, "b")); + Dataset> ds2 = context.createDataset(data2, encoder2); + Assert.assertEquals(data2, ds2.collectAsList()); + + Encoder> encoder3 = e.tuple(e.INT(), e.LONG(), e.STRING()); + List> data3 = + Arrays.asList(new Tuple3(1, 2L, "a")); + Dataset> ds3 = context.createDataset(data3, encoder3); + Assert.assertEquals(data3, ds3.collectAsList()); + + Encoder> encoder4 = + e.tuple(e.INT(), e.STRING(), e.LONG(), e.STRING()); + List> data4 = + Arrays.asList(new Tuple4(1, "b", 2L, "a")); + Dataset> ds4 = context.createDataset(data4, encoder4); + Assert.assertEquals(data4, ds4.collectAsList()); + + Encoder> encoder5 = + e.tuple(e.INT(), e.STRING(), e.LONG(), e.STRING(), e.BOOLEAN()); + List> data5 = + Arrays.asList(new Tuple5(1, "b", 2L, "a", true)); + Dataset> ds5 = + context.createDataset(data5, encoder5); + Assert.assertEquals(data5, ds5.collectAsList()); + } + + @Test + public void testNestedTupleEncoder() { + // test ((int, string), string) + Encoder, String>> encoder = + e.tuple(e.tuple(e.INT(), e.STRING()), e.STRING()); + List, String>> data = + Arrays.asList(tuple2(tuple2(1, "a"), "a"), tuple2(tuple2(2, "b"), "b")); + Dataset, String>> ds = context.createDataset(data, encoder); + Assert.assertEquals(data, ds.collectAsList()); + + // test (int, (string, string, long)) + Encoder>> encoder2 = + e.tuple(e.INT(), e.tuple(e.STRING(), e.STRING(), e.LONG())); + List>> data2 = + Arrays.asList(tuple2(1, new Tuple3("a", "b", 3L))); + Dataset>> ds2 = + context.createDataset(data2, encoder2); + Assert.assertEquals(data2, ds2.collectAsList()); + + // test (int, ((string, long), string)) + Encoder, String>>> encoder3 = + e.tuple(e.INT(), e.tuple(e.tuple(e.STRING(), e.LONG()), e.STRING())); + List, String>>> data3 = + Arrays.asList(tuple2(1, tuple2(tuple2("a", 2L), "b"))); + Dataset, String>>> ds3 = + context.createDataset(data3, encoder3); + Assert.assertEquals(data3, ds3.collectAsList()); + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala index 32443557fb8e0..e3b0346f857d3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala @@ -59,7 +59,7 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext { test("foreach") { val ds = Seq(1, 2, 3).toDS() val acc = sparkContext.accumulator(0) - ds.foreach(acc +=) + ds.foreach(acc += _) assert(acc.value == 6) } From 1ab72b08601a1c8a674bdd3fab84d9804899b2c7 Mon Sep 17 00:00:00 2001 From: Nong Li Date: Fri, 6 Nov 2015 15:48:20 -0800 Subject: [PATCH 24/88] =?UTF-8?q?[SPARK-11410]=20[PYSPARK]=20Add=20python?= =?UTF-8?q?=20bindings=20for=20repartition=20and=20sortW=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …ithinPartitions. Author: Nong Li Closes #9504 from nongli/spark-11410. --- python/pyspark/sql/dataframe.py | 117 +++++++++++++++++++++++++++----- 1 file changed, 101 insertions(+), 16 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 765a4511b64bc..b97c94dad834a 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -422,6 +422,67 @@ def repartition(self, numPartitions): """ return DataFrame(self._jdf.repartition(numPartitions), self.sql_ctx) + @since(1.3) + def repartition(self, numPartitions, *cols): + """ + Returns a new :class:`DataFrame` partitioned by the given partitioning expressions. The + resulting DataFrame is hash partitioned. + + ``numPartitions`` can be an int to specify the target number of partitions or a Column. + If it is a Column, it will be used as the first partitioning column. If not specified, + the default number of partitions is used. + + .. versionchanged:: 1.6 + Added optional arguments to specify the partitioning columns. Also made numPartitions + optional if partitioning columns are specified. + + >>> df.repartition(10).rdd.getNumPartitions() + 10 + >>> data = df.unionAll(df).repartition("age") + >>> data.show() + +---+-----+ + |age| name| + +---+-----+ + | 2|Alice| + | 2|Alice| + | 5| Bob| + | 5| Bob| + +---+-----+ + >>> data = data.repartition(7, "age") + >>> data.show() + +---+-----+ + |age| name| + +---+-----+ + | 5| Bob| + | 5| Bob| + | 2|Alice| + | 2|Alice| + +---+-----+ + >>> data.rdd.getNumPartitions() + 7 + >>> data = data.repartition("name", "age") + >>> data.show() + +---+-----+ + |age| name| + +---+-----+ + | 5| Bob| + | 5| Bob| + | 2|Alice| + | 2|Alice| + +---+-----+ + """ + if isinstance(numPartitions, int): + if len(cols) == 0: + return DataFrame(self._jdf.repartition(numPartitions), self.sql_ctx) + else: + return DataFrame( + self._jdf.repartition(numPartitions, self._jcols(*cols)), self.sql_ctx) + elif isinstance(numPartitions, (basestring, Column)): + cols = (numPartitions, ) + cols + return DataFrame(self._jdf.repartition(self._jcols(*cols)), self.sql_ctx) + else: + raise TypeError("numPartitions should be an int or Column") + @since(1.3) def distinct(self): """Returns a new :class:`DataFrame` containing the distinct rows in this :class:`DataFrame`. @@ -589,6 +650,26 @@ def join(self, other, on=None, how=None): jdf = self._jdf.join(other._jdf, on._jc, how) return DataFrame(jdf, self.sql_ctx) + @since(1.6) + def sortWithinPartitions(self, *cols, **kwargs): + """Returns a new :class:`DataFrame` with each partition sorted by the specified column(s). + + :param cols: list of :class:`Column` or column names to sort by. + :param ascending: boolean or list of boolean (default True). + Sort ascending vs. descending. Specify list for multiple sort orders. + If a list is specified, length of the list must equal length of the `cols`. + + >>> df.sortWithinPartitions("age", ascending=False).show() + +---+-----+ + |age| name| + +---+-----+ + | 2|Alice| + | 5| Bob| + +---+-----+ + """ + jdf = self._jdf.sortWithinPartitions(self._sort_cols(cols, kwargs)) + return DataFrame(jdf, self.sql_ctx) + @ignore_unicode_prefix @since(1.3) def sort(self, *cols, **kwargs): @@ -613,22 +694,7 @@ def sort(self, *cols, **kwargs): >>> df.orderBy(["age", "name"], ascending=[0, 1]).collect() [Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')] """ - if not cols: - raise ValueError("should sort by at least one column") - if len(cols) == 1 and isinstance(cols[0], list): - cols = cols[0] - jcols = [_to_java_column(c) for c in cols] - ascending = kwargs.get('ascending', True) - if isinstance(ascending, (bool, int)): - if not ascending: - jcols = [jc.desc() for jc in jcols] - elif isinstance(ascending, list): - jcols = [jc if asc else jc.desc() - for asc, jc in zip(ascending, jcols)] - else: - raise TypeError("ascending can only be boolean or list, but got %s" % type(ascending)) - - jdf = self._jdf.sort(self._jseq(jcols)) + jdf = self._jdf.sort(self._sort_cols(cols, kwargs)) return DataFrame(jdf, self.sql_ctx) orderBy = sort @@ -650,6 +716,25 @@ def _jcols(self, *cols): cols = cols[0] return self._jseq(cols, _to_java_column) + def _sort_cols(self, cols, kwargs): + """ Return a JVM Seq of Columns that describes the sort order + """ + if not cols: + raise ValueError("should sort by at least one column") + if len(cols) == 1 and isinstance(cols[0], list): + cols = cols[0] + jcols = [_to_java_column(c) for c in cols] + ascending = kwargs.get('ascending', True) + if isinstance(ascending, (bool, int)): + if not ascending: + jcols = [jc.desc() for jc in jcols] + elif isinstance(ascending, list): + jcols = [jc if asc else jc.desc() + for asc, jc in zip(ascending, jcols)] + else: + raise TypeError("ascending can only be boolean or list, but got %s" % type(ascending)) + return self._jseq(jcols) + @since("1.3.1") def describe(self, *cols): """Computes statistics for numeric columns. From 6d0ead322e72303c6444c6ac641378a4690cde96 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Fri, 6 Nov 2015 16:04:20 -0800 Subject: [PATCH 25/88] [SPARK-9241][SQL] Supporting multiple DISTINCT columns (2) - Rewriting Rule The second PR for SPARK-9241, this adds support for multiple distinct columns to the new aggregation code path. This PR solves the multiple DISTINCT column problem by rewriting these Aggregates into an Expand-Aggregate-Aggregate combination. See the [JIRA ticket](https://issues.apache.org/jira/browse/SPARK-9241) for some information on this. The advantages over the - competing - [first PR](https://github.com/apache/spark/pull/9280) are: - This can use the faster TungstenAggregate code path. - It is impossible to OOM due to an ```OpenHashSet``` allocating to much memory. However, this will multiply the number of input rows by the number of distinct clauses (plus one), and puts a lot more memory pressure on the aggregation code path itself. The location of this Rule is a bit funny, and should probably change when the old aggregation path is changed. cc yhuai - Could you also tell me where to add tests for this? Author: Herman van Hovell Closes #9406 from hvanhovell/SPARK-9241-rewriter. --- .../expressions/aggregate/Count.scala | 2 + .../expressions/aggregate/Utils.scala | 186 +++++++++++++++++- .../expressions/aggregate/interfaces.scala | 6 + .../sql/catalyst/optimizer/Optimizer.scala | 6 +- .../plans/logical/basicOperators.scala | 80 ++++---- .../spark/sql/execution/SparkStrategies.scala | 2 +- 6 files changed, 238 insertions(+), 44 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala index 54df96cd2446a..ec0c8b483a909 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala @@ -49,4 +49,6 @@ case class Count(child: Expression) extends DeclarativeAggregate { ) override val evaluateExpression = Cast(count, LongType) + + override def defaultResult: Option[Literal] = Option(Literal(0L)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Utils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Utils.scala index 644c6211d5f31..39010c3be6d4e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Utils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Utils.scala @@ -20,8 +20,9 @@ package org.apache.spark.sql.catalyst.expressions.aggregate import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan} -import org.apache.spark.sql.types.{StructType, MapType, ArrayType} +import org.apache.spark.sql.catalyst.plans.logical.{Expand, Aggregate, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.types.{IntegerType, StructType, MapType, ArrayType} /** * Utility functions used by the query planner to convert our plan to new aggregation code path. @@ -41,7 +42,7 @@ object Utils { private def doConvert(plan: LogicalPlan): Option[Aggregate] = plan match { case p: Aggregate if supportsGroupingKeySchema(p) => - val converted = p.transformExpressionsDown { + val converted = MultipleDistinctRewriter.rewrite(p.transformExpressionsDown { case expressions.Average(child) => aggregate.AggregateExpression2( aggregateFunction = aggregate.Average(child), @@ -144,7 +145,8 @@ object Utils { aggregateFunction = aggregate.VarianceSamp(child), mode = aggregate.Complete, isDistinct = false) - } + }) + // Check if there is any expressions.AggregateExpression1 left. // If so, we cannot convert this plan. val hasAggregateExpression1 = converted.aggregateExpressions.exists { expr => @@ -156,6 +158,7 @@ object Utils { } // Check if there are multiple distinct columns. + // TODO remove this. val aggregateExpressions = converted.aggregateExpressions.flatMap { expr => expr.collect { case agg: AggregateExpression2 => agg @@ -213,3 +216,178 @@ object Utils { case other => None } } + +/** + * This rule rewrites an aggregate query with multiple distinct clauses into an expanded double + * aggregation in which the regular aggregation expressions and every distinct clause is aggregated + * in a separate group. The results are then combined in a second aggregate. + * + * TODO Expression cannocalization + * TODO Eliminate foldable expressions from distinct clauses. + * TODO This eliminates all distinct expressions. We could safely pass one to the aggregate + * operator. Perhaps this is a good thing? It is much simpler to plan later on... + */ +object MultipleDistinctRewriter extends Rule[LogicalPlan] { + + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + case a: Aggregate => rewrite(a) + case p => p + } + + def rewrite(a: Aggregate): Aggregate = { + + // Collect all aggregate expressions. + val aggExpressions = a.aggregateExpressions.flatMap { e => + e.collect { + case ae: AggregateExpression2 => ae + } + } + + // Extract distinct aggregate expressions. + val distinctAggGroups = aggExpressions + .filter(_.isDistinct) + .groupBy(_.aggregateFunction.children.toSet) + + // Only continue to rewrite if there is more than one distinct group. + if (distinctAggGroups.size > 1) { + // Create the attributes for the grouping id and the group by clause. + val gid = new AttributeReference("gid", IntegerType, false)() + val groupByMap = a.groupingExpressions.collect { + case ne: NamedExpression => ne -> ne.toAttribute + case e => e -> new AttributeReference(e.prettyName, e.dataType, e.nullable)() + } + val groupByAttrs = groupByMap.map(_._2) + + // Functions used to modify aggregate functions and their inputs. + def evalWithinGroup(id: Literal, e: Expression) = If(EqualTo(gid, id), e, nullify(e)) + def patchAggregateFunctionChildren( + af: AggregateFunction2, + id: Literal, + attrs: Map[Expression, Expression]): AggregateFunction2 = { + af.withNewChildren(af.children.map { case afc => + evalWithinGroup(id, attrs(afc)) + }).asInstanceOf[AggregateFunction2] + } + + // Setup unique distinct aggregate children. + val distinctAggChildren = distinctAggGroups.keySet.flatten.toSeq + val distinctAggChildAttrMap = distinctAggChildren.map(expressionAttributePair).toMap + val distinctAggChildAttrs = distinctAggChildAttrMap.values.toSeq + + // Setup expand & aggregate operators for distinct aggregate expressions. + val distinctAggOperatorMap = distinctAggGroups.toSeq.zipWithIndex.map { + case ((group, expressions), i) => + val id = Literal(i + 1) + + // Expand projection + val projection = distinctAggChildren.map { + case e if group.contains(e) => e + case e => nullify(e) + } :+ id + + // Final aggregate + val operators = expressions.map { e => + val af = e.aggregateFunction + val naf = patchAggregateFunctionChildren(af, id, distinctAggChildAttrMap) + (e, e.copy(aggregateFunction = naf, isDistinct = false)) + } + + (projection, operators) + } + + // Setup expand for the 'regular' aggregate expressions. + val regularAggExprs = aggExpressions.filter(!_.isDistinct) + val regularAggChildren = regularAggExprs.flatMap(_.aggregateFunction.children).distinct + val regularAggChildAttrMap = regularAggChildren.map(expressionAttributePair).toMap + + // Setup aggregates for 'regular' aggregate expressions. + val regularGroupId = Literal(0) + val regularAggOperatorMap = regularAggExprs.map { e => + // Perform the actual aggregation in the initial aggregate. + val af = patchAggregateFunctionChildren( + e.aggregateFunction, + regularGroupId, + regularAggChildAttrMap) + val a = Alias(e.copy(aggregateFunction = af), e.toString)() + + // Get the result of the first aggregate in the last aggregate. + val b = AggregateExpression2( + aggregate.First(evalWithinGroup(regularGroupId, a.toAttribute), Literal(true)), + mode = Complete, + isDistinct = false) + + // Some aggregate functions (COUNT) have the special property that they can return a + // non-null result without any input. We need to make sure we return a result in this case. + val c = af.defaultResult match { + case Some(lit) => Coalesce(Seq(b, lit)) + case None => b + } + + (e, a, c) + } + + // Construct the regular aggregate input projection only if we need one. + val regularAggProjection = if (regularAggExprs.nonEmpty) { + Seq(a.groupingExpressions ++ + distinctAggChildren.map(nullify) ++ + Seq(regularGroupId) ++ + regularAggChildren) + } else { + Seq.empty[Seq[Expression]] + } + + // Construct the distinct aggregate input projections. + val regularAggNulls = regularAggChildren.map(nullify) + val distinctAggProjections = distinctAggOperatorMap.map { + case (projection, _) => + a.groupingExpressions ++ + projection ++ + regularAggNulls + } + + // Construct the expand operator. + val expand = Expand( + regularAggProjection ++ distinctAggProjections, + groupByAttrs ++ distinctAggChildAttrs ++ Seq(gid) ++ regularAggChildAttrMap.values.toSeq, + a.child) + + // Construct the first aggregate operator. This de-duplicates the all the children of + // distinct operators, and applies the regular aggregate operators. + val firstAggregateGroupBy = groupByAttrs ++ distinctAggChildAttrs :+ gid + val firstAggregate = Aggregate( + firstAggregateGroupBy, + firstAggregateGroupBy ++ regularAggOperatorMap.map(_._2), + expand) + + // Construct the second aggregate + val transformations: Map[Expression, Expression] = + (distinctAggOperatorMap.flatMap(_._2) ++ + regularAggOperatorMap.map(e => (e._1, e._3))).toMap + + val patchedAggExpressions = a.aggregateExpressions.map { e => + e.transformDown { + case e: Expression => + // The same GROUP BY clauses can have different forms (different names for instance) in + // the groupBy and aggregate expressions of an aggregate. This makes a map lookup + // tricky. So we do a linear search for a semantically equal group by expression. + groupByMap + .find(ge => e.semanticEquals(ge._1)) + .map(_._2) + .getOrElse(transformations.getOrElse(e, e)) + }.asInstanceOf[NamedExpression] + } + Aggregate(groupByAttrs, patchedAggExpressions, firstAggregate) + } else { + a + } + } + + private def nullify(e: Expression) = Literal.create(null, e.dataType) + + private def expressionAttributePair(e: Expression) = + // We are creating a new reference here instead of reusing the attribute in case of a + // NamedExpression. This is done to prevent collisions between distinct and regular aggregate + // children, in this case attribute reuse causes the input of the regular aggregate to bound to + // the (nulled out) input of the distinct aggregate. + e -> new AttributeReference(e.prettyName, e.dataType, true)() +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala index a2fab258fcac3..5c5b3d1ccd3cd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala @@ -133,6 +133,12 @@ sealed abstract class AggregateFunction2 extends Expression with ImplicitCastInp */ def supportsPartial: Boolean = true + /** + * Result of the aggregate function when the input is empty. This is currently only used for the + * proper rewriting of distinct aggregate functions. + */ + def defaultResult: Option[Literal] = None + override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = throw new UnsupportedOperationException(s"Cannot evaluate expression: $this") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 338c5193cb7a2..d222dfa33ad8a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -200,9 +200,9 @@ object SetOperationPushDown extends Rule[LogicalPlan] with PredicateHelper { */ object ColumnPruning extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case a @ Aggregate(_, _, e @ Expand(_, groupByExprs, _, child)) - if (child.outputSet -- AttributeSet(groupByExprs) -- a.references).nonEmpty => - a.copy(child = e.copy(child = prunedChild(child, AttributeSet(groupByExprs) ++ a.references))) + case a @ Aggregate(_, _, e @ Expand(_, _, child)) + if (child.outputSet -- AttributeSet(e.output) -- a.references).nonEmpty => + a.copy(child = e.copy(child = prunedChild(child, AttributeSet(e.output) ++ a.references))) // Eliminate attributes that are not needed to calculate the specified aggregates. case a @ Aggregate(_, _, child) if (child.outputSet -- a.references).nonEmpty => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 4cb67aacf33ee..fb963e2f8f7e7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -235,33 +235,17 @@ case class Window( projectList ++ windowExpressions.map(_.toAttribute) } -/** - * Apply the all of the GroupExpressions to every input row, hence we will get - * multiple output rows for a input row. - * @param bitmasks The bitmask set represents the grouping sets - * @param groupByExprs The grouping by expressions - * @param child Child operator - */ -case class Expand( - bitmasks: Seq[Int], - groupByExprs: Seq[Expression], - gid: Attribute, - child: LogicalPlan) extends UnaryNode { - override def statistics: Statistics = { - val sizeInBytes = child.statistics.sizeInBytes * projections.length - Statistics(sizeInBytes = sizeInBytes) - } - - val projections: Seq[Seq[Expression]] = expand() - +private[sql] object Expand { /** - * Extract attribute set according to the grouping id + * Extract attribute set according to the grouping id. + * * @param bitmask bitmask to represent the selected of the attribute sequence * @param exprs the attributes in sequence * @return the attributes of non selected specified via bitmask (with the bit set to 1) */ - private def buildNonSelectExprSet(bitmask: Int, exprs: Seq[Expression]) - : OpenHashSet[Expression] = { + private def buildNonSelectExprSet( + bitmask: Int, + exprs: Seq[Expression]): OpenHashSet[Expression] = { val set = new OpenHashSet[Expression](2) var bit = exprs.length - 1 @@ -274,18 +258,28 @@ case class Expand( } /** - * Create an array of Projections for the child projection, and replace the projections' - * expressions which equal GroupBy expressions with Literal(null), if those expressions - * are not set for this grouping set (according to the bit mask). + * Apply the all of the GroupExpressions to every input row, hence we will get + * multiple output rows for a input row. + * + * @param bitmasks The bitmask set represents the grouping sets + * @param groupByExprs The grouping by expressions + * @param gid Attribute of the grouping id + * @param child Child operator */ - private[this] def expand(): Seq[Seq[Expression]] = { - val result = new scala.collection.mutable.ArrayBuffer[Seq[Expression]] - - bitmasks.foreach { bitmask => + def apply( + bitmasks: Seq[Int], + groupByExprs: Seq[Expression], + gid: Attribute, + child: LogicalPlan): Expand = { + // Create an array of Projections for the child projection, and replace the projections' + // expressions which equal GroupBy expressions with Literal(null), if those expressions + // are not set for this grouping set (according to the bit mask). + val projections = bitmasks.map { bitmask => // get the non selected grouping attributes according to the bit mask val nonSelectedGroupExprSet = buildNonSelectExprSet(bitmask, groupByExprs) - val substitution = (child.output :+ gid).map(expr => expr transformDown { + (child.output :+ gid).map(expr => expr transformDown { + // TODO this causes a problem when a column is used both for grouping and aggregation. case x: Expression if nonSelectedGroupExprSet.contains(x) => // if the input attribute in the Invalid Grouping Expression set of for this group // replace it with constant null @@ -294,15 +288,29 @@ case class Expand( // replace the groupingId with concrete value (the bit mask) Literal.create(bitmask, IntegerType) }) - - result += substitution } - - result.toSeq + Expand(projections, child.output :+ gid, child) } +} - override def output: Seq[Attribute] = { - child.output :+ gid +/** + * Apply a number of projections to every input row, hence we will get multiple output rows for + * a input row. + * + * @param projections to apply + * @param output of all projections. + * @param child operator. + */ +case class Expand( + projections: Seq[Seq[Expression]], + output: Seq[Attribute], + child: LogicalPlan) extends UnaryNode { + + override def statistics: Statistics = { + // TODO shouldn't we factor in the size of the projection versus the size of the backing child + // row? + val sizeInBytes = child.statistics.sizeInBytes * projections.length + Statistics(sizeInBytes = sizeInBytes) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index f4464e0b916f8..dd3bb33c57287 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -420,7 +420,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } case logical.Filter(condition, child) => execution.Filter(condition, planLater(child)) :: Nil - case e @ logical.Expand(_, _, _, child) => + case e @ logical.Expand(_, _, child) => execution.Expand(e.projections, e.output, planLater(child)) :: Nil case a @ logical.Aggregate(group, agg, child) => { val useNewAggregation = sqlContext.conf.useSqlAggregate2 && sqlContext.conf.codegenEnabled From 1c80d66e52c0bcc4e5adda78b3d8e5bf55e4f128 Mon Sep 17 00:00:00 2001 From: "navis.ryu" Date: Fri, 6 Nov 2015 17:13:46 -0800 Subject: [PATCH 26/88] [SPARK-11546] Thrift server makes too many logs about result schema SparkExecuteStatementOperation logs result schema for each getNextRowSet() calls which is by default every 1000 rows, overwhelming whole log file. Author: navis.ryu Closes #9514 from navis/SPARK-11546. --- .../SparkExecuteStatementOperation.scala | 24 ++++++++++--------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala index 719b03e1c7c71..82fef92dcb73b 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala @@ -53,6 +53,18 @@ private[hive] class SparkExecuteStatementOperation( private var dataTypes: Array[DataType] = _ private var statementId: String = _ + private lazy val resultSchema: TableSchema = { + if (result == null || result.queryExecution.analyzed.output.size == 0) { + new TableSchema(Arrays.asList(new FieldSchema("Result", "string", ""))) + } else { + logInfo(s"Result Schema: ${result.queryExecution.analyzed.output}") + val schema = result.queryExecution.analyzed.output.map { attr => + new FieldSchema(attr.name, HiveMetastoreTypes.toMetastoreType(attr.dataType), "") + } + new TableSchema(schema.asJava) + } + } + def close(): Unit = { // RDDs will be cleaned automatically upon garbage collection. hiveContext.sparkContext.clearJobGroup() @@ -120,17 +132,7 @@ private[hive] class SparkExecuteStatementOperation( } } - def getResultSetSchema: TableSchema = { - if (result == null || result.queryExecution.analyzed.output.size == 0) { - new TableSchema(Arrays.asList(new FieldSchema("Result", "string", ""))) - } else { - logInfo(s"Result Schema: ${result.queryExecution.analyzed.output}") - val schema = result.queryExecution.analyzed.output.map { attr => - new FieldSchema(attr.name, HiveMetastoreTypes.toMetastoreType(attr.dataType), "") - } - new TableSchema(schema.asJava) - } - } + def getResultSetSchema: TableSchema = resultSchema override def run(): Unit = { setState(OperationState.PENDING) From 105732dcc6b651b9779f4a5773a759c5b4fbd21d Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Fri, 6 Nov 2015 17:22:30 -0800 Subject: [PATCH 27/88] [HOTFIX] Fix python tests after #9527 #9527 missed updating the python tests. Author: Michael Armbrust Closes #9533 from marmbrus/hotfixTextValue. --- python/pyspark/sql/readwriter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index 97bd90c4db829..927f4077424dc 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -203,7 +203,7 @@ def text(self, path): >>> df = sqlContext.read.text('python/test_support/sql/text-test.txt') >>> df.collect() - [Row(text=u'hello'), Row(text=u'this')] + [Row(value=u'hello'), Row(value=u'this')] """ return self._df(self._jreader.text(path)) From 30b706b7b36482921ec04145a0121ca147984fa8 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Fri, 6 Nov 2015 18:17:34 -0800 Subject: [PATCH 28/88] [SPARK-11389][CORE] Add support for off-heap memory to MemoryManager In order to lay the groundwork for proper off-heap memory support in SQL / Tungsten, we need to extend our MemoryManager to perform bookkeeping for off-heap memory. ## User-facing changes This PR introduces a new configuration, `spark.memory.offHeapSize` (name subject to change), which specifies the absolute amount of off-heap memory that Spark and Spark SQL can use. If Tungsten is configured to use off-heap execution memory for allocating data pages, then all data page allocations must fit within this size limit. ## Internals changes This PR contains a lot of internal refactoring of the MemoryManager. The key change at the heart of this patch is the introduction of a `MemoryPool` class (name subject to change) to manage the bookkeeping for a particular category of memory (storage, on-heap execution, and off-heap execution). These MemoryPools are not fixed-size; they can be dynamically grown and shrunk according to the MemoryManager's policies. In StaticMemoryManager, these pools have fixed sizes, proportional to the legacy `[storage|shuffle].memoryFraction`. In the new UnifiedMemoryManager, the sizes of these pools are dynamically adjusted according to its policies. There are two subclasses of `MemoryPool`: `StorageMemoryPool` manages storage memory and `ExecutionMemoryPool` manages execution memory. The MemoryManager creates two execution pools, one for on-heap memory and one for off-heap. Instances of `ExecutionMemoryPool` manage the logic for fair sharing of their pooled memory across running tasks (in other words, the ShuffleMemoryManager-like logic has been moved out of MemoryManager and pushed into these ExecutionMemoryPool instances). I think that this design is substantially easier to understand and reason about than the previous design, where most of these responsibilities were handled by MemoryManager and its subclasses. To see this, take at look at how simple the logic in `UnifiedMemoryManager` has become: it's now very easy to see when memory is dynamically shifted between storage and execution. ## TODOs - [x] Fix handful of test failures in the MemoryManagerSuites. - [x] Fix remaining TODO comments in code. - [ ] Document new configuration. - [x] Fix commented-out tests / asserts: - [x] UnifiedMemoryManagerSuite. - [x] Write tests that exercise the new off-heap memory management policies. Author: Josh Rosen Closes #9344 from JoshRosen/offheap-memory-accounting. --- .../apache/spark/memory/MemoryConsumer.java | 7 +- .../org/apache/spark/memory/MemoryMode.java | 26 ++ .../spark/memory/TaskMemoryManager.java | 72 +++-- .../scala/org/apache/spark/SparkEnv.scala | 2 +- .../spark/memory/ExecutionMemoryPool.scala | 153 +++++++++++ .../apache/spark/memory/MemoryManager.scala | 246 ++++++------------ .../org/apache/spark/memory/MemoryPool.scala | 71 +++++ .../spark/memory/StaticMemoryManager.scala | 75 +----- .../spark/memory/StorageMemoryPool.scala | 138 ++++++++++ .../spark/memory/UnifiedMemoryManager.scala | 138 +++++----- .../org/apache/spark/memory/package.scala | 75 ++++++ .../spark/util/collection/Spillable.scala | 8 +- .../spark/memory/TaskMemoryManagerSuite.java | 8 +- .../spark/memory/TestMemoryConsumer.java | 10 +- .../sort/UnsafeShuffleWriterSuite.java | 2 +- .../map/AbstractBytesToBytesMapSuite.java | 4 +- .../spark/memory/MemoryManagerSuite.scala | 104 +++++--- .../memory/StaticMemoryManagerSuite.scala | 39 +-- .../spark/memory/TestMemoryManager.scala | 20 +- .../memory/UnifiedMemoryManagerSuite.scala | 93 +++---- .../spark/storage/BlockManagerSuite.scala | 2 +- 21 files changed, 828 insertions(+), 465 deletions(-) create mode 100644 core/src/main/java/org/apache/spark/memory/MemoryMode.java create mode 100644 core/src/main/scala/org/apache/spark/memory/ExecutionMemoryPool.scala create mode 100644 core/src/main/scala/org/apache/spark/memory/MemoryPool.scala create mode 100644 core/src/main/scala/org/apache/spark/memory/StorageMemoryPool.scala create mode 100644 core/src/main/scala/org/apache/spark/memory/package.scala diff --git a/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java b/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java index 8fbdb72832adf..36138cc9a297c 100644 --- a/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java +++ b/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java @@ -17,15 +17,15 @@ package org.apache.spark.memory; - import java.io.IOException; import org.apache.spark.unsafe.array.LongArray; import org.apache.spark.unsafe.memory.MemoryBlock; - /** * An memory consumer of TaskMemoryManager, which support spilling. + * + * Note: this only supports allocation / spilling of Tungsten memory. */ public abstract class MemoryConsumer { @@ -36,7 +36,6 @@ public abstract class MemoryConsumer { protected MemoryConsumer(TaskMemoryManager taskMemoryManager, long pageSize) { this.taskMemoryManager = taskMemoryManager; this.pageSize = pageSize; - this.used = 0; } protected MemoryConsumer(TaskMemoryManager taskMemoryManager) { @@ -67,6 +66,8 @@ public void spill() throws IOException { * * Note: In order to avoid possible deadlock, should not call acquireMemory() from spill(). * + * Note: today, this only frees Tungsten-managed pages. + * * @param size the amount of memory should be released * @param trigger the MemoryConsumer that trigger this spilling * @return the amount of released memory in bytes diff --git a/core/src/main/java/org/apache/spark/memory/MemoryMode.java b/core/src/main/java/org/apache/spark/memory/MemoryMode.java new file mode 100644 index 0000000000000..3a5e72d8aaec0 --- /dev/null +++ b/core/src/main/java/org/apache/spark/memory/MemoryMode.java @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.memory; + +import org.apache.spark.annotation.Private; + +@Private +public enum MemoryMode { + ON_HEAP, + OFF_HEAP +} diff --git a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java index 6440f9c0f30de..5f743b28857b4 100644 --- a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java +++ b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java @@ -103,10 +103,10 @@ public class TaskMemoryManager { * without doing any masking or lookups. Since this branching should be well-predicted by the JIT, * this extra layer of indirection / abstraction hopefully shouldn't be too expensive. */ - private final boolean inHeap; + final MemoryMode tungstenMemoryMode; /** - * The size of memory granted to each consumer. + * Tracks spillable memory consumers. */ @GuardedBy("this") private final HashSet consumers; @@ -115,7 +115,7 @@ public class TaskMemoryManager { * Construct a new TaskMemoryManager. */ public TaskMemoryManager(MemoryManager memoryManager, long taskAttemptId) { - this.inHeap = memoryManager.tungstenMemoryIsAllocatedInHeap(); + this.tungstenMemoryMode = memoryManager.tungstenMemoryMode(); this.memoryManager = memoryManager; this.taskAttemptId = taskAttemptId; this.consumers = new HashSet<>(); @@ -127,12 +127,19 @@ public TaskMemoryManager(MemoryManager memoryManager, long taskAttemptId) { * * @return number of bytes successfully granted (<= N). */ - public long acquireExecutionMemory(long required, MemoryConsumer consumer) { + public long acquireExecutionMemory( + long required, + MemoryMode mode, + MemoryConsumer consumer) { assert(required >= 0); + // If we are allocating Tungsten pages off-heap and receive a request to allocate on-heap + // memory here, then it may not make sense to spill since that would only end up freeing + // off-heap memory. This is subject to change, though, so it may be risky to make this + // optimization now in case we forget to undo it late when making changes. synchronized (this) { - long got = memoryManager.acquireExecutionMemory(required, taskAttemptId); + long got = memoryManager.acquireExecutionMemory(required, taskAttemptId, mode); - // try to release memory from other consumers first, then we can reduce the frequency of + // Try to release memory from other consumers first, then we can reduce the frequency of // spilling, avoid to have too many spilled files. if (got < required) { // Call spill() on other consumers to release memory @@ -140,10 +147,10 @@ public long acquireExecutionMemory(long required, MemoryConsumer consumer) { if (c != consumer && c.getUsed() > 0) { try { long released = c.spill(required - got, consumer); - if (released > 0) { - logger.info("Task {} released {} from {} for {}", taskAttemptId, + if (released > 0 && mode == tungstenMemoryMode) { + logger.debug("Task {} released {} from {} for {}", taskAttemptId, Utils.bytesToString(released), c, consumer); - got += memoryManager.acquireExecutionMemory(required - got, taskAttemptId); + got += memoryManager.acquireExecutionMemory(required - got, taskAttemptId, mode); if (got >= required) { break; } @@ -161,10 +168,10 @@ public long acquireExecutionMemory(long required, MemoryConsumer consumer) { if (got < required && consumer != null) { try { long released = consumer.spill(required - got, consumer); - if (released > 0) { - logger.info("Task {} released {} from itself ({})", taskAttemptId, + if (released > 0 && mode == tungstenMemoryMode) { + logger.debug("Task {} released {} from itself ({})", taskAttemptId, Utils.bytesToString(released), consumer); - got += memoryManager.acquireExecutionMemory(required - got, taskAttemptId); + got += memoryManager.acquireExecutionMemory(required - got, taskAttemptId, mode); } } catch (IOException e) { logger.error("error while calling spill() on " + consumer, e); @@ -184,9 +191,9 @@ public long acquireExecutionMemory(long required, MemoryConsumer consumer) { /** * Release N bytes of execution memory for a MemoryConsumer. */ - public void releaseExecutionMemory(long size, MemoryConsumer consumer) { + public void releaseExecutionMemory(long size, MemoryMode mode, MemoryConsumer consumer) { logger.debug("Task {} release {} from {}", taskAttemptId, Utils.bytesToString(size), consumer); - memoryManager.releaseExecutionMemory(size, taskAttemptId); + memoryManager.releaseExecutionMemory(size, taskAttemptId, mode); } /** @@ -195,11 +202,19 @@ public void releaseExecutionMemory(long size, MemoryConsumer consumer) { public void showMemoryUsage() { logger.info("Memory used in task " + taskAttemptId); synchronized (this) { + long memoryAccountedForByConsumers = 0; for (MemoryConsumer c: consumers) { - if (c.getUsed() > 0) { - logger.info("Acquired by " + c + ": " + Utils.bytesToString(c.getUsed())); + long totalMemUsage = c.getUsed(); + memoryAccountedForByConsumers += totalMemUsage; + if (totalMemUsage > 0) { + logger.info("Acquired by " + c + ": " + Utils.bytesToString(totalMemUsage)); } } + long memoryNotAccountedFor = + memoryManager.getExecutionMemoryUsageForTask(taskAttemptId) - memoryAccountedForByConsumers; + logger.info( + "{} bytes of memory were used by task {} but are not associated with specific consumers", + memoryNotAccountedFor, taskAttemptId); } } @@ -214,7 +229,8 @@ public long pageSizeBytes() { * Allocate a block of memory that will be tracked in the MemoryManager's page table; this is * intended for allocating large blocks of Tungsten memory that will be shared between operators. * - * Returns `null` if there was not enough memory to allocate the page. + * Returns `null` if there was not enough memory to allocate the page. May return a page that + * contains fewer bytes than requested, so callers should verify the size of returned pages. */ public MemoryBlock allocatePage(long size, MemoryConsumer consumer) { if (size > MAXIMUM_PAGE_SIZE_BYTES) { @@ -222,7 +238,7 @@ public MemoryBlock allocatePage(long size, MemoryConsumer consumer) { "Cannot allocate a page with more than " + MAXIMUM_PAGE_SIZE_BYTES + " bytes"); } - long acquired = acquireExecutionMemory(size, consumer); + long acquired = acquireExecutionMemory(size, tungstenMemoryMode, consumer); if (acquired <= 0) { return null; } @@ -231,7 +247,7 @@ public MemoryBlock allocatePage(long size, MemoryConsumer consumer) { synchronized (this) { pageNumber = allocatedPages.nextClearBit(0); if (pageNumber >= PAGE_TABLE_SIZE) { - releaseExecutionMemory(acquired, consumer); + releaseExecutionMemory(acquired, tungstenMemoryMode, consumer); throw new IllegalStateException( "Have already allocated a maximum of " + PAGE_TABLE_SIZE + " pages"); } @@ -262,7 +278,7 @@ public void freePage(MemoryBlock page, MemoryConsumer consumer) { } long pageSize = page.size(); memoryManager.tungstenMemoryAllocator().free(page); - releaseExecutionMemory(pageSize, consumer); + releaseExecutionMemory(pageSize, tungstenMemoryMode, consumer); } /** @@ -276,7 +292,7 @@ public void freePage(MemoryBlock page, MemoryConsumer consumer) { * @return an encoded page address. */ public long encodePageNumberAndOffset(MemoryBlock page, long offsetInPage) { - if (!inHeap) { + if (tungstenMemoryMode == MemoryMode.OFF_HEAP) { // In off-heap mode, an offset is an absolute address that may require a full 64 bits to // encode. Due to our page size limitation, though, we can convert this into an offset that's // relative to the page's base offset; this relative offset will fit in 51 bits. @@ -305,7 +321,7 @@ private static long decodeOffset(long pagePlusOffsetAddress) { * {@link TaskMemoryManager#encodePageNumberAndOffset(MemoryBlock, long)} */ public Object getPage(long pagePlusOffsetAddress) { - if (inHeap) { + if (tungstenMemoryMode == MemoryMode.ON_HEAP) { final int pageNumber = decodePageNumber(pagePlusOffsetAddress); assert (pageNumber >= 0 && pageNumber < PAGE_TABLE_SIZE); final MemoryBlock page = pageTable[pageNumber]; @@ -323,7 +339,7 @@ public Object getPage(long pagePlusOffsetAddress) { */ public long getOffsetInPage(long pagePlusOffsetAddress) { final long offsetInPage = decodeOffset(pagePlusOffsetAddress); - if (inHeap) { + if (tungstenMemoryMode == MemoryMode.ON_HEAP) { return offsetInPage; } else { // In off-heap mode, an offset is an absolute address. In encodePageNumberAndOffset, we @@ -351,11 +367,19 @@ public long cleanUpAllAllocatedMemory() { } consumers.clear(); } + + for (MemoryBlock page : pageTable) { + if (page != null) { + memoryManager.tungstenMemoryAllocator().free(page); + } + } + Arrays.fill(pageTable, null); + return memoryManager.releaseAllExecutionMemoryForTask(taskAttemptId); } /** - * Returns the memory consumption, in bytes, for the current task + * Returns the memory consumption, in bytes, for the current task. */ public long getMemoryConsumptionForThisTask() { return memoryManager.getExecutionMemoryUsageForTask(taskAttemptId); diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 23ae9360f6a22..4474a83bedbdb 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -341,7 +341,7 @@ object SparkEnv extends Logging { if (useLegacyMemoryManager) { new StaticMemoryManager(conf, numUsableCores) } else { - new UnifiedMemoryManager(conf, numUsableCores) + UnifiedMemoryManager(conf, numUsableCores) } val blockTransferService = new NettyBlockTransferService(conf, securityManager, numUsableCores) diff --git a/core/src/main/scala/org/apache/spark/memory/ExecutionMemoryPool.scala b/core/src/main/scala/org/apache/spark/memory/ExecutionMemoryPool.scala new file mode 100644 index 0000000000000..7825bae425877 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/memory/ExecutionMemoryPool.scala @@ -0,0 +1,153 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.memory + +import javax.annotation.concurrent.GuardedBy + +import scala.collection.mutable + +import org.apache.spark.Logging + +/** + * Implements policies and bookkeeping for sharing a adjustable-sized pool of memory between tasks. + * + * Tries to ensure that each task gets a reasonable share of memory, instead of some task ramping up + * to a large amount first and then causing others to spill to disk repeatedly. + * + * If there are N tasks, it ensures that each task can acquire at least 1 / 2N of the memory + * before it has to spill, and at most 1 / N. Because N varies dynamically, we keep track of the + * set of active tasks and redo the calculations of 1 / 2N and 1 / N in waiting tasks whenever this + * set changes. This is all done by synchronizing access to mutable state and using wait() and + * notifyAll() to signal changes to callers. Prior to Spark 1.6, this arbitration of memory across + * tasks was performed by the ShuffleMemoryManager. + * + * @param lock a [[MemoryManager]] instance to synchronize on + * @param poolName a human-readable name for this pool, for use in log messages + */ +class ExecutionMemoryPool( + lock: Object, + poolName: String + ) extends MemoryPool(lock) with Logging { + + /** + * Map from taskAttemptId -> memory consumption in bytes + */ + @GuardedBy("lock") + private val memoryForTask = new mutable.HashMap[Long, Long]() + + override def memoryUsed: Long = lock.synchronized { + memoryForTask.values.sum + } + + /** + * Returns the memory consumption, in bytes, for the given task. + */ + def getMemoryUsageForTask(taskAttemptId: Long): Long = lock.synchronized { + memoryForTask.getOrElse(taskAttemptId, 0L) + } + + /** + * Try to acquire up to `numBytes` of memory for the given task and return the number of bytes + * obtained, or 0 if none can be allocated. + * + * This call may block until there is enough free memory in some situations, to make sure each + * task has a chance to ramp up to at least 1 / 2N of the total memory pool (where N is the # of + * active tasks) before it is forced to spill. This can happen if the number of tasks increase + * but an older task had a lot of memory already. + * + * @return the number of bytes granted to the task. + */ + def acquireMemory(numBytes: Long, taskAttemptId: Long): Long = lock.synchronized { + assert(numBytes > 0, s"invalid number of bytes requested: $numBytes") + + // Add this task to the taskMemory map just so we can keep an accurate count of the number + // of active tasks, to let other tasks ramp down their memory in calls to `acquireMemory` + if (!memoryForTask.contains(taskAttemptId)) { + memoryForTask(taskAttemptId) = 0L + // This will later cause waiting tasks to wake up and check numTasks again + lock.notifyAll() + } + + // Keep looping until we're either sure that we don't want to grant this request (because this + // task would have more than 1 / numActiveTasks of the memory) or we have enough free + // memory to give it (we always let each task get at least 1 / (2 * numActiveTasks)). + // TODO: simplify this to limit each task to its own slot + while (true) { + val numActiveTasks = memoryForTask.keys.size + val curMem = memoryForTask(taskAttemptId) + + // How much we can grant this task; don't let it grow to more than 1 / numActiveTasks; + // don't let it be negative + val maxToGrant = + math.min(numBytes, math.max(0, (poolSize / numActiveTasks) - curMem)) + // Only give it as much memory as is free, which might be none if it reached 1 / numTasks + val toGrant = math.min(maxToGrant, memoryFree) + + if (curMem < poolSize / (2 * numActiveTasks)) { + // We want to let each task get at least 1 / (2 * numActiveTasks) before blocking; + // if we can't give it this much now, wait for other tasks to free up memory + // (this happens if older tasks allocated lots of memory before N grew) + if (memoryFree >= math.min(maxToGrant, poolSize / (2 * numActiveTasks) - curMem)) { + memoryForTask(taskAttemptId) += toGrant + return toGrant + } else { + logInfo( + s"TID $taskAttemptId waiting for at least 1/2N of $poolName pool to be free") + lock.wait() + } + } else { + memoryForTask(taskAttemptId) += toGrant + return toGrant + } + } + 0L // Never reached + } + + /** + * Release `numBytes` of memory acquired by the given task. + */ + def releaseMemory(numBytes: Long, taskAttemptId: Long): Unit = lock.synchronized { + val curMem = memoryForTask.getOrElse(taskAttemptId, 0L) + var memoryToFree = if (curMem < numBytes) { + logWarning( + s"Internal error: release called on $numBytes bytes but task only has $curMem bytes " + + s"of memory from the $poolName pool") + curMem + } else { + numBytes + } + if (memoryForTask.contains(taskAttemptId)) { + memoryForTask(taskAttemptId) -= memoryToFree + if (memoryForTask(taskAttemptId) <= 0) { + memoryForTask.remove(taskAttemptId) + } + } + lock.notifyAll() // Notify waiters in acquireMemory() that memory has been freed + } + + /** + * Release all memory for the given task and mark it as inactive (e.g. when a task ends). + * @return the number of bytes freed. + */ + def releaseAllMemoryForTask(taskAttemptId: Long): Long = lock.synchronized { + val numBytesToFree = getMemoryUsageForTask(taskAttemptId) + releaseMemory(numBytesToFree, taskAttemptId) + numBytesToFree + } + +} diff --git a/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala b/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala index b0cf2696a397f..ceb8ea434e1be 100644 --- a/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala +++ b/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala @@ -20,12 +20,8 @@ package org.apache.spark.memory import javax.annotation.concurrent.GuardedBy import scala.collection.mutable -import scala.collection.mutable.ArrayBuffer -import com.google.common.annotations.VisibleForTesting - -import org.apache.spark.util.Utils -import org.apache.spark.{SparkException, TaskContext, SparkConf, Logging} +import org.apache.spark.{SparkConf, Logging} import org.apache.spark.storage.{BlockId, BlockStatus, MemoryStore} import org.apache.spark.unsafe.array.ByteArrayMethods import org.apache.spark.unsafe.memory.MemoryAllocator @@ -36,53 +32,40 @@ import org.apache.spark.unsafe.memory.MemoryAllocator * In this context, execution memory refers to that used for computation in shuffles, joins, * sorts and aggregations, while storage memory refers to that used for caching and propagating * internal data across the cluster. There exists one MemoryManager per JVM. - * - * The MemoryManager abstract base class itself implements policies for sharing execution memory - * between tasks; it tries to ensure that each task gets a reasonable share of memory, instead of - * some task ramping up to a large amount first and then causing others to spill to disk repeatedly. - * If there are N tasks, it ensures that each task can acquire at least 1 / 2N of the memory - * before it has to spill, and at most 1 / N. Because N varies dynamically, we keep track of the - * set of active tasks and redo the calculations of 1 / 2N and 1 / N in waiting tasks whenever - * this set changes. This is all done by synchronizing access to mutable state and using wait() and - * notifyAll() to signal changes to callers. Prior to Spark 1.6, this arbitration of memory across - * tasks was performed by the ShuffleMemoryManager. */ -private[spark] abstract class MemoryManager(conf: SparkConf, numCores: Int) extends Logging { +private[spark] abstract class MemoryManager( + conf: SparkConf, + numCores: Int, + storageMemory: Long, + onHeapExecutionMemory: Long) extends Logging { // -- Methods related to memory allocation policies and bookkeeping ------------------------------ - // The memory store used to evict cached blocks - private var _memoryStore: MemoryStore = _ - protected def memoryStore: MemoryStore = { - if (_memoryStore == null) { - throw new IllegalArgumentException("memory store not initialized yet") - } - _memoryStore - } + @GuardedBy("this") + protected val storageMemoryPool = new StorageMemoryPool(this) + @GuardedBy("this") + protected val onHeapExecutionMemoryPool = new ExecutionMemoryPool(this, "on-heap execution") + @GuardedBy("this") + protected val offHeapExecutionMemoryPool = new ExecutionMemoryPool(this, "off-heap execution") - // Amount of execution/storage memory in use, accesses must be synchronized on `this` - @GuardedBy("this") protected var _executionMemoryUsed: Long = 0 - @GuardedBy("this") protected var _storageMemoryUsed: Long = 0 - // Map from taskAttemptId -> memory consumption in bytes - @GuardedBy("this") private val executionMemoryForTask = new mutable.HashMap[Long, Long]() - - /** - * Set the [[MemoryStore]] used by this manager to evict cached blocks. - * This must be set after construction due to initialization ordering constraints. - */ - final def setMemoryStore(store: MemoryStore): Unit = { - _memoryStore = store - } + storageMemoryPool.incrementPoolSize(storageMemory) + onHeapExecutionMemoryPool.incrementPoolSize(onHeapExecutionMemory) + offHeapExecutionMemoryPool.incrementPoolSize(conf.getSizeAsBytes("spark.memory.offHeapSize", 0)) /** - * Total available memory for execution, in bytes. + * Total available memory for storage, in bytes. This amount can vary over time, depending on + * the MemoryManager implementation. + * In this model, this is equivalent to the amount of memory not occupied by execution. */ - def maxExecutionMemory: Long + def maxStorageMemory: Long /** - * Total available memory for storage, in bytes. + * Set the [[MemoryStore]] used by this manager to evict cached blocks. + * This must be set after construction due to initialization ordering constraints. */ - def maxStorageMemory: Long + final def setMemoryStore(store: MemoryStore): Unit = synchronized { + storageMemoryPool.setMemoryStore(store) + } // TODO: avoid passing evicted blocks around to simplify method signatures (SPARK-10985) @@ -94,7 +77,9 @@ private[spark] abstract class MemoryManager(conf: SparkConf, numCores: Int) exte def acquireStorageMemory( blockId: BlockId, numBytes: Long, - evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean + evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = synchronized { + storageMemoryPool.acquireMemory(blockId, numBytes, evictedBlocks) + } /** * Acquire N bytes of memory to unroll the given block, evicting existing ones if necessary. @@ -109,103 +94,25 @@ private[spark] abstract class MemoryManager(conf: SparkConf, numCores: Int) exte def acquireUnrollMemory( blockId: BlockId, numBytes: Long, - evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = synchronized { - acquireStorageMemory(blockId, numBytes, evictedBlocks) - } - - /** - * Acquire N bytes of memory for execution, evicting cached blocks if necessary. - * Blocks evicted in the process, if any, are added to `evictedBlocks`. - * @return number of bytes successfully granted (<= N). - */ - @VisibleForTesting - private[memory] def doAcquireExecutionMemory( - numBytes: Long, - evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Long + evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean /** - * Try to acquire up to `numBytes` of execution memory for the current task and return the number - * of bytes obtained, or 0 if none can be allocated. + * Try to acquire up to `numBytes` of execution memory for the current task and return the + * number of bytes obtained, or 0 if none can be allocated. * * This call may block until there is enough free memory in some situations, to make sure each * task has a chance to ramp up to at least 1 / 2N of the total memory pool (where N is the # of * active tasks) before it is forced to spill. This can happen if the number of tasks increase * but an older task had a lot of memory already. - * - * Subclasses should override `doAcquireExecutionMemory` in order to customize the policies - * that control global sharing of memory between execution and storage. */ private[memory] - final def acquireExecutionMemory(numBytes: Long, taskAttemptId: Long): Long = synchronized { - assert(numBytes > 0, "invalid number of bytes requested: " + numBytes) - - // Add this task to the taskMemory map just so we can keep an accurate count of the number - // of active tasks, to let other tasks ramp down their memory in calls to tryToAcquire - if (!executionMemoryForTask.contains(taskAttemptId)) { - executionMemoryForTask(taskAttemptId) = 0L - // This will later cause waiting tasks to wake up and check numTasks again - notifyAll() - } - - // Once the cross-task memory allocation policy has decided to grant more memory to a task, - // this method is called in order to actually obtain that execution memory, potentially - // triggering eviction of storage memory: - def acquire(toGrant: Long): Long = synchronized { - val evictedBlocks = new ArrayBuffer[(BlockId, BlockStatus)] - val acquired = doAcquireExecutionMemory(toGrant, evictedBlocks) - // Register evicted blocks, if any, with the active task metrics - Option(TaskContext.get()).foreach { tc => - val metrics = tc.taskMetrics() - val lastUpdatedBlocks = metrics.updatedBlocks.getOrElse(Seq[(BlockId, BlockStatus)]()) - metrics.updatedBlocks = Some(lastUpdatedBlocks ++ evictedBlocks.toSeq) - } - executionMemoryForTask(taskAttemptId) += acquired - acquired - } - - // Keep looping until we're either sure that we don't want to grant this request (because this - // task would have more than 1 / numActiveTasks of the memory) or we have enough free - // memory to give it (we always let each task get at least 1 / (2 * numActiveTasks)). - // TODO: simplify this to limit each task to its own slot - while (true) { - val numActiveTasks = executionMemoryForTask.keys.size - val curMem = executionMemoryForTask(taskAttemptId) - val freeMemory = maxExecutionMemory - executionMemoryForTask.values.sum - - // How much we can grant this task; don't let it grow to more than 1 / numActiveTasks; - // don't let it be negative - val maxToGrant = - math.min(numBytes, math.max(0, (maxExecutionMemory / numActiveTasks) - curMem)) - // Only give it as much memory as is free, which might be none if it reached 1 / numTasks - val toGrant = math.min(maxToGrant, freeMemory) - - if (curMem < maxExecutionMemory / (2 * numActiveTasks)) { - // We want to let each task get at least 1 / (2 * numActiveTasks) before blocking; - // if we can't give it this much now, wait for other tasks to free up memory - // (this happens if older tasks allocated lots of memory before N grew) - if ( - freeMemory >= math.min(maxToGrant, maxExecutionMemory / (2 * numActiveTasks) - curMem)) { - return acquire(toGrant) - } else { - logInfo( - s"TID $taskAttemptId waiting for at least 1/2N of execution memory pool to be free") - wait() - } - } else { - return acquire(toGrant) - } - } - 0L // Never reached - } - - @VisibleForTesting - private[memory] def releaseExecutionMemory(numBytes: Long): Unit = synchronized { - if (numBytes > _executionMemoryUsed) { - logWarning(s"Attempted to release $numBytes bytes of execution " + - s"memory when we only have ${_executionMemoryUsed} bytes") - _executionMemoryUsed = 0 - } else { - _executionMemoryUsed -= numBytes + def acquireExecutionMemory( + numBytes: Long, + taskAttemptId: Long, + memoryMode: MemoryMode): Long = synchronized { + memoryMode match { + case MemoryMode.ON_HEAP => onHeapExecutionMemoryPool.acquireMemory(numBytes, taskAttemptId) + case MemoryMode.OFF_HEAP => offHeapExecutionMemoryPool.acquireMemory(numBytes, taskAttemptId) } } @@ -213,24 +120,14 @@ private[spark] abstract class MemoryManager(conf: SparkConf, numCores: Int) exte * Release numBytes of execution memory belonging to the given task. */ private[memory] - final def releaseExecutionMemory(numBytes: Long, taskAttemptId: Long): Unit = synchronized { - val curMem = executionMemoryForTask.getOrElse(taskAttemptId, 0L) - if (curMem < numBytes) { - if (Utils.isTesting) { - throw new SparkException( - s"Internal error: release called on $numBytes bytes but task only has $curMem") - } else { - logWarning(s"Internal error: release called on $numBytes bytes but task only has $curMem") - } - } - if (executionMemoryForTask.contains(taskAttemptId)) { - executionMemoryForTask(taskAttemptId) -= numBytes - if (executionMemoryForTask(taskAttemptId) <= 0) { - executionMemoryForTask.remove(taskAttemptId) - } - releaseExecutionMemory(numBytes) + def releaseExecutionMemory( + numBytes: Long, + taskAttemptId: Long, + memoryMode: MemoryMode): Unit = synchronized { + memoryMode match { + case MemoryMode.ON_HEAP => onHeapExecutionMemoryPool.releaseMemory(numBytes, taskAttemptId) + case MemoryMode.OFF_HEAP => offHeapExecutionMemoryPool.releaseMemory(numBytes, taskAttemptId) } - notifyAll() // Notify waiters in acquireExecutionMemory() that memory has been freed } /** @@ -238,35 +135,28 @@ private[spark] abstract class MemoryManager(conf: SparkConf, numCores: Int) exte * @return the number of bytes freed. */ private[memory] def releaseAllExecutionMemoryForTask(taskAttemptId: Long): Long = synchronized { - val numBytesToFree = getExecutionMemoryUsageForTask(taskAttemptId) - releaseExecutionMemory(numBytesToFree, taskAttemptId) - numBytesToFree + onHeapExecutionMemoryPool.releaseAllMemoryForTask(taskAttemptId) + + offHeapExecutionMemoryPool.releaseAllMemoryForTask(taskAttemptId) } /** * Release N bytes of storage memory. */ def releaseStorageMemory(numBytes: Long): Unit = synchronized { - if (numBytes > _storageMemoryUsed) { - logWarning(s"Attempted to release $numBytes bytes of storage " + - s"memory when we only have ${_storageMemoryUsed} bytes") - _storageMemoryUsed = 0 - } else { - _storageMemoryUsed -= numBytes - } + storageMemoryPool.releaseMemory(numBytes) } /** * Release all storage memory acquired. */ - def releaseAllStorageMemory(): Unit = synchronized { - _storageMemoryUsed = 0 + final def releaseAllStorageMemory(): Unit = synchronized { + storageMemoryPool.releaseAllMemory() } /** * Release N bytes of unroll memory. */ - def releaseUnrollMemory(numBytes: Long): Unit = synchronized { + final def releaseUnrollMemory(numBytes: Long): Unit = synchronized { releaseStorageMemory(numBytes) } @@ -274,25 +164,34 @@ private[spark] abstract class MemoryManager(conf: SparkConf, numCores: Int) exte * Execution memory currently in use, in bytes. */ final def executionMemoryUsed: Long = synchronized { - _executionMemoryUsed + onHeapExecutionMemoryPool.memoryUsed + offHeapExecutionMemoryPool.memoryUsed } /** * Storage memory currently in use, in bytes. */ final def storageMemoryUsed: Long = synchronized { - _storageMemoryUsed + storageMemoryPool.memoryUsed } /** * Returns the execution memory consumption, in bytes, for the given task. */ private[memory] def getExecutionMemoryUsageForTask(taskAttemptId: Long): Long = synchronized { - executionMemoryForTask.getOrElse(taskAttemptId, 0L) + onHeapExecutionMemoryPool.getMemoryUsageForTask(taskAttemptId) + + offHeapExecutionMemoryPool.getMemoryUsageForTask(taskAttemptId) } // -- Fields related to Tungsten managed memory ------------------------------------------------- + /** + * Tracks whether Tungsten memory will be allocated on the JVM heap or off-heap using + * sun.misc.Unsafe. + */ + final val tungstenMemoryMode: MemoryMode = { + if (conf.getBoolean("spark.unsafe.offHeap", false)) MemoryMode.OFF_HEAP else MemoryMode.ON_HEAP + } + /** * The default page size, in bytes. * @@ -306,21 +205,22 @@ private[spark] abstract class MemoryManager(conf: SparkConf, numCores: Int) exte val cores = if (numCores > 0) numCores else Runtime.getRuntime.availableProcessors() // Because of rounding to next power of 2, we may have safetyFactor as 8 in worst case val safetyFactor = 16 - val size = ByteArrayMethods.nextPowerOf2(maxExecutionMemory / cores / safetyFactor) + val maxTungstenMemory: Long = tungstenMemoryMode match { + case MemoryMode.ON_HEAP => onHeapExecutionMemoryPool.poolSize + case MemoryMode.OFF_HEAP => offHeapExecutionMemoryPool.poolSize + } + val size = ByteArrayMethods.nextPowerOf2(maxTungstenMemory / cores / safetyFactor) val default = math.min(maxPageSize, math.max(minPageSize, size)) conf.getSizeAsBytes("spark.buffer.pageSize", default) } - /** - * Tracks whether Tungsten memory will be allocated on the JVM heap or off-heap using - * sun.misc.Unsafe. - */ - final val tungstenMemoryIsAllocatedInHeap: Boolean = - !conf.getBoolean("spark.unsafe.offHeap", false) - /** * Allocates memory for use by Unsafe/Tungsten code. */ - private[memory] final val tungstenMemoryAllocator: MemoryAllocator = - if (tungstenMemoryIsAllocatedInHeap) MemoryAllocator.HEAP else MemoryAllocator.UNSAFE + private[memory] final val tungstenMemoryAllocator: MemoryAllocator = { + tungstenMemoryMode match { + case MemoryMode.ON_HEAP => MemoryAllocator.HEAP + case MemoryMode.OFF_HEAP => MemoryAllocator.UNSAFE + } + } } diff --git a/core/src/main/scala/org/apache/spark/memory/MemoryPool.scala b/core/src/main/scala/org/apache/spark/memory/MemoryPool.scala new file mode 100644 index 0000000000000..bfeec47e3892e --- /dev/null +++ b/core/src/main/scala/org/apache/spark/memory/MemoryPool.scala @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.memory + +import javax.annotation.concurrent.GuardedBy + +/** + * Manages bookkeeping for an adjustable-sized region of memory. This class is internal to + * the [[MemoryManager]]. See subclasses for more details. + * + * @param lock a [[MemoryManager]] instance, used for synchronization. We purposely erase the type + * to `Object` to avoid programming errors, since this object should only be used for + * synchronization purposes. + */ +abstract class MemoryPool(lock: Object) { + + @GuardedBy("lock") + private[this] var _poolSize: Long = 0 + + /** + * Returns the current size of the pool, in bytes. + */ + final def poolSize: Long = lock.synchronized { + _poolSize + } + + /** + * Returns the amount of free memory in the pool, in bytes. + */ + final def memoryFree: Long = lock.synchronized { + _poolSize - memoryUsed + } + + /** + * Expands the pool by `delta` bytes. + */ + final def incrementPoolSize(delta: Long): Unit = lock.synchronized { + require(delta >= 0) + _poolSize += delta + } + + /** + * Shrinks the pool by `delta` bytes. + */ + final def decrementPoolSize(delta: Long): Unit = lock.synchronized { + require(delta >= 0) + require(delta <= _poolSize) + require(_poolSize - delta >= memoryUsed) + _poolSize -= delta + } + + /** + * Returns the amount of used memory in this pool (in bytes). + */ + def memoryUsed: Long +} diff --git a/core/src/main/scala/org/apache/spark/memory/StaticMemoryManager.scala b/core/src/main/scala/org/apache/spark/memory/StaticMemoryManager.scala index 9c2c2e90a2282..12a094306861f 100644 --- a/core/src/main/scala/org/apache/spark/memory/StaticMemoryManager.scala +++ b/core/src/main/scala/org/apache/spark/memory/StaticMemoryManager.scala @@ -22,7 +22,6 @@ import scala.collection.mutable import org.apache.spark.SparkConf import org.apache.spark.storage.{BlockId, BlockStatus} - /** * A [[MemoryManager]] that statically partitions the heap space into disjoint regions. * @@ -32,10 +31,14 @@ import org.apache.spark.storage.{BlockId, BlockStatus} */ private[spark] class StaticMemoryManager( conf: SparkConf, - override val maxExecutionMemory: Long, + maxOnHeapExecutionMemory: Long, override val maxStorageMemory: Long, numCores: Int) - extends MemoryManager(conf, numCores) { + extends MemoryManager( + conf, + numCores, + maxStorageMemory, + maxOnHeapExecutionMemory) { def this(conf: SparkConf, numCores: Int) { this( @@ -50,76 +53,15 @@ private[spark] class StaticMemoryManager( (maxStorageMemory * conf.getDouble("spark.storage.unrollFraction", 0.2)).toLong } - /** - * Acquire N bytes of memory for execution. - * @return number of bytes successfully granted (<= N). - */ - override def doAcquireExecutionMemory( - numBytes: Long, - evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Long = synchronized { - assert(numBytes >= 0) - assert(_executionMemoryUsed <= maxExecutionMemory) - val bytesToGrant = math.min(numBytes, maxExecutionMemory - _executionMemoryUsed) - _executionMemoryUsed += bytesToGrant - bytesToGrant - } - - /** - * Acquire N bytes of memory to cache the given block, evicting existing ones if necessary. - * Blocks evicted in the process, if any, are added to `evictedBlocks`. - * @return whether all N bytes were successfully granted. - */ - override def acquireStorageMemory( - blockId: BlockId, - numBytes: Long, - evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = synchronized { - acquireStorageMemory(blockId, numBytes, numBytes, evictedBlocks) - } - - /** - * Acquire N bytes of memory to unroll the given block, evicting existing ones if necessary. - * - * This evicts at most M bytes worth of existing blocks, where M is a fraction of the storage - * space specified by `spark.storage.unrollFraction`. Blocks evicted in the process, if any, - * are added to `evictedBlocks`. - * - * @return whether all N bytes were successfully granted. - */ override def acquireUnrollMemory( blockId: BlockId, numBytes: Long, evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = synchronized { - val currentUnrollMemory = memoryStore.currentUnrollMemory + val currentUnrollMemory = storageMemoryPool.memoryStore.currentUnrollMemory val maxNumBytesToFree = math.max(0, maxMemoryToEvictForUnroll - currentUnrollMemory) val numBytesToFree = math.min(numBytes, maxNumBytesToFree) - acquireStorageMemory(blockId, numBytes, numBytesToFree, evictedBlocks) + storageMemoryPool.acquireMemory(blockId, numBytes, numBytesToFree, evictedBlocks) } - - /** - * Acquire N bytes of storage memory for the given block, evicting existing ones if necessary. - * - * @param blockId the ID of the block we are acquiring storage memory for - * @param numBytesToAcquire the size of this block - * @param numBytesToFree the size of space to be freed through evicting blocks - * @param evictedBlocks a holder for blocks evicted in the process - * @return whether all N bytes were successfully granted. - */ - private def acquireStorageMemory( - blockId: BlockId, - numBytesToAcquire: Long, - numBytesToFree: Long, - evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = synchronized { - assert(numBytesToAcquire >= 0) - assert(numBytesToFree >= 0) - memoryStore.ensureFreeSpace(blockId, numBytesToFree, evictedBlocks) - assert(_storageMemoryUsed <= maxStorageMemory) - val enoughMemory = _storageMemoryUsed + numBytesToAcquire <= maxStorageMemory - if (enoughMemory) { - _storageMemoryUsed += numBytesToAcquire - } - enoughMemory - } - } @@ -135,7 +77,6 @@ private[spark] object StaticMemoryManager { (systemMaxMemory * memoryFraction * safetyFraction).toLong } - /** * Return the total amount of memory available for the execution region, in bytes. */ diff --git a/core/src/main/scala/org/apache/spark/memory/StorageMemoryPool.scala b/core/src/main/scala/org/apache/spark/memory/StorageMemoryPool.scala new file mode 100644 index 0000000000000..6a322eabf81ed --- /dev/null +++ b/core/src/main/scala/org/apache/spark/memory/StorageMemoryPool.scala @@ -0,0 +1,138 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.memory + +import javax.annotation.concurrent.GuardedBy + +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.{TaskContext, Logging} +import org.apache.spark.storage.{MemoryStore, BlockStatus, BlockId} + +/** + * Performs bookkeeping for managing an adjustable-size pool of memory that is used for storage + * (caching). + * + * @param lock a [[MemoryManager]] instance to synchronize on + */ +class StorageMemoryPool(lock: Object) extends MemoryPool(lock) with Logging { + + @GuardedBy("lock") + private[this] var _memoryUsed: Long = 0L + + override def memoryUsed: Long = lock.synchronized { + _memoryUsed + } + + private var _memoryStore: MemoryStore = _ + def memoryStore: MemoryStore = { + if (_memoryStore == null) { + throw new IllegalStateException("memory store not initialized yet") + } + _memoryStore + } + + /** + * Set the [[MemoryStore]] used by this manager to evict cached blocks. + * This must be set after construction due to initialization ordering constraints. + */ + final def setMemoryStore(store: MemoryStore): Unit = { + _memoryStore = store + } + + /** + * Acquire N bytes of memory to cache the given block, evicting existing ones if necessary. + * Blocks evicted in the process, if any, are added to `evictedBlocks`. + * @return whether all N bytes were successfully granted. + */ + def acquireMemory( + blockId: BlockId, + numBytes: Long, + evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = lock.synchronized { + acquireMemory(blockId, numBytes, numBytes, evictedBlocks) + } + + /** + * Acquire N bytes of storage memory for the given block, evicting existing ones if necessary. + * + * @param blockId the ID of the block we are acquiring storage memory for + * @param numBytesToAcquire the size of this block + * @param numBytesToFree the size of space to be freed through evicting blocks + * @return whether all N bytes were successfully granted. + */ + def acquireMemory( + blockId: BlockId, + numBytesToAcquire: Long, + numBytesToFree: Long, + evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = lock.synchronized { + assert(numBytesToAcquire >= 0) + assert(numBytesToFree >= 0) + assert(memoryUsed <= poolSize) + memoryStore.ensureFreeSpace(blockId, numBytesToFree, evictedBlocks) + // Register evicted blocks, if any, with the active task metrics + Option(TaskContext.get()).foreach { tc => + val metrics = tc.taskMetrics() + val lastUpdatedBlocks = metrics.updatedBlocks.getOrElse(Seq[(BlockId, BlockStatus)]()) + metrics.updatedBlocks = Some(lastUpdatedBlocks ++ evictedBlocks.toSeq) + } + // NOTE: If the memory store evicts blocks, then those evictions will synchronously call + // back into this StorageMemoryPool in order to free. Therefore, these variables should have + // been updated. + val enoughMemory = numBytesToAcquire <= memoryFree + if (enoughMemory) { + _memoryUsed += numBytesToAcquire + } + enoughMemory + } + + def releaseMemory(size: Long): Unit = lock.synchronized { + if (size > _memoryUsed) { + logWarning(s"Attempted to release $size bytes of storage " + + s"memory when we only have ${_memoryUsed} bytes") + _memoryUsed = 0 + } else { + _memoryUsed -= size + } + } + + def releaseAllMemory(): Unit = lock.synchronized { + _memoryUsed = 0 + } + + /** + * Try to shrink the size of this storage memory pool by `spaceToFree` bytes. Return the number + * of bytes removed from the pool's capacity. + */ + def shrinkPoolToFreeSpace(spaceToFree: Long): Long = lock.synchronized { + // First, shrink the pool by reclaiming free memory: + val spaceFreedByReleasingUnusedMemory = Math.min(spaceToFree, memoryFree) + decrementPoolSize(spaceFreedByReleasingUnusedMemory) + if (spaceFreedByReleasingUnusedMemory == spaceToFree) { + spaceFreedByReleasingUnusedMemory + } else { + // If reclaiming free memory did not adequately shrink the pool, begin evicting blocks: + val evictedBlocks = new ArrayBuffer[(BlockId, BlockStatus)] + memoryStore.ensureFreeSpace(spaceToFree - spaceFreedByReleasingUnusedMemory, evictedBlocks) + val spaceFreedByEviction = evictedBlocks.map(_._2.memSize).sum + _memoryUsed -= spaceFreedByEviction + decrementPoolSize(spaceFreedByEviction) + spaceFreedByReleasingUnusedMemory + spaceFreedByEviction + } + } +} diff --git a/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala b/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala index a3093030a0f93..8be5b05419094 100644 --- a/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala +++ b/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala @@ -22,7 +22,6 @@ import scala.collection.mutable import org.apache.spark.SparkConf import org.apache.spark.storage.{BlockStatus, BlockId} - /** * A [[MemoryManager]] that enforces a soft boundary between execution and storage such that * either side can borrow memory from the other. @@ -41,98 +40,105 @@ import org.apache.spark.storage.{BlockStatus, BlockId} * The implication is that attempts to cache blocks may fail if execution has already eaten * up most of the storage space, in which case the new blocks will be evicted immediately * according to their respective storage levels. + * + * @param storageRegionSize Size of the storage region, in bytes. + * This region is not statically reserved; execution can borrow from + * it if necessary. Cached blocks can be evicted only if actual + * storage memory usage exceeds this region. */ -private[spark] class UnifiedMemoryManager( +private[spark] class UnifiedMemoryManager private[memory] ( conf: SparkConf, maxMemory: Long, + private val storageRegionSize: Long, numCores: Int) - extends MemoryManager(conf, numCores) { - - def this(conf: SparkConf, numCores: Int) { - this(conf, UnifiedMemoryManager.getMaxMemory(conf), numCores) - } - - /** - * Size of the storage region, in bytes. - * - * This region is not statically reserved; execution can borrow from it if necessary. - * Cached blocks can be evicted only if actual storage memory usage exceeds this region. - */ - private val storageRegionSize: Long = { - (maxMemory * conf.getDouble("spark.memory.storageFraction", 0.5)).toLong - } - - /** - * Total amount of memory, in bytes, not currently occupied by either execution or storage. - */ - private def totalFreeMemory: Long = synchronized { - assert(_executionMemoryUsed <= maxMemory) - assert(_storageMemoryUsed <= maxMemory) - assert(_executionMemoryUsed + _storageMemoryUsed <= maxMemory) - maxMemory - _executionMemoryUsed - _storageMemoryUsed - } + extends MemoryManager( + conf, + numCores, + storageRegionSize, + maxMemory - storageRegionSize) { - /** - * Total available memory for execution, in bytes. - * In this model, this is equivalent to the amount of memory not occupied by storage. - */ - override def maxExecutionMemory: Long = synchronized { - maxMemory - _storageMemoryUsed - } + // We always maintain this invariant: + assert(onHeapExecutionMemoryPool.poolSize + storageMemoryPool.poolSize == maxMemory) - /** - * Total available memory for storage, in bytes. - * In this model, this is equivalent to the amount of memory not occupied by execution. - */ override def maxStorageMemory: Long = synchronized { - maxMemory - _executionMemoryUsed + maxMemory - onHeapExecutionMemoryPool.memoryUsed } /** - * Acquire N bytes of memory for execution, evicting cached blocks if necessary. + * Try to acquire up to `numBytes` of execution memory for the current task and return the + * number of bytes obtained, or 0 if none can be allocated. * - * This method evicts blocks only up to the amount of memory borrowed by storage. - * Blocks evicted in the process, if any, are added to `evictedBlocks`. - * @return number of bytes successfully granted (<= N). + * This call may block until there is enough free memory in some situations, to make sure each + * task has a chance to ramp up to at least 1 / 2N of the total memory pool (where N is the # of + * active tasks) before it is forced to spill. This can happen if the number of tasks increase + * but an older task had a lot of memory already. */ - private[memory] override def doAcquireExecutionMemory( + override private[memory] def acquireExecutionMemory( numBytes: Long, - evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Long = synchronized { + taskAttemptId: Long, + memoryMode: MemoryMode): Long = synchronized { + assert(onHeapExecutionMemoryPool.poolSize + storageMemoryPool.poolSize == maxMemory) assert(numBytes >= 0) - val memoryBorrowedByStorage = math.max(0, _storageMemoryUsed - storageRegionSize) - // If there is not enough free memory AND storage has borrowed some execution memory, - // then evict as much memory borrowed by storage as needed to grant this request - val shouldEvictStorage = totalFreeMemory < numBytes && memoryBorrowedByStorage > 0 - if (shouldEvictStorage) { - val spaceToEnsure = math.min(numBytes, memoryBorrowedByStorage) - memoryStore.ensureFreeSpace(spaceToEnsure, evictedBlocks) + memoryMode match { + case MemoryMode.ON_HEAP => + if (numBytes > onHeapExecutionMemoryPool.memoryFree) { + val extraMemoryNeeded = numBytes - onHeapExecutionMemoryPool.memoryFree + // There is not enough free memory in the execution pool, so try to reclaim memory from + // storage. We can reclaim any free memory from the storage pool. If the storage pool + // has grown to become larger than `storageRegionSize`, we can evict blocks and reclaim + // the memory that storage has borrowed from execution. + val memoryReclaimableFromStorage = + math.max(storageMemoryPool.memoryFree, storageMemoryPool.poolSize - storageRegionSize) + if (memoryReclaimableFromStorage > 0) { + // Only reclaim as much space as is necessary and available: + val spaceReclaimed = storageMemoryPool.shrinkPoolToFreeSpace( + math.min(extraMemoryNeeded, memoryReclaimableFromStorage)) + onHeapExecutionMemoryPool.incrementPoolSize(spaceReclaimed) + } + } + onHeapExecutionMemoryPool.acquireMemory(numBytes, taskAttemptId) + case MemoryMode.OFF_HEAP => + // For now, we only support on-heap caching of data, so we do not need to interact with + // the storage pool when allocating off-heap memory. This will change in the future, though. + super.acquireExecutionMemory(numBytes, taskAttemptId, memoryMode) } - val bytesToGrant = math.min(numBytes, totalFreeMemory) - _executionMemoryUsed += bytesToGrant - bytesToGrant } - /** - * Acquire N bytes of memory to cache the given block, evicting existing ones if necessary. - * Blocks evicted in the process, if any, are added to `evictedBlocks`. - * @return whether all N bytes were successfully granted. - */ override def acquireStorageMemory( blockId: BlockId, numBytes: Long, evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = synchronized { + assert(onHeapExecutionMemoryPool.poolSize + storageMemoryPool.poolSize == maxMemory) assert(numBytes >= 0) - memoryStore.ensureFreeSpace(blockId, numBytes, evictedBlocks) - val enoughMemory = totalFreeMemory >= numBytes - if (enoughMemory) { - _storageMemoryUsed += numBytes + if (numBytes > storageMemoryPool.memoryFree) { + // There is not enough free memory in the storage pool, so try to borrow free memory from + // the execution pool. + val memoryBorrowedFromExecution = Math.min(onHeapExecutionMemoryPool.memoryFree, numBytes) + onHeapExecutionMemoryPool.decrementPoolSize(memoryBorrowedFromExecution) + storageMemoryPool.incrementPoolSize(memoryBorrowedFromExecution) } - enoughMemory + storageMemoryPool.acquireMemory(blockId, numBytes, evictedBlocks) } + override def acquireUnrollMemory( + blockId: BlockId, + numBytes: Long, + evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = synchronized { + acquireStorageMemory(blockId, numBytes, evictedBlocks) + } } -private object UnifiedMemoryManager { +object UnifiedMemoryManager { + + def apply(conf: SparkConf, numCores: Int): UnifiedMemoryManager = { + val maxMemory = getMaxMemory(conf) + new UnifiedMemoryManager( + conf, + maxMemory = maxMemory, + storageRegionSize = + (maxMemory * conf.getDouble("spark.memory.storageFraction", 0.5)).toLong, + numCores = numCores) + } /** * Return the total amount of memory shared between execution and storage, in bytes. diff --git a/core/src/main/scala/org/apache/spark/memory/package.scala b/core/src/main/scala/org/apache/spark/memory/package.scala new file mode 100644 index 0000000000000..564e30d2ffd66 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/memory/package.scala @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +/** + * This package implements Spark's memory management system. This system consists of two main + * components, a JVM-wide memory manager and a per-task manager: + * + * - [[org.apache.spark.memory.MemoryManager]] manages Spark's overall memory usage within a JVM. + * This component implements the policies for dividing the available memory across tasks and for + * allocating memory between storage (memory used caching and data transfer) and execution (memory + * used by computations, such as shuffles, joins, sorts, and aggregations). + * - [[org.apache.spark.memory.TaskMemoryManager]] manages the memory allocated by individual tasks. + * Tasks interact with TaskMemoryManager and never directly interact with the JVM-wide + * MemoryManager. + * + * Internally, each of these components have additional abstractions for memory bookkeeping: + * + * - [[org.apache.spark.memory.MemoryConsumer]]s are clients of the TaskMemoryManager and + * correspond to individual operators and data structures within a task. The TaskMemoryManager + * receives memory allocation requests from MemoryConsumers and issues callbacks to consumers + * in order to trigger spilling when running low on memory. + * - [[org.apache.spark.memory.MemoryPool]]s are a bookkeeping abstraction used by the + * MemoryManager to track the division of memory between storage and execution. + * + * Diagrammatically: + * + * {{{ + * +-------------+ + * | MemConsumer |----+ +------------------------+ + * +-------------+ | +-------------------+ | MemoryManager | + * +--->| TaskMemoryManager |----+ | | + * +-------------+ | +-------------------+ | | +------------------+ | + * | MemConsumer |----+ | | | StorageMemPool | | + * +-------------+ +-------------------+ | | +------------------+ | + * | TaskMemoryManager |----+ | | + * +-------------------+ | | +------------------+ | + * +---->| |OnHeapExecMemPool | | + * * | | +------------------+ | + * * | | | + * +-------------+ * | | +------------------+ | + * | MemConsumer |----+ | | |OffHeapExecMemPool| | + * +-------------+ | +-------------------+ | | +------------------+ | + * +--->| TaskMemoryManager |----+ | | + * +-------------------+ +------------------------+ + * }}} + * + * + * There are two implementations of [[org.apache.spark.memory.MemoryManager]] which vary in how + * they handle the sizing of their memory pools: + * + * - [[org.apache.spark.memory.UnifiedMemoryManager]], the default in Spark 1.6+, enforces soft + * boundaries between storage and execution memory, allowing requests for memory in one region + * to be fulfilled by borrowing memory from the other. + * - [[org.apache.spark.memory.StaticMemoryManager]] enforces hard boundaries between storage + * and execution memory by statically partitioning Spark's memory and preventing storage and + * execution from borrowing memory from each other. This mode is retained only for legacy + * compatibility purposes. + */ +package object memory diff --git a/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala b/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala index 9e002621a6909..3a48af82b1dae 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala @@ -17,7 +17,7 @@ package org.apache.spark.util.collection -import org.apache.spark.memory.TaskMemoryManager +import org.apache.spark.memory.{MemoryMode, TaskMemoryManager} import org.apache.spark.{Logging, SparkEnv} /** @@ -78,7 +78,8 @@ private[spark] trait Spillable[C] extends Logging { if (elementsRead % 32 == 0 && currentMemory >= myMemoryThreshold) { // Claim up to double our current memory from the shuffle memory pool val amountToRequest = 2 * currentMemory - myMemoryThreshold - val granted = taskMemoryManager.acquireExecutionMemory(amountToRequest, null) + val granted = + taskMemoryManager.acquireExecutionMemory(amountToRequest, MemoryMode.ON_HEAP, null) myMemoryThreshold += granted // If we were granted too little memory to grow further (either tryToAcquire returned 0, // or we already had more memory than myMemoryThreshold), spill the current collection @@ -107,7 +108,8 @@ private[spark] trait Spillable[C] extends Logging { */ def releaseMemory(): Unit = { // The amount we requested does not include the initial memory tracking threshold - taskMemoryManager.releaseExecutionMemory(myMemoryThreshold - initialMemoryThreshold, null) + taskMemoryManager.releaseExecutionMemory( + myMemoryThreshold - initialMemoryThreshold, MemoryMode.ON_HEAP, null) myMemoryThreshold = initialMemoryThreshold } diff --git a/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java b/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java index c731317395612..711eed0193bc0 100644 --- a/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java +++ b/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java @@ -28,8 +28,14 @@ public class TaskMemoryManagerSuite { @Test public void leakedPageMemoryIsDetected() { final TaskMemoryManager manager = new TaskMemoryManager( - new TestMemoryManager(new SparkConf().set("spark.unsafe.offHeap", "false")), 0); + new StaticMemoryManager( + new SparkConf().set("spark.unsafe.offHeap", "false"), + Long.MAX_VALUE, + Long.MAX_VALUE, + 1), + 0); manager.allocatePage(4096, null); // leak memory + Assert.assertEquals(4096, manager.getMemoryConsumptionForThisTask()); Assert.assertEquals(4096, manager.cleanUpAllAllocatedMemory()); } diff --git a/core/src/test/java/org/apache/spark/memory/TestMemoryConsumer.java b/core/src/test/java/org/apache/spark/memory/TestMemoryConsumer.java index 8ae3642738509..e6e16fff80401 100644 --- a/core/src/test/java/org/apache/spark/memory/TestMemoryConsumer.java +++ b/core/src/test/java/org/apache/spark/memory/TestMemoryConsumer.java @@ -32,13 +32,19 @@ public long spill(long size, MemoryConsumer trigger) throws IOException { } void use(long size) { - long got = taskMemoryManager.acquireExecutionMemory(size, this); + long got = taskMemoryManager.acquireExecutionMemory( + size, + taskMemoryManager.tungstenMemoryMode, + this); used += got; } void free(long size) { used -= size; - taskMemoryManager.releaseExecutionMemory(size, this); + taskMemoryManager.releaseExecutionMemory( + size, + taskMemoryManager.tungstenMemoryMode, + this); } } diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java index 4763395d7d401..0e0eca515afc1 100644 --- a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java @@ -423,7 +423,7 @@ public void writeEnoughRecordsToTriggerSortBufferExpansionAndSpill() throws Exce memoryManager.limit(UnsafeShuffleWriter.INITIAL_SORT_BUFFER_SIZE * 16); final UnsafeShuffleWriter writer = createWriter(false); final ArrayList> dataToWrite = new ArrayList<>(); - for (int i = 0; i < UnsafeShuffleWriter.INITIAL_SORT_BUFFER_SIZE; i++) { + for (int i = 0; i < UnsafeShuffleWriter.INITIAL_SORT_BUFFER_SIZE + 1; i++) { dataToWrite.add(new Tuple2(i, i)); } writer.write(dataToWrite.iterator()); diff --git a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java index 92bd45e5fa241..3bca790f30870 100644 --- a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java +++ b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java @@ -83,7 +83,9 @@ public OutputStream apply(OutputStream stream) { public void setup() { memoryManager = new TestMemoryManager( - new SparkConf().set("spark.unsafe.offHeap", "" + useOffHeapMemoryAllocator())); + new SparkConf() + .set("spark.unsafe.offHeap", "" + useOffHeapMemoryAllocator()) + .set("spark.memory.offHeapSize", "256mb")); taskMemoryManager = new TaskMemoryManager(memoryManager, 0); tempDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "unsafe-test"); diff --git a/core/src/test/scala/org/apache/spark/memory/MemoryManagerSuite.scala b/core/src/test/scala/org/apache/spark/memory/MemoryManagerSuite.scala index 4a9479cf490fb..f55d435fa33a6 100644 --- a/core/src/test/scala/org/apache/spark/memory/MemoryManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/memory/MemoryManagerSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.memory import java.util.concurrent.atomic.AtomicLong +import scala.collection.mutable import scala.concurrent.duration.Duration import scala.concurrent.{Await, ExecutionContext, Future} @@ -29,7 +30,7 @@ import org.mockito.stubbing.Answer import org.scalatest.time.SpanSugar._ import org.apache.spark.SparkFunSuite -import org.apache.spark.storage.MemoryStore +import org.apache.spark.storage.{BlockId, BlockStatus, MemoryStore, StorageLevel} /** @@ -78,7 +79,12 @@ private[memory] trait MemoryManagerSuite extends SparkFunSuite { require(args(numBytesPos).isInstanceOf[Long], s"bad test: expected ensureFreeSpace " + s"argument at index $numBytesPos to be a Long: ${args.mkString(", ")}") val numBytes = args(numBytesPos).asInstanceOf[Long] - mockEnsureFreeSpace(mm, numBytes) + val success = mockEnsureFreeSpace(mm, numBytes) + if (success) { + args.last.asInstanceOf[mutable.Buffer[(BlockId, BlockStatus)]].append( + (null, BlockStatus(StorageLevel.MEMORY_ONLY, numBytes, 0L, 0L))) + } + success } } } @@ -132,93 +138,95 @@ private[memory] trait MemoryManagerSuite extends SparkFunSuite { } /** - * Create a MemoryManager with the specified execution memory limit and no storage memory. + * Create a MemoryManager with the specified execution memory limits and no storage memory. */ - protected def createMemoryManager(maxExecutionMemory: Long): MemoryManager + protected def createMemoryManager( + maxOnHeapExecutionMemory: Long, + maxOffHeapExecutionMemory: Long = 0L): MemoryManager // -- Tests of sharing of execution memory between tasks ---------------------------------------- // Prior to Spark 1.6, these tests were part of ShuffleMemoryManagerSuite. implicit val ec = ExecutionContext.global - test("single task requesting execution memory") { + test("single task requesting on-heap execution memory") { val manager = createMemoryManager(1000L) val taskMemoryManager = new TaskMemoryManager(manager, 0) - assert(taskMemoryManager.acquireExecutionMemory(100L, null) === 100L) - assert(taskMemoryManager.acquireExecutionMemory(400L, null) === 400L) - assert(taskMemoryManager.acquireExecutionMemory(400L, null) === 400L) - assert(taskMemoryManager.acquireExecutionMemory(200L, null) === 100L) - assert(taskMemoryManager.acquireExecutionMemory(100L, null) === 0L) - assert(taskMemoryManager.acquireExecutionMemory(100L, null) === 0L) + assert(taskMemoryManager.acquireExecutionMemory(100L, MemoryMode.ON_HEAP, null) === 100L) + assert(taskMemoryManager.acquireExecutionMemory(400L, MemoryMode.ON_HEAP, null) === 400L) + assert(taskMemoryManager.acquireExecutionMemory(400L, MemoryMode.ON_HEAP, null) === 400L) + assert(taskMemoryManager.acquireExecutionMemory(200L, MemoryMode.ON_HEAP, null) === 100L) + assert(taskMemoryManager.acquireExecutionMemory(100L, MemoryMode.ON_HEAP, null) === 0L) + assert(taskMemoryManager.acquireExecutionMemory(100L, MemoryMode.ON_HEAP, null) === 0L) - taskMemoryManager.releaseExecutionMemory(500L, null) - assert(taskMemoryManager.acquireExecutionMemory(300L, null) === 300L) - assert(taskMemoryManager.acquireExecutionMemory(300L, null) === 200L) + taskMemoryManager.releaseExecutionMemory(500L, MemoryMode.ON_HEAP, null) + assert(taskMemoryManager.acquireExecutionMemory(300L, MemoryMode.ON_HEAP, null) === 300L) + assert(taskMemoryManager.acquireExecutionMemory(300L, MemoryMode.ON_HEAP, null) === 200L) taskMemoryManager.cleanUpAllAllocatedMemory() - assert(taskMemoryManager.acquireExecutionMemory(1000L, null) === 1000L) - assert(taskMemoryManager.acquireExecutionMemory(100L, null) === 0L) + assert(taskMemoryManager.acquireExecutionMemory(1000L, MemoryMode.ON_HEAP, null) === 1000L) + assert(taskMemoryManager.acquireExecutionMemory(100L, MemoryMode.ON_HEAP, null) === 0L) } - test("two tasks requesting full execution memory") { + test("two tasks requesting full on-heap execution memory") { val memoryManager = createMemoryManager(1000L) val t1MemManager = new TaskMemoryManager(memoryManager, 1) val t2MemManager = new TaskMemoryManager(memoryManager, 2) val futureTimeout: Duration = 20.seconds // Have both tasks request 500 bytes, then wait until both requests have been granted: - val t1Result1 = Future { t1MemManager.acquireExecutionMemory(500L, null) } - val t2Result1 = Future { t2MemManager.acquireExecutionMemory(500L, null) } + val t1Result1 = Future { t1MemManager.acquireExecutionMemory(500L, MemoryMode.ON_HEAP, null) } + val t2Result1 = Future { t2MemManager.acquireExecutionMemory(500L, MemoryMode.ON_HEAP, null) } assert(Await.result(t1Result1, futureTimeout) === 500L) assert(Await.result(t2Result1, futureTimeout) === 500L) // Have both tasks each request 500 bytes more; both should immediately return 0 as they are // both now at 1 / N - val t1Result2 = Future { t1MemManager.acquireExecutionMemory(500L, null) } - val t2Result2 = Future { t2MemManager.acquireExecutionMemory(500L, null) } + val t1Result2 = Future { t1MemManager.acquireExecutionMemory(500L, MemoryMode.ON_HEAP, null) } + val t2Result2 = Future { t2MemManager.acquireExecutionMemory(500L, MemoryMode.ON_HEAP, null) } assert(Await.result(t1Result2, 200.millis) === 0L) assert(Await.result(t2Result2, 200.millis) === 0L) } - test("two tasks cannot grow past 1 / N of execution memory") { + test("two tasks cannot grow past 1 / N of on-heap execution memory") { val memoryManager = createMemoryManager(1000L) val t1MemManager = new TaskMemoryManager(memoryManager, 1) val t2MemManager = new TaskMemoryManager(memoryManager, 2) val futureTimeout: Duration = 20.seconds // Have both tasks request 250 bytes, then wait until both requests have been granted: - val t1Result1 = Future { t1MemManager.acquireExecutionMemory(250L, null) } - val t2Result1 = Future { t2MemManager.acquireExecutionMemory(250L, null) } + val t1Result1 = Future { t1MemManager.acquireExecutionMemory(250L, MemoryMode.ON_HEAP, null) } + val t2Result1 = Future { t2MemManager.acquireExecutionMemory(250L, MemoryMode.ON_HEAP, null) } assert(Await.result(t1Result1, futureTimeout) === 250L) assert(Await.result(t2Result1, futureTimeout) === 250L) // Have both tasks each request 500 bytes more. // We should only grant 250 bytes to each of them on this second request - val t1Result2 = Future { t1MemManager.acquireExecutionMemory(500L, null) } - val t2Result2 = Future { t2MemManager.acquireExecutionMemory(500L, null) } + val t1Result2 = Future { t1MemManager.acquireExecutionMemory(500L, MemoryMode.ON_HEAP, null) } + val t2Result2 = Future { t2MemManager.acquireExecutionMemory(500L, MemoryMode.ON_HEAP, null) } assert(Await.result(t1Result2, futureTimeout) === 250L) assert(Await.result(t2Result2, futureTimeout) === 250L) } - test("tasks can block to get at least 1 / 2N of execution memory") { + test("tasks can block to get at least 1 / 2N of on-heap execution memory") { val memoryManager = createMemoryManager(1000L) val t1MemManager = new TaskMemoryManager(memoryManager, 1) val t2MemManager = new TaskMemoryManager(memoryManager, 2) val futureTimeout: Duration = 20.seconds // t1 grabs 1000 bytes and then waits until t2 is ready to make a request. - val t1Result1 = Future { t1MemManager.acquireExecutionMemory(1000L, null) } + val t1Result1 = Future { t1MemManager.acquireExecutionMemory(1000L, MemoryMode.ON_HEAP, null) } assert(Await.result(t1Result1, futureTimeout) === 1000L) - val t2Result1 = Future { t2MemManager.acquireExecutionMemory(250L, null) } + val t2Result1 = Future { t2MemManager.acquireExecutionMemory(250L, MemoryMode.ON_HEAP, null) } // Make sure that t2 didn't grab the memory right away. This is hacky but it would be difficult // to make sure the other thread blocks for some time otherwise. Thread.sleep(300) - t1MemManager.releaseExecutionMemory(250L, null) + t1MemManager.releaseExecutionMemory(250L, MemoryMode.ON_HEAP, null) // The memory freed from t1 should now be granted to t2. assert(Await.result(t2Result1, futureTimeout) === 250L) // Further requests by t2 should be denied immediately because it now has 1 / 2N of the memory. - val t2Result2 = Future { t2MemManager.acquireExecutionMemory(100L, null) } + val t2Result2 = Future { t2MemManager.acquireExecutionMemory(100L, MemoryMode.ON_HEAP, null) } assert(Await.result(t2Result2, 200.millis) === 0L) } @@ -229,18 +237,18 @@ private[memory] trait MemoryManagerSuite extends SparkFunSuite { val futureTimeout: Duration = 20.seconds // t1 grabs 1000 bytes and then waits until t2 is ready to make a request. - val t1Result1 = Future { t1MemManager.acquireExecutionMemory(1000L, null) } + val t1Result1 = Future { t1MemManager.acquireExecutionMemory(1000L, MemoryMode.ON_HEAP, null) } assert(Await.result(t1Result1, futureTimeout) === 1000L) - val t2Result1 = Future { t2MemManager.acquireExecutionMemory(500L, null) } + val t2Result1 = Future { t2MemManager.acquireExecutionMemory(500L, MemoryMode.ON_HEAP, null) } // Make sure that t2 didn't grab the memory right away. This is hacky but it would be difficult // to make sure the other thread blocks for some time otherwise. Thread.sleep(300) // t1 releases all of its memory, so t2 should be able to grab all of the memory t1MemManager.cleanUpAllAllocatedMemory() assert(Await.result(t2Result1, futureTimeout) === 500L) - val t2Result2 = Future { t2MemManager.acquireExecutionMemory(500L, null) } + val t2Result2 = Future { t2MemManager.acquireExecutionMemory(500L, MemoryMode.ON_HEAP, null) } assert(Await.result(t2Result2, futureTimeout) === 500L) - val t2Result3 = Future { t2MemManager.acquireExecutionMemory(500L, null) } + val t2Result3 = Future { t2MemManager.acquireExecutionMemory(500L, MemoryMode.ON_HEAP, null) } assert(Await.result(t2Result3, 200.millis) === 0L) } @@ -251,15 +259,35 @@ private[memory] trait MemoryManagerSuite extends SparkFunSuite { val t2MemManager = new TaskMemoryManager(memoryManager, 2) val futureTimeout: Duration = 20.seconds - val t1Result1 = Future { t1MemManager.acquireExecutionMemory(700L, null) } + val t1Result1 = Future { t1MemManager.acquireExecutionMemory(700L, MemoryMode.ON_HEAP, null) } assert(Await.result(t1Result1, futureTimeout) === 700L) - val t2Result1 = Future { t2MemManager.acquireExecutionMemory(300L, null) } + val t2Result1 = Future { t2MemManager.acquireExecutionMemory(300L, MemoryMode.ON_HEAP, null) } assert(Await.result(t2Result1, futureTimeout) === 300L) - val t1Result2 = Future { t1MemManager.acquireExecutionMemory(300L, null) } + val t1Result2 = Future { t1MemManager.acquireExecutionMemory(300L, MemoryMode.ON_HEAP, null) } assert(Await.result(t1Result2, 200.millis) === 0L) } + + test("off-heap execution allocations cannot exceed limit") { + val memoryManager = createMemoryManager( + maxOnHeapExecutionMemory = 0L, + maxOffHeapExecutionMemory = 1000L) + + val tMemManager = new TaskMemoryManager(memoryManager, 1) + val result1 = Future { tMemManager.acquireExecutionMemory(1000L, MemoryMode.OFF_HEAP, null) } + assert(Await.result(result1, 200.millis) === 1000L) + assert(tMemManager.getMemoryConsumptionForThisTask === 1000L) + + val result2 = Future { tMemManager.acquireExecutionMemory(300L, MemoryMode.OFF_HEAP, null) } + assert(Await.result(result2, 200.millis) === 0L) + + assert(tMemManager.getMemoryConsumptionForThisTask === 1000L) + tMemManager.releaseExecutionMemory(500L, MemoryMode.OFF_HEAP, null) + assert(tMemManager.getMemoryConsumptionForThisTask === 500L) + tMemManager.releaseExecutionMemory(500L, MemoryMode.OFF_HEAP, null) + assert(tMemManager.getMemoryConsumptionForThisTask === 0L) + } } private object MemoryManagerSuite { diff --git a/core/src/test/scala/org/apache/spark/memory/StaticMemoryManagerSuite.scala b/core/src/test/scala/org/apache/spark/memory/StaticMemoryManagerSuite.scala index 885c450d6d4f5..54cb28c389c2f 100644 --- a/core/src/test/scala/org/apache/spark/memory/StaticMemoryManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/memory/StaticMemoryManagerSuite.scala @@ -24,7 +24,6 @@ import org.mockito.Mockito.when import org.apache.spark.SparkConf import org.apache.spark.storage.{BlockId, BlockStatus, MemoryStore, TestBlockId} - class StaticMemoryManagerSuite extends MemoryManagerSuite { private val conf = new SparkConf().set("spark.storage.unrollFraction", "0.4") private val evictedBlocks = new ArrayBuffer[(BlockId, BlockStatus)] @@ -36,38 +35,47 @@ class StaticMemoryManagerSuite extends MemoryManagerSuite { maxExecutionMem: Long, maxStorageMem: Long): (StaticMemoryManager, MemoryStore) = { val mm = new StaticMemoryManager( - conf, maxExecutionMemory = maxExecutionMem, maxStorageMemory = maxStorageMem, numCores = 1) + conf, + maxOnHeapExecutionMemory = maxExecutionMem, + maxStorageMemory = maxStorageMem, + numCores = 1) val ms = makeMemoryStore(mm) (mm, ms) } - override protected def createMemoryManager(maxMemory: Long): MemoryManager = { + override protected def createMemoryManager( + maxOnHeapExecutionMemory: Long, + maxOffHeapExecutionMemory: Long): StaticMemoryManager = { new StaticMemoryManager( - conf, - maxExecutionMemory = maxMemory, + conf.clone + .set("spark.memory.fraction", "1") + .set("spark.testing.memory", maxOnHeapExecutionMemory.toString) + .set("spark.memory.offHeapSize", maxOffHeapExecutionMemory.toString), + maxOnHeapExecutionMemory = maxOnHeapExecutionMemory, maxStorageMemory = 0, numCores = 1) } test("basic execution memory") { val maxExecutionMem = 1000L + val taskAttemptId = 0L val (mm, _) = makeThings(maxExecutionMem, Long.MaxValue) assert(mm.executionMemoryUsed === 0L) - assert(mm.doAcquireExecutionMemory(10L, evictedBlocks) === 10L) + assert(mm.acquireExecutionMemory(10L, taskAttemptId, MemoryMode.ON_HEAP) === 10L) assert(mm.executionMemoryUsed === 10L) - assert(mm.doAcquireExecutionMemory(100L, evictedBlocks) === 100L) + assert(mm.acquireExecutionMemory(100L, taskAttemptId, MemoryMode.ON_HEAP) === 100L) // Acquire up to the max - assert(mm.doAcquireExecutionMemory(1000L, evictedBlocks) === 890L) + assert(mm.acquireExecutionMemory(1000L, taskAttemptId, MemoryMode.ON_HEAP) === 890L) assert(mm.executionMemoryUsed === maxExecutionMem) - assert(mm.doAcquireExecutionMemory(1L, evictedBlocks) === 0L) + assert(mm.acquireExecutionMemory(1L, taskAttemptId, MemoryMode.ON_HEAP) === 0L) assert(mm.executionMemoryUsed === maxExecutionMem) - mm.releaseExecutionMemory(800L) + mm.releaseExecutionMemory(800L, taskAttemptId, MemoryMode.ON_HEAP) assert(mm.executionMemoryUsed === 200L) // Acquire after release - assert(mm.doAcquireExecutionMemory(1L, evictedBlocks) === 1L) + assert(mm.acquireExecutionMemory(1L, taskAttemptId, MemoryMode.ON_HEAP) === 1L) assert(mm.executionMemoryUsed === 201L) // Release beyond what was acquired - mm.releaseExecutionMemory(maxExecutionMem) + mm.releaseExecutionMemory(maxExecutionMem, taskAttemptId, MemoryMode.ON_HEAP) assert(mm.executionMemoryUsed === 0L) } @@ -113,13 +121,14 @@ class StaticMemoryManagerSuite extends MemoryManagerSuite { test("execution and storage isolation") { val maxExecutionMem = 200L val maxStorageMem = 1000L + val taskAttemptId = 0L val dummyBlock = TestBlockId("ain't nobody love like you do") val (mm, ms) = makeThings(maxExecutionMem, maxStorageMem) // Only execution memory should increase - assert(mm.doAcquireExecutionMemory(100L, evictedBlocks) === 100L) + assert(mm.acquireExecutionMemory(100L, taskAttemptId, MemoryMode.ON_HEAP) === 100L) assert(mm.storageMemoryUsed === 0L) assert(mm.executionMemoryUsed === 100L) - assert(mm.doAcquireExecutionMemory(1000L, evictedBlocks) === 100L) + assert(mm.acquireExecutionMemory(1000L, taskAttemptId, MemoryMode.ON_HEAP) === 100L) assert(mm.storageMemoryUsed === 0L) assert(mm.executionMemoryUsed === 200L) // Only storage memory should increase @@ -128,7 +137,7 @@ class StaticMemoryManagerSuite extends MemoryManagerSuite { assert(mm.storageMemoryUsed === 50L) assert(mm.executionMemoryUsed === 200L) // Only execution memory should be released - mm.releaseExecutionMemory(133L) + mm.releaseExecutionMemory(133L, taskAttemptId, MemoryMode.ON_HEAP) assert(mm.storageMemoryUsed === 50L) assert(mm.executionMemoryUsed === 67L) // Only storage memory should be released diff --git a/core/src/test/scala/org/apache/spark/memory/TestMemoryManager.scala b/core/src/test/scala/org/apache/spark/memory/TestMemoryManager.scala index 77e43554ee27c..0706a6e45de8f 100644 --- a/core/src/test/scala/org/apache/spark/memory/TestMemoryManager.scala +++ b/core/src/test/scala/org/apache/spark/memory/TestMemoryManager.scala @@ -22,19 +22,20 @@ import scala.collection.mutable import org.apache.spark.SparkConf import org.apache.spark.storage.{BlockStatus, BlockId} -class TestMemoryManager(conf: SparkConf) extends MemoryManager(conf, numCores = 1) { - private[memory] override def doAcquireExecutionMemory( +class TestMemoryManager(conf: SparkConf) + extends MemoryManager(conf, numCores = 1, Long.MaxValue, Long.MaxValue) { + + override private[memory] def acquireExecutionMemory( numBytes: Long, - evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Long = synchronized { + taskAttemptId: Long, + memoryMode: MemoryMode): Long = { if (oomOnce) { oomOnce = false 0 } else if (available >= numBytes) { - _executionMemoryUsed += numBytes // To suppress warnings when freeing unallocated memory available -= numBytes numBytes } else { - _executionMemoryUsed += available val grant = available available = 0 grant @@ -48,12 +49,13 @@ class TestMemoryManager(conf: SparkConf) extends MemoryManager(conf, numCores = blockId: BlockId, numBytes: Long, evictedBlocks: mutable.Buffer[(BlockId, BlockStatus)]): Boolean = true - override def releaseExecutionMemory(numBytes: Long): Unit = { + override def releaseStorageMemory(numBytes: Long): Unit = {} + override private[memory] def releaseExecutionMemory( + numBytes: Long, + taskAttemptId: Long, + memoryMode: MemoryMode): Unit = { available += numBytes - _executionMemoryUsed -= numBytes } - override def releaseStorageMemory(numBytes: Long): Unit = {} - override def maxExecutionMemory: Long = Long.MaxValue override def maxStorageMemory: Long = Long.MaxValue private var oomOnce = false diff --git a/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala b/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala index 0c97f2bd89651..8cebe81c3bfff 100644 --- a/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala @@ -24,57 +24,52 @@ import org.scalatest.PrivateMethodTester import org.apache.spark.SparkConf import org.apache.spark.storage.{BlockId, BlockStatus, MemoryStore, TestBlockId} - class UnifiedMemoryManagerSuite extends MemoryManagerSuite with PrivateMethodTester { - private val conf = new SparkConf().set("spark.memory.storageFraction", "0.5") private val dummyBlock = TestBlockId("--") private val evictedBlocks = new ArrayBuffer[(BlockId, BlockStatus)] + private val storageFraction: Double = 0.5 + /** * Make a [[UnifiedMemoryManager]] and a [[MemoryStore]] with limited class dependencies. */ private def makeThings(maxMemory: Long): (UnifiedMemoryManager, MemoryStore) = { - val mm = new UnifiedMemoryManager(conf, maxMemory, numCores = 1) + val mm = createMemoryManager(maxMemory) val ms = makeMemoryStore(mm) (mm, ms) } - override protected def createMemoryManager(maxMemory: Long): MemoryManager = { - new UnifiedMemoryManager(conf, maxMemory, numCores = 1) - } - - private def getStorageRegionSize(mm: UnifiedMemoryManager): Long = { - mm invokePrivate PrivateMethod[Long]('storageRegionSize)() - } - - test("storage region size") { - val maxMemory = 1000L - val (mm, _) = makeThings(maxMemory) - val storageFraction = conf.get("spark.memory.storageFraction").toDouble - val expectedStorageRegionSize = maxMemory * storageFraction - val actualStorageRegionSize = getStorageRegionSize(mm) - assert(expectedStorageRegionSize === actualStorageRegionSize) + override protected def createMemoryManager( + maxOnHeapExecutionMemory: Long, + maxOffHeapExecutionMemory: Long): UnifiedMemoryManager = { + val conf = new SparkConf() + .set("spark.memory.fraction", "1") + .set("spark.testing.memory", maxOnHeapExecutionMemory.toString) + .set("spark.memory.offHeapSize", maxOffHeapExecutionMemory.toString) + .set("spark.memory.storageFraction", storageFraction.toString) + UnifiedMemoryManager(conf, numCores = 1) } test("basic execution memory") { val maxMemory = 1000L + val taskAttemptId = 0L val (mm, _) = makeThings(maxMemory) assert(mm.executionMemoryUsed === 0L) - assert(mm.doAcquireExecutionMemory(10L, evictedBlocks) === 10L) + assert(mm.acquireExecutionMemory(10L, taskAttemptId, MemoryMode.ON_HEAP) === 10L) assert(mm.executionMemoryUsed === 10L) - assert(mm.doAcquireExecutionMemory(100L, evictedBlocks) === 100L) + assert(mm.acquireExecutionMemory(100L, taskAttemptId, MemoryMode.ON_HEAP) === 100L) // Acquire up to the max - assert(mm.doAcquireExecutionMemory(1000L, evictedBlocks) === 890L) + assert(mm.acquireExecutionMemory(1000L, taskAttemptId, MemoryMode.ON_HEAP) === 890L) assert(mm.executionMemoryUsed === maxMemory) - assert(mm.doAcquireExecutionMemory(1L, evictedBlocks) === 0L) + assert(mm.acquireExecutionMemory(1L, taskAttemptId, MemoryMode.ON_HEAP) === 0L) assert(mm.executionMemoryUsed === maxMemory) - mm.releaseExecutionMemory(800L) + mm.releaseExecutionMemory(800L, taskAttemptId, MemoryMode.ON_HEAP) assert(mm.executionMemoryUsed === 200L) // Acquire after release - assert(mm.doAcquireExecutionMemory(1L, evictedBlocks) === 1L) + assert(mm.acquireExecutionMemory(1L, taskAttemptId, MemoryMode.ON_HEAP) === 1L) assert(mm.executionMemoryUsed === 201L) // Release beyond what was acquired - mm.releaseExecutionMemory(maxMemory) + mm.releaseExecutionMemory(maxMemory, taskAttemptId, MemoryMode.ON_HEAP) assert(mm.executionMemoryUsed === 0L) } @@ -118,44 +113,34 @@ class UnifiedMemoryManagerSuite extends MemoryManagerSuite with PrivateMethodTes test("execution evicts storage") { val maxMemory = 1000L + val taskAttemptId = 0L val (mm, ms) = makeThings(maxMemory) - // First, ensure the test classes are set up as expected - val expectedStorageRegionSize = 500L - val expectedExecutionRegionSize = 500L - val storageRegionSize = getStorageRegionSize(mm) - val executionRegionSize = maxMemory - expectedStorageRegionSize - require(storageRegionSize === expectedStorageRegionSize, - "bad test: storage region size is unexpected") - require(executionRegionSize === expectedExecutionRegionSize, - "bad test: storage region size is unexpected") // Acquire enough storage memory to exceed the storage region assert(mm.acquireStorageMemory(dummyBlock, 750L, evictedBlocks)) assertEnsureFreeSpaceCalled(ms, 750L) assert(mm.executionMemoryUsed === 0L) assert(mm.storageMemoryUsed === 750L) - require(mm.storageMemoryUsed > storageRegionSize, - s"bad test: storage memory used should exceed the storage region") // Execution needs to request 250 bytes to evict storage memory - assert(mm.doAcquireExecutionMemory(100L, evictedBlocks) === 100L) + assert(mm.acquireExecutionMemory(100L, taskAttemptId, MemoryMode.ON_HEAP) === 100L) assert(mm.executionMemoryUsed === 100L) assert(mm.storageMemoryUsed === 750L) assertEnsureFreeSpaceNotCalled(ms) // Execution wants 200 bytes but only 150 are free, so storage is evicted - assert(mm.doAcquireExecutionMemory(200L, evictedBlocks) === 200L) - assertEnsureFreeSpaceCalled(ms, 200L) + assert(mm.acquireExecutionMemory(200L, taskAttemptId, MemoryMode.ON_HEAP) === 200L) + assert(mm.executionMemoryUsed === 300L) + assertEnsureFreeSpaceCalled(ms, 50L) assert(mm.executionMemoryUsed === 300L) mm.releaseAllStorageMemory() - require(mm.executionMemoryUsed < executionRegionSize, - s"bad test: execution memory used should be within the execution region") + require(mm.executionMemoryUsed === 300L) require(mm.storageMemoryUsed === 0, "bad test: all storage memory should have been released") // Acquire some storage memory again, but this time keep it within the storage region assert(mm.acquireStorageMemory(dummyBlock, 400L, evictedBlocks)) assertEnsureFreeSpaceCalled(ms, 400L) - require(mm.storageMemoryUsed < storageRegionSize, - s"bad test: storage memory used should be within the storage region") + assert(mm.storageMemoryUsed === 400L) + assert(mm.executionMemoryUsed === 300L) // Execution cannot evict storage because the latter is within the storage fraction, // so grant only what's remaining without evicting anything, i.e. 1000 - 300 - 400 = 300 - assert(mm.doAcquireExecutionMemory(400L, evictedBlocks) === 300L) + assert(mm.acquireExecutionMemory(400L, taskAttemptId, MemoryMode.ON_HEAP) === 300L) assert(mm.executionMemoryUsed === 600L) assert(mm.storageMemoryUsed === 400L) assertEnsureFreeSpaceNotCalled(ms) @@ -163,23 +148,13 @@ class UnifiedMemoryManagerSuite extends MemoryManagerSuite with PrivateMethodTes test("storage does not evict execution") { val maxMemory = 1000L + val taskAttemptId = 0L val (mm, ms) = makeThings(maxMemory) - // First, ensure the test classes are set up as expected - val expectedStorageRegionSize = 500L - val expectedExecutionRegionSize = 500L - val storageRegionSize = getStorageRegionSize(mm) - val executionRegionSize = maxMemory - expectedStorageRegionSize - require(storageRegionSize === expectedStorageRegionSize, - "bad test: storage region size is unexpected") - require(executionRegionSize === expectedExecutionRegionSize, - "bad test: storage region size is unexpected") // Acquire enough execution memory to exceed the execution region - assert(mm.doAcquireExecutionMemory(800L, evictedBlocks) === 800L) + assert(mm.acquireExecutionMemory(800L, taskAttemptId, MemoryMode.ON_HEAP) === 800L) assert(mm.executionMemoryUsed === 800L) assert(mm.storageMemoryUsed === 0L) assertEnsureFreeSpaceNotCalled(ms) - require(mm.executionMemoryUsed > executionRegionSize, - s"bad test: execution memory used should exceed the execution region") // Storage should not be able to evict execution assert(mm.acquireStorageMemory(dummyBlock, 100L, evictedBlocks)) assert(mm.executionMemoryUsed === 800L) @@ -189,15 +164,13 @@ class UnifiedMemoryManagerSuite extends MemoryManagerSuite with PrivateMethodTes assert(mm.executionMemoryUsed === 800L) assert(mm.storageMemoryUsed === 100L) assertEnsureFreeSpaceCalled(ms, 250L) - mm.releaseExecutionMemory(maxMemory) + mm.releaseExecutionMemory(maxMemory, taskAttemptId, MemoryMode.ON_HEAP) mm.releaseStorageMemory(maxMemory) // Acquire some execution memory again, but this time keep it within the execution region - assert(mm.doAcquireExecutionMemory(200L, evictedBlocks) === 200L) + assert(mm.acquireExecutionMemory(200L, taskAttemptId, MemoryMode.ON_HEAP) === 200L) assert(mm.executionMemoryUsed === 200L) assert(mm.storageMemoryUsed === 0L) assertEnsureFreeSpaceNotCalled(ms) - require(mm.executionMemoryUsed < executionRegionSize, - s"bad test: execution memory used should be within the execution region") // Storage should still not be able to evict execution assert(mm.acquireStorageMemory(dummyBlock, 750L, evictedBlocks)) assert(mm.executionMemoryUsed === 200L) diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index d49015afcd594..53991d8a1aede 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -825,7 +825,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE val transfer = new NettyBlockTransferService(conf, securityMgr, numCores = 1) val memoryManager = new StaticMemoryManager( conf, - maxExecutionMemory = Long.MaxValue, + maxOnHeapExecutionMemory = Long.MaxValue, maxStorageMemory = 1200, numCores = 1) store = new BlockManager(SparkContext.DRIVER_IDENTIFIER, rpcEnv, master, From 7f741905b06ed6d3dfbff6db41a3355dab71aa3c Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Sat, 7 Nov 2015 05:35:53 +0100 Subject: [PATCH 29/88] [SPARK-11112] DAG visualization: display RDD callsite screen shot 2015-11-01 at 9 42 33 am mateiz sarutak Author: Andrew Or Closes #9398 from andrewor14/rdd-callsite. --- .../apache/spark/ui/static/spark-dag-viz.css | 4 ++ .../org/apache/spark/storage/RDDInfo.scala | 16 +++++++- .../spark/ui/scope/RDDOperationGraph.scala | 10 +++-- .../org/apache/spark/util/JsonProtocol.scala | 17 ++++++++- .../scala/org/apache/spark/util/Utils.scala | 1 + .../org/apache/spark/ui/UISeleniumSuite.scala | 14 +++---- .../apache/spark/util/JsonProtocolSuite.scala | 37 ++++++++++++++++--- 7 files changed, 79 insertions(+), 20 deletions(-) diff --git a/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.css b/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.css index 3b4ae2ed354b8..9cc5c79f67346 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.css +++ b/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.css @@ -122,3 +122,7 @@ stroke: #52C366; stroke-width: 2px; } + +.tooltip-inner { + white-space: pre-wrap; +} diff --git a/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala b/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala index 96062626b5045..3fa209b924170 100644 --- a/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala +++ b/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala @@ -19,7 +19,7 @@ package org.apache.spark.storage import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.{RDDOperationScope, RDD} -import org.apache.spark.util.Utils +import org.apache.spark.util.{CallSite, Utils} @DeveloperApi class RDDInfo( @@ -28,9 +28,20 @@ class RDDInfo( val numPartitions: Int, var storageLevel: StorageLevel, val parentIds: Seq[Int], + val callSite: CallSite, val scope: Option[RDDOperationScope] = None) extends Ordered[RDDInfo] { + def this( + id: Int, + name: String, + numPartitions: Int, + storageLevel: StorageLevel, + parentIds: Seq[Int], + scope: Option[RDDOperationScope] = None) { + this(id, name, numPartitions, storageLevel, parentIds, CallSite.empty, scope) + } + var numCachedPartitions = 0 var memSize = 0L var diskSize = 0L @@ -56,6 +67,7 @@ private[spark] object RDDInfo { def fromRdd(rdd: RDD[_]): RDDInfo = { val rddName = Option(rdd.name).getOrElse(Utils.getFormattedClassName(rdd)) val parentIds = rdd.dependencies.map(_.rdd.id) - new RDDInfo(rdd.id, rddName, rdd.partitions.length, rdd.getStorageLevel, parentIds, rdd.scope) + new RDDInfo(rdd.id, rddName, rdd.partitions.length, + rdd.getStorageLevel, parentIds, rdd.creationSite, rdd.scope) } } diff --git a/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraph.scala b/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraph.scala index 81f168a447ead..24274562657b3 100644 --- a/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraph.scala +++ b/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraph.scala @@ -23,6 +23,7 @@ import scala.collection.mutable.{StringBuilder, ListBuffer} import org.apache.spark.Logging import org.apache.spark.scheduler.StageInfo import org.apache.spark.storage.StorageLevel +import org.apache.spark.util.CallSite /** * A representation of a generic cluster graph used for storing information on RDD operations. @@ -38,7 +39,7 @@ private[ui] case class RDDOperationGraph( rootCluster: RDDOperationCluster) /** A node in an RDDOperationGraph. This represents an RDD. */ -private[ui] case class RDDOperationNode(id: Int, name: String, cached: Boolean) +private[ui] case class RDDOperationNode(id: Int, name: String, cached: Boolean, callsite: CallSite) /** * A directed edge connecting two nodes in an RDDOperationGraph. @@ -104,8 +105,8 @@ private[ui] object RDDOperationGraph extends Logging { edges ++= rdd.parentIds.map { parentId => RDDOperationEdge(parentId, rdd.id) } // TODO: differentiate between the intention to cache an RDD and whether it's actually cached - val node = nodes.getOrElseUpdate( - rdd.id, RDDOperationNode(rdd.id, rdd.name, rdd.storageLevel != StorageLevel.NONE)) + val node = nodes.getOrElseUpdate(rdd.id, RDDOperationNode( + rdd.id, rdd.name, rdd.storageLevel != StorageLevel.NONE, rdd.callSite)) if (rdd.scope.isEmpty) { // This RDD has no encompassing scope, so we put it directly in the root cluster @@ -177,7 +178,8 @@ private[ui] object RDDOperationGraph extends Logging { /** Return the dot representation of a node in an RDDOperationGraph. */ private def makeDotNode(node: RDDOperationNode): String = { - s"""${node.id} [label="${node.name} [${node.id}]"]""" + val label = s"${node.name} [${node.id}]\n${node.callsite.shortForm}" + s"""${node.id} [label="$label"]""" } /** Update the dot representation of the RDDOperationGraph in cluster to subgraph. */ diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index ee2eb58cf5e2a..c9beeb25e05af 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -398,6 +398,7 @@ private[spark] object JsonProtocol { ("RDD ID" -> rddInfo.id) ~ ("Name" -> rddInfo.name) ~ ("Scope" -> rddInfo.scope.map(_.toJson)) ~ + ("Callsite" -> callsiteToJson(rddInfo.callSite)) ~ ("Parent IDs" -> parentIds) ~ ("Storage Level" -> storageLevel) ~ ("Number of Partitions" -> rddInfo.numPartitions) ~ @@ -407,6 +408,11 @@ private[spark] object JsonProtocol { ("Disk Size" -> rddInfo.diskSize) } + def callsiteToJson(callsite: CallSite): JValue = { + ("Short Form" -> callsite.shortForm) ~ + ("Long Form" -> callsite.longForm) + } + def storageLevelToJson(storageLevel: StorageLevel): JValue = { ("Use Disk" -> storageLevel.useDisk) ~ ("Use Memory" -> storageLevel.useMemory) ~ @@ -851,6 +857,9 @@ private[spark] object JsonProtocol { val scope = Utils.jsonOption(json \ "Scope") .map(_.extract[String]) .map(RDDOperationScope.fromJson) + val callsite = Utils.jsonOption(json \ "Callsite") + .map(callsiteFromJson) + .getOrElse(CallSite.empty) val parentIds = Utils.jsonOption(json \ "Parent IDs") .map { l => l.extract[List[JValue]].map(_.extract[Int]) } .getOrElse(Seq.empty) @@ -863,7 +872,7 @@ private[spark] object JsonProtocol { .getOrElse(json \ "Tachyon Size").extract[Long] val diskSize = (json \ "Disk Size").extract[Long] - val rddInfo = new RDDInfo(rddId, name, numPartitions, storageLevel, parentIds, scope) + val rddInfo = new RDDInfo(rddId, name, numPartitions, storageLevel, parentIds, callsite, scope) rddInfo.numCachedPartitions = numCachedPartitions rddInfo.memSize = memSize rddInfo.externalBlockStoreSize = externalBlockStoreSize @@ -871,6 +880,12 @@ private[spark] object JsonProtocol { rddInfo } + def callsiteFromJson(json: JValue): CallSite = { + val shortForm = (json \ "Short Form").extract[String] + val longForm = (json \ "Long Form").extract[String] + CallSite(shortForm, longForm) + } + def storageLevelFromJson(json: JValue): StorageLevel = { val useDisk = (json \ "Use Disk").extract[Boolean] val useMemory = (json \ "Use Memory").extract[Boolean] diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 5a976ee839b1e..316c194ff3454 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -57,6 +57,7 @@ private[spark] case class CallSite(shortForm: String, longForm: String) private[spark] object CallSite { val SHORT_FORM = "callSite.short" val LONG_FORM = "callSite.long" + val empty = CallSite("", "") } /** diff --git a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala index 18eec7da9763e..ceecfd665bf87 100644 --- a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala @@ -615,29 +615,29 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B assert(stage0.contains("digraph G {\n subgraph clusterstage_0 {\n " + "label="Stage 0";\n subgraph ")) assert(stage0.contains("{\n label="parallelize";\n " + - "0 [label="ParallelCollectionRDD [0]"];\n }")) + "0 [label="ParallelCollectionRDD [0]")) assert(stage0.contains("{\n label="map";\n " + - "1 [label="MapPartitionsRDD [1]"];\n }")) + "1 [label="MapPartitionsRDD [1]")) assert(stage0.contains("{\n label="groupBy";\n " + - "2 [label="MapPartitionsRDD [2]"];\n }")) + "2 [label="MapPartitionsRDD [2]")) val stage1 = Source.fromURL(sc.ui.get.appUIAddress + "/stages/stage/?id=1&attempt=0&expandDagViz=true").mkString assert(stage1.contains("digraph G {\n subgraph clusterstage_1 {\n " + "label="Stage 1";\n subgraph ")) assert(stage1.contains("{\n label="groupBy";\n " + - "3 [label="ShuffledRDD [3]"];\n }")) + "3 [label="ShuffledRDD [3]")) assert(stage1.contains("{\n label="map";\n " + - "4 [label="MapPartitionsRDD [4]"];\n }")) + "4 [label="MapPartitionsRDD [4]")) assert(stage1.contains("{\n label="groupBy";\n " + - "5 [label="MapPartitionsRDD [5]"];\n }")) + "5 [label="MapPartitionsRDD [5]")) val stage2 = Source.fromURL(sc.ui.get.appUIAddress + "/stages/stage/?id=2&attempt=0&expandDagViz=true").mkString assert(stage2.contains("digraph G {\n subgraph clusterstage_2 {\n " + "label="Stage 2";\n subgraph ")) assert(stage2.contains("{\n label="groupBy";\n " + - "6 [label="ShuffledRDD [6]"];\n }")) + "6 [label="ShuffledRDD [6]")) } } diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index 953456c2caa89..3f94ef7041914 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -111,6 +111,7 @@ class JsonProtocolSuite extends SparkFunSuite { test("Dependent Classes") { val logUrlMap = Map("stderr" -> "mystderr", "stdout" -> "mystdout").toMap testRDDInfo(makeRddInfo(2, 3, 4, 5L, 6L)) + testCallsite(CallSite("happy", "birthday")) testStageInfo(makeStageInfo(10, 20, 30, 40L, 50L)) testTaskInfo(makeTaskInfo(999L, 888, 55, 777L, false)) testTaskMetrics(makeTaskMetrics( @@ -163,6 +164,10 @@ class JsonProtocolSuite extends SparkFunSuite { testBlockId(StreamBlockId(1, 2L)) } + /* ============================== * + | Backward compatibility tests | + * ============================== */ + test("ExceptionFailure backward compatibility") { val exceptionFailure = ExceptionFailure("To be", "or not to be", stackTrace, null, None, None) @@ -334,14 +339,17 @@ class JsonProtocolSuite extends SparkFunSuite { assertEquals(expectedJobEnd, JsonProtocol.jobEndFromJson(oldEndEvent)) } - test("RDDInfo backward compatibility (scope, parent IDs)") { - // Prior to Spark 1.4.0, RDDInfo did not have the "Scope" and "Parent IDs" properties - val rddInfo = new RDDInfo( - 1, "one", 100, StorageLevel.NONE, Seq(1, 6, 8), Some(new RDDOperationScope("fable"))) + test("RDDInfo backward compatibility (scope, parent IDs, callsite)") { + // "Scope" and "Parent IDs" were introduced in Spark 1.4.0 + // "Callsite" was introduced in Spark 1.6.0 + val rddInfo = new RDDInfo(1, "one", 100, StorageLevel.NONE, Seq(1, 6, 8), + CallSite("short", "long"), Some(new RDDOperationScope("fable"))) val oldRddInfoJson = JsonProtocol.rddInfoToJson(rddInfo) .removeField({ _._1 == "Parent IDs"}) .removeField({ _._1 == "Scope"}) - val expectedRddInfo = new RDDInfo(1, "one", 100, StorageLevel.NONE, Seq.empty, scope = None) + .removeField({ _._1 == "Callsite"}) + val expectedRddInfo = new RDDInfo( + 1, "one", 100, StorageLevel.NONE, Seq.empty, CallSite.empty, scope = None) assertEquals(expectedRddInfo, JsonProtocol.rddInfoFromJson(oldRddInfoJson)) } @@ -389,6 +397,11 @@ class JsonProtocolSuite extends SparkFunSuite { assertEquals(info, newInfo) } + private def testCallsite(callsite: CallSite): Unit = { + val newCallsite = JsonProtocol.callsiteFromJson(JsonProtocol.callsiteToJson(callsite)) + assert(callsite === newCallsite) + } + private def testStageInfo(info: StageInfo) { val newInfo = JsonProtocol.stageInfoFromJson(JsonProtocol.stageInfoToJson(info)) assertEquals(info, newInfo) @@ -713,7 +726,8 @@ class JsonProtocolSuite extends SparkFunSuite { } private def makeRddInfo(a: Int, b: Int, c: Int, d: Long, e: Long) = { - val r = new RDDInfo(a, "mayor", b, StorageLevel.MEMORY_AND_DISK, Seq(1, 4, 7)) + val r = new RDDInfo(a, "mayor", b, StorageLevel.MEMORY_AND_DISK, + Seq(1, 4, 7), CallSite(a.toString, b.toString)) r.numCachedPartitions = c r.memSize = d r.diskSize = e @@ -856,6 +870,7 @@ class JsonProtocolSuite extends SparkFunSuite { | { | "RDD ID": 101, | "Name": "mayor", + | "Callsite": {"Short Form": "101", "Long Form": "201"}, | "Parent IDs": [1, 4, 7], | "Storage Level": { | "Use Disk": true, @@ -1258,6 +1273,7 @@ class JsonProtocolSuite extends SparkFunSuite { | { | "RDD ID": 1, | "Name": "mayor", + | "Callsite": {"Short Form": "1", "Long Form": "200"}, | "Parent IDs": [1, 4, 7], | "Storage Level": { | "Use Disk": true, @@ -1301,6 +1317,7 @@ class JsonProtocolSuite extends SparkFunSuite { | { | "RDD ID": 2, | "Name": "mayor", + | "Callsite": {"Short Form": "2", "Long Form": "400"}, | "Parent IDs": [1, 4, 7], | "Storage Level": { | "Use Disk": true, @@ -1318,6 +1335,7 @@ class JsonProtocolSuite extends SparkFunSuite { | { | "RDD ID": 3, | "Name": "mayor", + | "Callsite": {"Short Form": "3", "Long Form": "401"}, | "Parent IDs": [1, 4, 7], | "Storage Level": { | "Use Disk": true, @@ -1361,6 +1379,7 @@ class JsonProtocolSuite extends SparkFunSuite { | { | "RDD ID": 3, | "Name": "mayor", + | "Callsite": {"Short Form": "3", "Long Form": "600"}, | "Parent IDs": [1, 4, 7], | "Storage Level": { | "Use Disk": true, @@ -1378,6 +1397,7 @@ class JsonProtocolSuite extends SparkFunSuite { | { | "RDD ID": 4, | "Name": "mayor", + | "Callsite": {"Short Form": "4", "Long Form": "601"}, | "Parent IDs": [1, 4, 7], | "Storage Level": { | "Use Disk": true, @@ -1395,6 +1415,7 @@ class JsonProtocolSuite extends SparkFunSuite { | { | "RDD ID": 5, | "Name": "mayor", + | "Callsite": {"Short Form": "5", "Long Form": "602"}, | "Parent IDs": [1, 4, 7], | "Storage Level": { | "Use Disk": true, @@ -1438,6 +1459,7 @@ class JsonProtocolSuite extends SparkFunSuite { | { | "RDD ID": 4, | "Name": "mayor", + | "Callsite": {"Short Form": "4", "Long Form": "800"}, | "Parent IDs": [1, 4, 7], | "Storage Level": { | "Use Disk": true, @@ -1455,6 +1477,7 @@ class JsonProtocolSuite extends SparkFunSuite { | { | "RDD ID": 5, | "Name": "mayor", + | "Callsite": {"Short Form": "5", "Long Form": "801"}, | "Parent IDs": [1, 4, 7], | "Storage Level": { | "Use Disk": true, @@ -1472,6 +1495,7 @@ class JsonProtocolSuite extends SparkFunSuite { | { | "RDD ID": 6, | "Name": "mayor", + | "Callsite": {"Short Form": "6", "Long Form": "802"}, | "Parent IDs": [1, 4, 7], | "Storage Level": { | "Use Disk": true, @@ -1489,6 +1513,7 @@ class JsonProtocolSuite extends SparkFunSuite { | { | "RDD ID": 7, | "Name": "mayor", + | "Callsite": {"Short Form": "7", "Long Form": "803"}, | "Parent IDs": [1, 4, 7], | "Storage Level": { | "Use Disk": true, From 2ff0e79a8647cca5c9c57f613a07e739ac4f677e Mon Sep 17 00:00:00 2001 From: Yu ISHIKAWA Date: Fri, 6 Nov 2015 22:56:29 -0800 Subject: [PATCH 30/88] [SPARK-8467] [MLLIB] [PYSPARK] Add LDAModel.describeTopics() in Python Could jkbradley and davies review it? - Create a wrapper class: `LDAModelWrapper` for `LDAModel`. Because we can't deal with the return value of`describeTopics` in Scala from pyspark directly. `Array[(Array[Int], Array[Double])]` is too complicated to convert it. - Add `loadLDAModel` in `PythonMLlibAPI`. Since `LDAModel` in Scala is an abstract class and we need to call `load` of `DistributedLDAModel`. [[SPARK-8467] Add LDAModel.describeTopics() in Python - ASF JIRA](https://issues.apache.org/jira/browse/SPARK-8467) Author: Yu ISHIKAWA Closes #8643 from yu-iskw/SPARK-8467-2. --- .../mllib/api/python/LDAModelWrapper.scala | 46 +++++++++++++++++++ .../mllib/api/python/PythonMLLibAPI.scala | 13 +++++- python/pyspark/mllib/clustering.py | 33 +++++++------ 3 files changed, 75 insertions(+), 17 deletions(-) create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/api/python/LDAModelWrapper.scala diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/LDAModelWrapper.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/LDAModelWrapper.scala new file mode 100644 index 0000000000000..63282eee6e656 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/LDAModelWrapper.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.mllib.api.python + +import scala.collection.JavaConverters + +import org.apache.spark.SparkContext +import org.apache.spark.mllib.clustering.LDAModel +import org.apache.spark.mllib.linalg.Matrix + +/** + * Wrapper around LDAModel to provide helper methods in Python + */ +private[python] class LDAModelWrapper(model: LDAModel) { + + def topicsMatrix(): Matrix = model.topicsMatrix + + def vocabSize(): Int = model.vocabSize + + def describeTopics(): Array[Byte] = describeTopics(this.model.vocabSize) + + def describeTopics(maxTermsPerTopic: Int): Array[Byte] = { + val topics = model.describeTopics(maxTermsPerTopic).map { case (terms, termWeights) => + val jTerms = JavaConverters.seqAsJavaListConverter(terms).asJava + val jTermWeights = JavaConverters.seqAsJavaListConverter(termWeights).asJava + Array[Any](jTerms, jTermWeights) + } + SerDe.dumps(JavaConverters.seqAsJavaListConverter(topics).asJava) + } + + def save(sc: SparkContext, path: String): Unit = model.save(sc, path) +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index 40c41806cdfea..54b03a9f90283 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -517,7 +517,7 @@ private[python] class PythonMLLibAPI extends Serializable { topicConcentration: Double, seed: java.lang.Long, checkpointInterval: Int, - optimizer: String): LDAModel = { + optimizer: String): LDAModelWrapper = { val algo = new LDA() .setK(k) .setMaxIterations(maxIterations) @@ -535,7 +535,16 @@ private[python] class PythonMLLibAPI extends Serializable { case _ => throw new IllegalArgumentException("input values contains invalid type value.") } } - algo.run(documents) + val model = algo.run(documents) + new LDAModelWrapper(model) + } + + /** + * Load a LDA model + */ + def loadLDAModel(jsc: JavaSparkContext, path: String): LDAModelWrapper = { + val model = DistributedLDAModel.load(jsc.sc, path) + new LDAModelWrapper(model) } diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py index 8629aa5a17164..12081f8c69075 100644 --- a/python/pyspark/mllib/clustering.py +++ b/python/pyspark/mllib/clustering.py @@ -671,7 +671,7 @@ def predictOnValues(self, dstream): return dstream.mapValues(lambda x: self._model.predict(x)) -class LDAModel(JavaModelWrapper): +class LDAModel(JavaModelWrapper, JavaSaveable, Loader): """ A clustering model derived from the LDA method. @@ -691,9 +691,14 @@ class LDAModel(JavaModelWrapper): ... [2, SparseVector(2, {0: 1.0})], ... ] >>> rdd = sc.parallelize(data) - >>> model = LDA.train(rdd, k=2) + >>> model = LDA.train(rdd, k=2, seed=1) >>> model.vocabSize() 2 + >>> model.describeTopics() + [([1, 0], [0.5..., 0.49...]), ([0, 1], [0.5..., 0.49...])] + >>> model.describeTopics(1) + [([1], [0.5...]), ([0], [0.5...])] + >>> topics = model.topicsMatrix() >>> topics_expect = array([[0.5, 0.5], [0.5, 0.5]]) >>> assert_almost_equal(topics, topics_expect, 1) @@ -724,18 +729,17 @@ def vocabSize(self): """Vocabulary size (number of terms or terms in the vocabulary)""" return self.call("vocabSize") - @since('1.5.0') - def save(self, sc, path): - """Save the LDAModel on to disk. + @since('1.6.0') + def describeTopics(self, maxTermsPerTopic=None): + """Return the topics described by weighted terms. - :param sc: SparkContext - :param path: str, path to where the model needs to be stored. + WARNING: If vocabSize and k are large, this can return a large object! """ - if not isinstance(sc, SparkContext): - raise TypeError("sc should be a SparkContext, got type %s" % type(sc)) - if not isinstance(path, basestring): - raise TypeError("path should be a basestring, got type %s" % type(path)) - self._java_model.save(sc._jsc.sc(), path) + if maxTermsPerTopic is None: + topics = self.call("describeTopics") + else: + topics = self.call("describeTopics", maxTermsPerTopic) + return topics @classmethod @since('1.5.0') @@ -749,9 +753,8 @@ def load(cls, sc, path): raise TypeError("sc should be a SparkContext, got type %s" % type(sc)) if not isinstance(path, basestring): raise TypeError("path should be a basestring, got type %s" % type(path)) - java_model = sc._jvm.org.apache.spark.mllib.clustering.DistributedLDAModel.load( - sc._jsc.sc(), path) - return cls(java_model) + model = callMLlibFunc("loadLDAModel", sc, path) + return LDAModel(model) class LDA(object): From ef362846eb448769bcf774fc9090a5013d459464 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Sat, 7 Nov 2015 13:37:37 -0800 Subject: [PATCH 31/88] [SPARK-9241][SQL] Supporting multiple DISTINCT columns - follow-up This PR is a follow up for PR https://github.com/apache/spark/pull/9406. It adds more documentation to the rewriting rule, removes a redundant if expression in the non-distinct aggregation path and adds a multiple distinct test to the AggregationQuerySuite. cc yhuai marmbrus Author: Herman van Hovell Closes #9541 from hvanhovell/SPARK-9241-followup. --- .../expressions/aggregate/Utils.scala | 114 ++++++++++++++---- .../execution/AggregationQuerySuite.scala | 17 +++ 2 files changed, 108 insertions(+), 23 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Utils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Utils.scala index 39010c3be6d4e..ac23f727829b6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Utils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Utils.scala @@ -222,10 +222,76 @@ object Utils { * aggregation in which the regular aggregation expressions and every distinct clause is aggregated * in a separate group. The results are then combined in a second aggregate. * - * TODO Expression cannocalization - * TODO Eliminate foldable expressions from distinct clauses. - * TODO This eliminates all distinct expressions. We could safely pass one to the aggregate - * operator. Perhaps this is a good thing? It is much simpler to plan later on... + * For example (in scala): + * {{{ + * val data = Seq( + * ("a", "ca1", "cb1", 10), + * ("a", "ca1", "cb2", 5), + * ("b", "ca1", "cb1", 13)) + * .toDF("key", "cat1", "cat2", "value") + * data.registerTempTable("data") + * + * val agg = data.groupBy($"key") + * .agg( + * countDistinct($"cat1").as("cat1_cnt"), + * countDistinct($"cat2").as("cat2_cnt"), + * sum($"value").as("total")) + * }}} + * + * This translates to the following (pseudo) logical plan: + * {{{ + * Aggregate( + * key = ['key] + * functions = [COUNT(DISTINCT 'cat1), + * COUNT(DISTINCT 'cat2), + * sum('value)] + * output = ['key, 'cat1_cnt, 'cat2_cnt, 'total]) + * LocalTableScan [...] + * }}} + * + * This rule rewrites this logical plan to the following (pseudo) logical plan: + * {{{ + * Aggregate( + * key = ['key] + * functions = [count(if (('gid = 1)) 'cat1 else null), + * count(if (('gid = 2)) 'cat2 else null), + * first(if (('gid = 0)) 'total else null) ignore nulls] + * output = ['key, 'cat1_cnt, 'cat2_cnt, 'total]) + * Aggregate( + * key = ['key, 'cat1, 'cat2, 'gid] + * functions = [sum('value)] + * output = ['key, 'cat1, 'cat2, 'gid, 'total]) + * Expand( + * projections = [('key, null, null, 0, cast('value as bigint)), + * ('key, 'cat1, null, 1, null), + * ('key, null, 'cat2, 2, null)] + * output = ['key, 'cat1, 'cat2, 'gid, 'value]) + * LocalTableScan [...] + * }}} + * + * The rule does the following things here: + * 1. Expand the data. There are three aggregation groups in this query: + * i. the non-distinct group; + * ii. the distinct 'cat1 group; + * iii. the distinct 'cat2 group. + * An expand operator is inserted to expand the child data for each group. The expand will null + * out all unused columns for the given group; this must be done in order to ensure correctness + * later on. Groups can by identified by a group id (gid) column added by the expand operator. + * 2. De-duplicate the distinct paths and aggregate the non-aggregate path. The group by clause of + * this aggregate consists of the original group by clause, all the requested distinct columns + * and the group id. Both de-duplication of distinct column and the aggregation of the + * non-distinct group take advantage of the fact that we group by the group id (gid) and that we + * have nulled out all non-relevant columns for the the given group. + * 3. Aggregating the distinct groups and combining this with the results of the non-distinct + * aggregation. In this step we use the group id to filter the inputs for the aggregate + * functions. The result of the non-distinct group are 'aggregated' by using the first operator, + * it might be more elegant to use the native UDAF merge mechanism for this in the future. + * + * This rule duplicates the input data by two or more times (# distinct groups + an optional + * non-distinct group). This will put quite a bit of memory pressure of the used aggregate and + * exchange operators. Keeping the number of distinct groups as low a possible should be priority, + * we could improve this in the current rule by applying more advanced expression cannocalization + * techniques. */ object MultipleDistinctRewriter extends Rule[LogicalPlan] { @@ -261,11 +327,10 @@ object MultipleDistinctRewriter extends Rule[LogicalPlan] { // Functions used to modify aggregate functions and their inputs. def evalWithinGroup(id: Literal, e: Expression) = If(EqualTo(gid, id), e, nullify(e)) def patchAggregateFunctionChildren( - af: AggregateFunction2, - id: Literal, - attrs: Map[Expression, Expression]): AggregateFunction2 = { - af.withNewChildren(af.children.map { case afc => - evalWithinGroup(id, attrs(afc)) + af: AggregateFunction2)( + attrs: Expression => Expression): AggregateFunction2 = { + af.withNewChildren(af.children.map { + case afc => attrs(afc) }).asInstanceOf[AggregateFunction2] } @@ -288,7 +353,9 @@ object MultipleDistinctRewriter extends Rule[LogicalPlan] { // Final aggregate val operators = expressions.map { e => val af = e.aggregateFunction - val naf = patchAggregateFunctionChildren(af, id, distinctAggChildAttrMap) + val naf = patchAggregateFunctionChildren(af) { x => + evalWithinGroup(id, distinctAggChildAttrMap(x)) + } (e, e.copy(aggregateFunction = naf, isDistinct = false)) } @@ -304,26 +371,27 @@ object MultipleDistinctRewriter extends Rule[LogicalPlan] { val regularGroupId = Literal(0) val regularAggOperatorMap = regularAggExprs.map { e => // Perform the actual aggregation in the initial aggregate. - val af = patchAggregateFunctionChildren( - e.aggregateFunction, - regularGroupId, - regularAggChildAttrMap) - val a = Alias(e.copy(aggregateFunction = af), e.toString)() - - // Get the result of the first aggregate in the last aggregate. - val b = AggregateExpression2( - aggregate.First(evalWithinGroup(regularGroupId, a.toAttribute), Literal(true)), + val af = patchAggregateFunctionChildren(e.aggregateFunction)(regularAggChildAttrMap) + val operator = Alias(e.copy(aggregateFunction = af), e.toString)() + + // Select the result of the first aggregate in the last aggregate. + val result = AggregateExpression2( + aggregate.First(evalWithinGroup(regularGroupId, operator.toAttribute), Literal(true)), mode = Complete, isDistinct = false) // Some aggregate functions (COUNT) have the special property that they can return a // non-null result without any input. We need to make sure we return a result in this case. - val c = af.defaultResult match { - case Some(lit) => Coalesce(Seq(b, lit)) - case None => b + val resultWithDefault = af.defaultResult match { + case Some(lit) => Coalesce(Seq(result, lit)) + case None => result } - (e, a, c) + // Return a Tuple3 containing: + // i. The original aggregate expression (used for look ups). + // ii. The actual aggregation operator (used in the first aggregate). + // iii. The operator that selects and returns the result (used in the second aggregate). + (e, operator, resultWithDefault) } // Construct the regular aggregate input projection only if we need one. diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala index ea80060e370e0..7f6fe339232ad 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -516,6 +516,23 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te Row(3, 4, 4, 3, null) :: Nil) } + test("multiple distinct column sets") { + checkAnswer( + sqlContext.sql( + """ + |SELECT + | key, + | count(distinct value1), + | count(distinct value2) + |FROM agg2 + |GROUP BY key + """.stripMargin), + Row(null, 3, 3) :: + Row(1, 2, 3) :: + Row(2, 2, 1) :: + Row(3, 0, 1) :: Nil) + } + test("test count") { checkAnswer( sqlContext.sql( From 4b69a42eda3aff08eb7437c353fe2cc87ed67181 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sat, 7 Nov 2015 19:44:45 -0800 Subject: [PATCH 32/88] [SPARK-11362] [SQL] Use Spark BitSet in BroadcastNestedLoopJoin JIRA: https://issues.apache.org/jira/browse/SPARK-11362 We use scala.collection.mutable.BitSet in BroadcastNestedLoopJoin now. We should use Spark's BitSet. Author: Liang-Chi Hsieh Closes #9316 from viirya/use-spark-bitset. --- .../joins/BroadcastNestedLoopJoin.scala | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala index 05d20f511aef8..aab177b2e8427 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} import org.apache.spark.sql.execution.metric.SQLMetrics -import org.apache.spark.util.collection.CompactBuffer +import org.apache.spark.util.collection.{BitSet, CompactBuffer} case class BroadcastNestedLoopJoin( @@ -95,9 +95,7 @@ case class BroadcastNestedLoopJoin( /** All rows that either match both-way, or rows from streamed joined with nulls. */ val matchesOrStreamedRowsWithNulls = streamed.execute().mapPartitions { streamedIter => val matchedRows = new CompactBuffer[InternalRow] - // TODO: Use Spark's BitSet. - val includedBroadcastTuples = - new scala.collection.mutable.BitSet(broadcastedRelation.value.size) + val includedBroadcastTuples = new BitSet(broadcastedRelation.value.size) val joinedRow = new JoinedRow val leftNulls = new GenericMutableRow(left.output.size) @@ -115,11 +113,11 @@ case class BroadcastNestedLoopJoin( case BuildRight if boundCondition(joinedRow(streamedRow, broadcastedRow)) => matchedRows += resultProj(joinedRow(streamedRow, broadcastedRow)).copy() streamRowMatched = true - includedBroadcastTuples += i + includedBroadcastTuples.set(i) case BuildLeft if boundCondition(joinedRow(broadcastedRow, streamedRow)) => matchedRows += resultProj(joinedRow(broadcastedRow, streamedRow)).copy() streamRowMatched = true - includedBroadcastTuples += i + includedBroadcastTuples.set(i) case _ => } i += 1 @@ -138,8 +136,8 @@ case class BroadcastNestedLoopJoin( val includedBroadcastTuples = matchesOrStreamedRowsWithNulls.map(_._2) val allIncludedBroadcastTuples = includedBroadcastTuples.fold( - new scala.collection.mutable.BitSet(broadcastedRelation.value.size) - )(_ ++ _) + new BitSet(broadcastedRelation.value.size) + )(_ | _) val leftNulls = new GenericMutableRow(left.output.size) val rightNulls = new GenericMutableRow(right.output.size) @@ -155,7 +153,7 @@ case class BroadcastNestedLoopJoin( val joinedRow = new JoinedRow joinedRow.withLeft(leftNulls) while (i < rel.length) { - if (!allIncludedBroadcastTuples.contains(i)) { + if (!allIncludedBroadcastTuples.get(i)) { buf += resultProj(joinedRow.withRight(rel(i))).copy() } i += 1 @@ -164,7 +162,7 @@ case class BroadcastNestedLoopJoin( val joinedRow = new JoinedRow joinedRow.withRight(rightNulls) while (i < rel.length) { - if (!allIncludedBroadcastTuples.contains(i)) { + if (!allIncludedBroadcastTuples.get(i)) { buf += resultProj(joinedRow.withLeft(rel(i))).copy() } i += 1 From d981902101767b32dc83a5a639311e197f5cbcc1 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Sun, 8 Nov 2015 11:15:58 +0000 Subject: [PATCH 33/88] [SPARK-11476][DOCS] Incorrect function referred to in MLib Random data generation documentation Fix Python example to use normalRDD as advertised Author: Sean Owen Closes #9529 from srowen/SPARK-11476. --- docs/mllib-statistics.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/mllib-statistics.md b/docs/mllib-statistics.md index 2c7c9ed693fd4..ade5b0768aefe 100644 --- a/docs/mllib-statistics.md +++ b/docs/mllib-statistics.md @@ -594,7 +594,7 @@ sc = ... # SparkContext # Generate a random double RDD that contains 1 million i.i.d. values drawn from the # standard normal distribution `N(0, 1)`, evenly distributed in 10 partitions. -u = RandomRDDs.uniformRDD(sc, 1000000L, 10) +u = RandomRDDs.normalRDD(sc, 1000000L, 10) # Apply a transform to get a random double RDD following `N(1, 4)`. v = u.map(lambda x: 1.0 + 2.0 * x) {% endhighlight %} From 5c4e6d7ec9157c02494a382dfb49e7fbde3be222 Mon Sep 17 00:00:00 2001 From: Rohit Agarwal Date: Sun, 8 Nov 2015 14:24:26 +0000 Subject: [PATCH 34/88] [DOC][SQL] Remove redundant out-of-place python snippet This snippet seems to be mistakenly introduced at two places in #5348. Author: Rohit Agarwal Closes #9540 from mindprince/patch-1. --- docs/sql-programming-guide.md | 9 --------- 1 file changed, 9 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 2fe5c36338899..085874133d968 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1089,15 +1089,6 @@ for (teenName in collect(teenNames)) { -
- -{% highlight python %} -# sqlContext is an existing HiveContext -sqlContext.sql("REFRESH TABLE my_table") -{% endhighlight %} - -
-
{% highlight sql %} From 30c8ba71a76788cbc6916bc1ba6bc8522925fc2b Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Sun, 8 Nov 2015 11:06:10 -0800 Subject: [PATCH 35/88] [SPARK-11451][SQL] Support single distinct count on multiple columns. This PR adds support for multiple column in a single count distinct aggregate to the new aggregation path. cc yhuai Author: Herman van Hovell Closes #9409 from hvanhovell/SPARK-11451. --- .../expressions/aggregate/Utils.scala | 44 +++++++++++-------- .../expressions/conditionalExpressions.scala | 30 ++++++++++++- .../plans/logical/basicOperators.scala | 3 ++ .../ConditionalExpressionSuite.scala | 14 ++++++ .../spark/sql/DataFrameAggregateSuite.scala | 25 +++++++++++ .../execution/AggregationQuerySuite.scala | 37 +++++++++++++--- 6 files changed, 127 insertions(+), 26 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Utils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Utils.scala index ac23f727829b6..9b22ce2619731 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Utils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Utils.scala @@ -22,26 +22,27 @@ import org.apache.spark.sql.catalyst._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.{Expand, Aggregate, LogicalPlan} import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.types.{IntegerType, StructType, MapType, ArrayType} +import org.apache.spark.sql.types._ /** * Utility functions used by the query planner to convert our plan to new aggregation code path. */ object Utils { - // Right now, we do not support complex types in the grouping key schema. - private def supportsGroupingKeySchema(aggregate: Aggregate): Boolean = { - val hasComplexTypes = aggregate.groupingExpressions.map(_.dataType).exists { - case array: ArrayType => true - case map: MapType => true - case struct: StructType => true - case _ => false - } - !hasComplexTypes + // Check if the DataType given cannot be part of a group by clause. + private def isUnGroupable(dt: DataType): Boolean = dt match { + case _: ArrayType | _: MapType => true + case s: StructType => s.fields.exists(f => isUnGroupable(f.dataType)) + case _ => false } + // Right now, we do not support complex types in the grouping key schema. + private def supportsGroupingKeySchema(aggregate: Aggregate): Boolean = + !aggregate.groupingExpressions.exists(e => isUnGroupable(e.dataType)) + private def doConvert(plan: LogicalPlan): Option[Aggregate] = plan match { case p: Aggregate if supportsGroupingKeySchema(p) => + val converted = MultipleDistinctRewriter.rewrite(p.transformExpressionsDown { case expressions.Average(child) => aggregate.AggregateExpression2( @@ -55,10 +56,14 @@ object Utils { mode = aggregate.Complete, isDistinct = false) - // We do not support multiple COUNT DISTINCT columns for now. - case expressions.CountDistinct(children) if children.length == 1 => + case expressions.CountDistinct(children) => + val child = if (children.size > 1) { + DropAnyNull(CreateStruct(children)) + } else { + children.head + } aggregate.AggregateExpression2( - aggregateFunction = aggregate.Count(children.head), + aggregateFunction = aggregate.Count(child), mode = aggregate.Complete, isDistinct = true) @@ -320,7 +325,7 @@ object MultipleDistinctRewriter extends Rule[LogicalPlan] { val gid = new AttributeReference("gid", IntegerType, false)() val groupByMap = a.groupingExpressions.collect { case ne: NamedExpression => ne -> ne.toAttribute - case e => e -> new AttributeReference(e.prettyName, e.dataType, e.nullable)() + case e => e -> new AttributeReference(e.prettyString, e.dataType, e.nullable)() } val groupByAttrs = groupByMap.map(_._2) @@ -365,14 +370,15 @@ object MultipleDistinctRewriter extends Rule[LogicalPlan] { // Setup expand for the 'regular' aggregate expressions. val regularAggExprs = aggExpressions.filter(!_.isDistinct) val regularAggChildren = regularAggExprs.flatMap(_.aggregateFunction.children).distinct - val regularAggChildAttrMap = regularAggChildren.map(expressionAttributePair).toMap + val regularAggChildAttrMap = regularAggChildren.map(expressionAttributePair) // Setup aggregates for 'regular' aggregate expressions. val regularGroupId = Literal(0) + val regularAggChildAttrLookup = regularAggChildAttrMap.toMap val regularAggOperatorMap = regularAggExprs.map { e => // Perform the actual aggregation in the initial aggregate. - val af = patchAggregateFunctionChildren(e.aggregateFunction)(regularAggChildAttrMap) - val operator = Alias(e.copy(aggregateFunction = af), e.toString)() + val af = patchAggregateFunctionChildren(e.aggregateFunction)(regularAggChildAttrLookup) + val operator = Alias(e.copy(aggregateFunction = af), e.prettyString)() // Select the result of the first aggregate in the last aggregate. val result = AggregateExpression2( @@ -416,7 +422,7 @@ object MultipleDistinctRewriter extends Rule[LogicalPlan] { // Construct the expand operator. val expand = Expand( regularAggProjection ++ distinctAggProjections, - groupByAttrs ++ distinctAggChildAttrs ++ Seq(gid) ++ regularAggChildAttrMap.values.toSeq, + groupByAttrs ++ distinctAggChildAttrs ++ Seq(gid) ++ regularAggChildAttrMap.map(_._2), a.child) // Construct the first aggregate operator. This de-duplicates the all the children of @@ -457,5 +463,5 @@ object MultipleDistinctRewriter extends Rule[LogicalPlan] { // NamedExpression. This is done to prevent collisions between distinct and regular aggregate // children, in this case attribute reuse causes the input of the regular aggregate to bound to // the (nulled out) input of the distinct aggregate. - e -> new AttributeReference(e.prettyName, e.dataType, true)() + e -> new AttributeReference(e.prettyString, e.dataType, true)() } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index d532629984bec..0d4af43978ea1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -21,7 +21,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util.TypeUtils -import org.apache.spark.sql.types.{NullType, BooleanType, DataType} +import org.apache.spark.sql.types._ case class If(predicate: Expression, trueValue: Expression, falseValue: Expression) @@ -419,3 +419,31 @@ case class Greatest(children: Seq[Expression]) extends Expression { """ } } + +/** Operator that drops a row when it contains any nulls. */ +case class DropAnyNull(child: Expression) extends UnaryExpression with ExpectsInputTypes { + override def nullable: Boolean = true + override def dataType: DataType = child.dataType + override def inputTypes: Seq[AbstractDataType] = Seq(StructType) + + protected override def nullSafeEval(input: Any): InternalRow = { + val row = input.asInstanceOf[InternalRow] + if (row.anyNull) { + null + } else { + row + } + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + nullSafeCodeGen(ctx, ev, eval => { + s""" + if ($eval.anyNull()) { + ${ev.isNull} = true; + } else { + ${ev.value} = $eval; + } + """ + }) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index fb963e2f8f7e7..09aac00a455f9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -306,6 +306,9 @@ case class Expand( output: Seq[Attribute], child: LogicalPlan) extends UnaryNode { + override def references: AttributeSet = + AttributeSet(projections.flatten.flatMap(_.references)) + override def statistics: Statistics = { // TODO shouldn't we factor in the size of the projection versus the size of the backing child // row? diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala index 0df673bb9fa02..c1e3c17b87102 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala @@ -231,4 +231,18 @@ class ConditionalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper checkConsistencyBetweenInterpretedAndCodegen(Greatest, dt, 2) } } + + test("function dropAnyNull") { + val drop = DropAnyNull(CreateStruct(Seq('a.string.at(0), 'b.string.at(1)))) + val a = create_row("a", "q") + val nullStr: String = null + checkEvaluation(drop, a, a) + checkEvaluation(drop, null, create_row("b", nullStr)) + checkEvaluation(drop, null, create_row(nullStr, nullStr)) + + val row = 'r.struct( + StructField("a", StringType, false), + StructField("b", StringType, true)).at(0) + checkEvaluation(DropAnyNull(row), null, create_row(null)) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 2e679e7bc4e0a..eb1ee266c5d28 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -162,6 +162,31 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { ) } + test("multiple column distinct count") { + val df1 = Seq( + ("a", "b", "c"), + ("a", "b", "c"), + ("a", "b", "d"), + ("x", "y", "z"), + ("x", "q", null.asInstanceOf[String])) + .toDF("key1", "key2", "key3") + + checkAnswer( + df1.agg(countDistinct('key1, 'key2)), + Row(3) + ) + + checkAnswer( + df1.agg(countDistinct('key1, 'key2, 'key3)), + Row(3) + ) + + checkAnswer( + df1.groupBy('key1).agg(countDistinct('key2, 'key3)), + Seq(Row("a", 2), Row("x", 1)) + ) + } + test("zero count") { val emptyTableData = Seq.empty[(Int, Int)].toDF("a", "b") checkAnswer( diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala index 7f6fe339232ad..ea36c132bb190 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -516,21 +516,46 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te Row(3, 4, 4, 3, null) :: Nil) } - test("multiple distinct column sets") { + test("single distinct multiple columns set") { + checkAnswer( + sqlContext.sql( + """ + |SELECT + | key, + | count(distinct value1, value2) + |FROM agg2 + |GROUP BY key + """.stripMargin), + Row(null, 3) :: + Row(1, 3) :: + Row(2, 1) :: + Row(3, 0) :: Nil) + } + + test("multiple distinct multiple columns sets") { checkAnswer( sqlContext.sql( """ |SELECT | key, | count(distinct value1), - | count(distinct value2) + | sum(distinct value1), + | count(distinct value2), + | sum(distinct value2), + | count(distinct value1, value2), + | count(value1), + | sum(value1), + | count(value2), + | sum(value2), + | count(*), + | count(1) |FROM agg2 |GROUP BY key """.stripMargin), - Row(null, 3, 3) :: - Row(1, 2, 3) :: - Row(2, 2, 1) :: - Row(3, 0, 1) :: Nil) + Row(null, 3, 30, 3, 60, 3, 3, 30, 3, 60, 4, 4) :: + Row(1, 2, 40, 3, -10, 3, 3, 70, 3, -10, 3, 3) :: + Row(2, 2, 0, 1, 1, 1, 3, 1, 3, 3, 4, 4) :: + Row(3, 0, null, 1, 3, 0, 0, null, 1, 3, 2, 2) :: Nil) } test("test count") { From 26739059bc39cd7aa7e0b1c16089c1cf8d8e4d7d Mon Sep 17 00:00:00 2001 From: xin Wu Date: Sun, 8 Nov 2015 12:28:19 -0800 Subject: [PATCH 36/88] =?UTF-8?q?[SPARK-10046][SQL]=20Hive=20warehouse=20d?= =?UTF-8?q?ir=20not=20set=20in=20current=20directory=20when=20not=20?= =?UTF-8?q?=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Doc change to align with HiveConf default in terms of where to create `warehouse` directory. Author: xin Wu Closes #9365 from xwu0226/spark-10046-commit. --- docs/sql-programming-guide.md | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 085874133d968..52e03b951f966 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1627,8 +1627,10 @@ YARN cluster. The convenient way to do this is adding them through the `--jars` When working with Hive one must construct a `HiveContext`, which inherits from `SQLContext`, and adds support for finding tables in the MetaStore and writing queries using HiveQL. Users who do not have an existing Hive deployment can still create a `HiveContext`. When not configured by the -hive-site.xml, the context automatically creates `metastore_db` and `warehouse` in the current -directory. +hive-site.xml, the context automatically creates `metastore_db` in the current directory and +creates `warehouse` directory indicated by HiveConf, which defaults to `/user/hive/warehouse`. +Note that you may need to grant write privilege on `/user/hive/warehouse` to the user who starts +the spark application. {% highlight scala %} // sc is an existing SparkContext. From b2d195e137fad88d567974659fa7023ff4da96cd Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Sun, 8 Nov 2015 12:59:35 -0800 Subject: [PATCH 37/88] [SPARK-11554][SQL] add map/flatMap to GroupedDataset Author: Wenchen Fan Closes #9521 from cloud-fan/map. --- .../plans/logical/basicOperators.scala | 4 +- .../org/apache/spark/sql/GroupedDataset.scala | 29 ++++++++++++-- .../spark/sql/execution/basicOperators.scala | 2 +- .../apache/spark/sql/JavaDatasetSuite.java | 16 ++++---- .../spark/sql/DatasetPrimitiveSuite.scala | 16 ++++++-- .../org/apache/spark/sql/DatasetSuite.scala | 40 +++++++++---------- 6 files changed, 70 insertions(+), 37 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 09aac00a455f9..e151ac04ede2a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -494,7 +494,7 @@ case class AppendColumn[T, U]( /** Factory for constructing new `MapGroups` nodes. */ object MapGroups { def apply[K : Encoder, T : Encoder, U : Encoder]( - func: (K, Iterator[T]) => Iterator[U], + func: (K, Iterator[T]) => TraversableOnce[U], groupingAttributes: Seq[Attribute], child: LogicalPlan): MapGroups[K, T, U] = { new MapGroups( @@ -514,7 +514,7 @@ object MapGroups { * object representation of all the rows with that key. */ case class MapGroups[K, T, U]( - func: (K, Iterator[T]) => Iterator[U], + func: (K, Iterator[T]) => TraversableOnce[U], kEncoder: ExpressionEncoder[K], tEncoder: ExpressionEncoder[T], uEncoder: ExpressionEncoder[U], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala index b2803d5a9a1e3..5c3f626545875 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala @@ -102,16 +102,39 @@ class GroupedDataset[K, T] private[sql]( * (for example, by calling `toList`) unless they are sure that this is possible given the memory * constraints of their cluster. */ - def mapGroups[U : Encoder](f: (K, Iterator[T]) => Iterator[U]): Dataset[U] = { + def flatMap[U : Encoder](f: (K, Iterator[T]) => TraversableOnce[U]): Dataset[U] = { new Dataset[U]( sqlContext, MapGroups(f, groupingAttributes, logicalPlan)) } - def mapGroups[U]( + def flatMap[U]( f: JFunction2[K, JIterator[T], JIterator[U]], encoder: Encoder[U]): Dataset[U] = { - mapGroups((key, data) => f.call(key, data.asJava).asScala)(encoder) + flatMap((key, data) => f.call(key, data.asJava).asScala)(encoder) + } + + /** + * Applies the given function to each group of data. For each unique group, the function will + * be passed the group key and an iterator that contains all of the elements in the group. The + * function can return an element of arbitrary type which will be returned as a new [[Dataset]]. + * + * Internally, the implementation will spill to disk if any given group is too large to fit into + * memory. However, users must take care to avoid materializing the whole iterator for a group + * (for example, by calling `toList`) unless they are sure that this is possible given the memory + * constraints of their cluster. + */ + def map[U : Encoder](f: (K, Iterator[T]) => U): Dataset[U] = { + val func = (key: K, it: Iterator[T]) => Iterator(f(key, it)) + new Dataset[U]( + sqlContext, + MapGroups(func, groupingAttributes, logicalPlan)) + } + + def map[U]( + f: JFunction2[K, JIterator[T], U], + encoder: Encoder[U]): Dataset[U] = { + map((key, data) => f.call(key, data.asJava))(encoder) } // To ensure valid overloading. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 799650a4f784f..2593b16b1c8d7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -356,7 +356,7 @@ case class AppendColumns[T, U]( * being output. */ case class MapGroups[K, T, U]( - func: (K, Iterator[T]) => Iterator[U], + func: (K, Iterator[T]) => TraversableOnce[U], kEncoder: ExpressionEncoder[K], tEncoder: ExpressionEncoder[T], uEncoder: ExpressionEncoder[U], diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index a9493d576d179..0d3b1a5af52c4 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -170,15 +170,15 @@ public Integer call(String v) throws Exception { } }, e.INT()); - Dataset mapped = grouped.mapGroups( - new Function2, Iterator>() { + Dataset mapped = grouped.map( + new Function2, String>() { @Override - public Iterator call(Integer key, Iterator data) throws Exception { + public String call(Integer key, Iterator data) throws Exception { StringBuilder sb = new StringBuilder(key.toString()); while (data.hasNext()) { sb.append(data.next()); } - return Collections.singletonList(sb.toString()).iterator(); + return sb.toString(); } }, e.STRING()); @@ -224,15 +224,15 @@ public void testGroupByColumn() { Dataset ds = context.createDataset(data, e.STRING()); GroupedDataset grouped = ds.groupBy(length(col("value"))).asKey(e.INT()); - Dataset mapped = grouped.mapGroups( - new Function2, Iterator>() { + Dataset mapped = grouped.map( + new Function2, String>() { @Override - public Iterator call(Integer key, Iterator data) throws Exception { + public String call(Integer key, Iterator data) throws Exception { StringBuilder sb = new StringBuilder(key.toString()); while (data.hasNext()) { sb.append(data.next()); } - return Collections.singletonList(sb.toString()).iterator(); + return sb.toString(); } }, e.STRING()); diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala index e3b0346f857d3..fcf03f7180984 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala @@ -88,16 +88,26 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext { 0, 1) } - test("groupBy function, mapGroups") { + test("groupBy function, map") { val ds = Seq(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11).toDS() val grouped = ds.groupBy(_ % 2) - val agged = grouped.mapGroups { case (g, iter) => + val agged = grouped.map { case (g, iter) => val name = if (g == 0) "even" else "odd" - Iterator((name, iter.size)) + (name, iter.size) } checkAnswer( agged, ("even", 5), ("odd", 6)) } + + test("groupBy function, flatMap") { + val ds = Seq("a", "b", "c", "xyz", "hello").toDS() + val grouped = ds.groupBy(_.length) + val agged = grouped.flatMap { case (g, iter) => Iterator(g.toString, iter.mkString) } + + checkAnswer( + agged, + "1", "abc", "3", "xyz", "5", "hello") + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index d61e17edc64ed..6f1174e6577e3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -198,60 +198,60 @@ class DatasetSuite extends QueryTest with SharedSQLContext { (1, 1)) } - test("groupBy function, mapGroups") { + test("groupBy function, map") { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() val grouped = ds.groupBy(v => (v._1, "word")) - val agged = grouped.mapGroups { case (g, iter) => - Iterator((g._1, iter.map(_._2).sum)) - } + val agged = grouped.map { case (g, iter) => (g._1, iter.map(_._2).sum) } checkAnswer( agged, ("a", 30), ("b", 3), ("c", 1)) } - test("groupBy columns, mapGroups") { + test("groupBy function, fatMap") { + val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() + val grouped = ds.groupBy(v => (v._1, "word")) + val agged = grouped.flatMap { case (g, iter) => Iterator(g._1, iter.map(_._2).sum.toString) } + + checkAnswer( + agged, + "a", "30", "b", "3", "c", "1") + } + + test("groupBy columns, map") { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() val grouped = ds.groupBy($"_1") - val agged = grouped.mapGroups { case (g, iter) => - Iterator((g.getString(0), iter.map(_._2).sum)) - } + val agged = grouped.map { case (g, iter) => (g.getString(0), iter.map(_._2).sum) } checkAnswer( agged, ("a", 30), ("b", 3), ("c", 1)) } - test("groupBy columns asKey, mapGroups") { + test("groupBy columns asKey, map") { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() val grouped = ds.groupBy($"_1").asKey[String] - val agged = grouped.mapGroups { case (g, iter) => - Iterator((g, iter.map(_._2).sum)) - } + val agged = grouped.map { case (g, iter) => (g, iter.map(_._2).sum) } checkAnswer( agged, ("a", 30), ("b", 3), ("c", 1)) } - test("groupBy columns asKey tuple, mapGroups") { + test("groupBy columns asKey tuple, map") { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() val grouped = ds.groupBy($"_1", lit(1)).asKey[(String, Int)] - val agged = grouped.mapGroups { case (g, iter) => - Iterator((g, iter.map(_._2).sum)) - } + val agged = grouped.map { case (g, iter) => (g, iter.map(_._2).sum) } checkAnswer( agged, (("a", 1), 30), (("b", 1), 3), (("c", 1), 1)) } - test("groupBy columns asKey class, mapGroups") { + test("groupBy columns asKey class, map") { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() val grouped = ds.groupBy($"_1".as("a"), lit(1).as("b")).asKey[ClassData] - val agged = grouped.mapGroups { case (g, iter) => - Iterator((g, iter.map(_._2).sum)) - } + val agged = grouped.map { case (g, iter) => (g, iter.map(_._2).sum) } checkAnswer( agged, From 97b7080cf2d2846c7257f8926f775f27d457fe7d Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sun, 8 Nov 2015 20:57:09 -0800 Subject: [PATCH 38/88] [SPARK-11564][SQL] Dataset Java API audit A few changes: 1. Removed fold, since it can be confusing for distributed collections. 2. Created specific interfaces for each Dataset function (e.g. MapFunction, ReduceFunction, MapPartitionsFunction) 3. Added more documentation and test cases. The other thing I'm considering doing is to have a "collector" interface for FlatMapFunction and MapPartitionsFunction, similar to MapReduce's map function. Author: Reynold Xin Closes #9531 from rxin/SPARK-11564. --- .../api/java/function/FilterFunction.java | 29 +++++ .../api/java/function/ForeachFunction.java | 29 +++++ .../function/ForeachPartitionFunction.java | 28 +++++ .../spark/api/java/function/Function0.java | 2 +- .../spark/api/java/function/MapFunction.java | 27 +++++ .../java/function/MapPartitionsFunction.java | 28 +++++ .../api/java/function/ReduceFunction.java | 27 +++++ .../spark/sql/catalyst/encoders/Encoder.scala | 38 +++++-- .../org/apache/spark/sql/DataFrame.scala | 47 ++++++-- .../scala/org/apache/spark/sql/Dataset.scala | 100 +++++++++--------- .../apache/spark/sql/JavaDataFrameSuite.java | 7 ++ .../apache/spark/sql/JavaDatasetSuite.java | 36 +++---- .../spark/sql/DatasetPrimitiveSuite.scala | 5 - .../org/apache/spark/sql/DatasetSuite.scala | 10 +- 14 files changed, 316 insertions(+), 97 deletions(-) create mode 100644 core/src/main/java/org/apache/spark/api/java/function/FilterFunction.java create mode 100644 core/src/main/java/org/apache/spark/api/java/function/ForeachFunction.java create mode 100644 core/src/main/java/org/apache/spark/api/java/function/ForeachPartitionFunction.java create mode 100644 core/src/main/java/org/apache/spark/api/java/function/MapFunction.java create mode 100644 core/src/main/java/org/apache/spark/api/java/function/MapPartitionsFunction.java create mode 100644 core/src/main/java/org/apache/spark/api/java/function/ReduceFunction.java diff --git a/core/src/main/java/org/apache/spark/api/java/function/FilterFunction.java b/core/src/main/java/org/apache/spark/api/java/function/FilterFunction.java new file mode 100644 index 0000000000000..e8d999dd00135 --- /dev/null +++ b/core/src/main/java/org/apache/spark/api/java/function/FilterFunction.java @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.java.function; + +import java.io.Serializable; + +/** + * Base interface for a function used in Dataset's filter function. + * + * If the function returns true, the element is discarded in the returned Dataset. + */ +public interface FilterFunction extends Serializable { + boolean call(T value) throws Exception; +} diff --git a/core/src/main/java/org/apache/spark/api/java/function/ForeachFunction.java b/core/src/main/java/org/apache/spark/api/java/function/ForeachFunction.java new file mode 100644 index 0000000000000..07e54b28fa12c --- /dev/null +++ b/core/src/main/java/org/apache/spark/api/java/function/ForeachFunction.java @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.java.function; + +import java.io.Serializable; + +/** + * Base interface for a function used in Dataset's foreach function. + * + * Spark will invoke the call function on each element in the input Dataset. + */ +public interface ForeachFunction extends Serializable { + void call(T t) throws Exception; +} diff --git a/core/src/main/java/org/apache/spark/api/java/function/ForeachPartitionFunction.java b/core/src/main/java/org/apache/spark/api/java/function/ForeachPartitionFunction.java new file mode 100644 index 0000000000000..4938a51bcd712 --- /dev/null +++ b/core/src/main/java/org/apache/spark/api/java/function/ForeachPartitionFunction.java @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.java.function; + +import java.io.Serializable; +import java.util.Iterator; + +/** + * Base interface for a function used in Dataset's foreachPartition function. + */ +public interface ForeachPartitionFunction extends Serializable { + void call(Iterator t) throws Exception; +} diff --git a/core/src/main/java/org/apache/spark/api/java/function/Function0.java b/core/src/main/java/org/apache/spark/api/java/function/Function0.java index 38e410c5debe6..c86928dd05408 100644 --- a/core/src/main/java/org/apache/spark/api/java/function/Function0.java +++ b/core/src/main/java/org/apache/spark/api/java/function/Function0.java @@ -23,5 +23,5 @@ * A zero-argument function that returns an R. */ public interface Function0 extends Serializable { - public R call() throws Exception; + R call() throws Exception; } diff --git a/core/src/main/java/org/apache/spark/api/java/function/MapFunction.java b/core/src/main/java/org/apache/spark/api/java/function/MapFunction.java new file mode 100644 index 0000000000000..3ae6ef44898e1 --- /dev/null +++ b/core/src/main/java/org/apache/spark/api/java/function/MapFunction.java @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.java.function; + +import java.io.Serializable; + +/** + * Base interface for a map function used in Dataset's map function. + */ +public interface MapFunction extends Serializable { + U call(T value) throws Exception; +} diff --git a/core/src/main/java/org/apache/spark/api/java/function/MapPartitionsFunction.java b/core/src/main/java/org/apache/spark/api/java/function/MapPartitionsFunction.java new file mode 100644 index 0000000000000..6cb569ce0cb6b --- /dev/null +++ b/core/src/main/java/org/apache/spark/api/java/function/MapPartitionsFunction.java @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.java.function; + +import java.io.Serializable; +import java.util.Iterator; + +/** + * Base interface for function used in Dataset's mapPartitions. + */ +public interface MapPartitionsFunction extends Serializable { + Iterable call(Iterator input) throws Exception; +} diff --git a/core/src/main/java/org/apache/spark/api/java/function/ReduceFunction.java b/core/src/main/java/org/apache/spark/api/java/function/ReduceFunction.java new file mode 100644 index 0000000000000..ee092d0058f44 --- /dev/null +++ b/core/src/main/java/org/apache/spark/api/java/function/ReduceFunction.java @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.java.function; + +import java.io.Serializable; + +/** + * Base interface for function used in Dataset's reduce. + */ +public interface ReduceFunction extends Serializable { + T call(T v1, T v2) throws Exception; +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala index f05e18288de2b..6569b900fed90 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.encoders import scala.reflect.ClassTag import org.apache.spark.util.Utils -import org.apache.spark.sql.types.{DataType, ObjectType, StructField, StructType} +import org.apache.spark.sql.types.{ObjectType, StructField, StructType} import org.apache.spark.sql.catalyst.expressions._ /** @@ -100,7 +100,7 @@ object Encoder { expr.transformUp { case BoundReference(0, t: ObjectType, _) => Invoke( - BoundReference(0, ObjectType(cls), true), + BoundReference(0, ObjectType(cls), nullable = true), s"_${index + 1}", t) } @@ -114,13 +114,13 @@ object Encoder { } else { enc.constructExpression.transformUp { case BoundReference(ordinal, dt, _) => - GetInternalRowField(BoundReference(index, enc.schema, true), ordinal, dt) + GetInternalRowField(BoundReference(index, enc.schema, nullable = true), ordinal, dt) } } } val constructExpression = - NewInstance(cls, constructExpressions, false, ObjectType(cls)) + NewInstance(cls, constructExpressions, propagateNull = false, ObjectType(cls)) new ExpressionEncoder[Any]( schema, @@ -130,7 +130,6 @@ object Encoder { ClassTag.apply(cls)) } - def typeTagOfTuple2[T1 : TypeTag, T2 : TypeTag]: TypeTag[(T1, T2)] = typeTag[(T1, T2)] private def getTypeTag[T](c: Class[T]): TypeTag[T] = { @@ -148,9 +147,36 @@ object Encoder { }) } - def forTuple2[T1, T2](c1: Class[T1], c2: Class[T2]): Encoder[(T1, T2)] = { + def forTuple[T1, T2](c1: Class[T1], c2: Class[T2]): Encoder[(T1, T2)] = { implicit val typeTag1 = getTypeTag(c1) implicit val typeTag2 = getTypeTag(c2) ExpressionEncoder[(T1, T2)]() } + + def forTuple[T1, T2, T3](c1: Class[T1], c2: Class[T2], c3: Class[T3]): Encoder[(T1, T2, T3)] = { + implicit val typeTag1 = getTypeTag(c1) + implicit val typeTag2 = getTypeTag(c2) + implicit val typeTag3 = getTypeTag(c3) + ExpressionEncoder[(T1, T2, T3)]() + } + + def forTuple[T1, T2, T3, T4]( + c1: Class[T1], c2: Class[T2], c3: Class[T3], c4: Class[T4]): Encoder[(T1, T2, T3, T4)] = { + implicit val typeTag1 = getTypeTag(c1) + implicit val typeTag2 = getTypeTag(c2) + implicit val typeTag3 = getTypeTag(c3) + implicit val typeTag4 = getTypeTag(c4) + ExpressionEncoder[(T1, T2, T3, T4)]() + } + + def forTuple[T1, T2, T3, T4, T5]( + c1: Class[T1], c2: Class[T2], c3: Class[T3], c4: Class[T4], c5: Class[T5]) + : Encoder[(T1, T2, T3, T4, T5)] = { + implicit val typeTag1 = getTypeTag(c1) + implicit val typeTag2 = getTypeTag(c2) + implicit val typeTag3 = getTypeTag(c3) + implicit val typeTag4 = getTypeTag(c4) + implicit val typeTag5 = getTypeTag(c5) + ExpressionEncoder[(T1, T2, T3, T4, T5)]() + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index f2d4db5550273..8ab958adadcca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -1478,18 +1478,54 @@ class DataFrame private[sql]( /** * Returns the first `n` rows in the [[DataFrame]]. + * + * Running take requires moving data into the application's driver process, and doing so on a + * very large dataset can crash the driver process with OutOfMemoryError. + * * @group action * @since 1.3.0 */ def take(n: Int): Array[Row] = head(n) + /** + * Returns the first `n` rows in the [[DataFrame]] as a list. + * + * Running take requires moving data into the application's driver process, and doing so with + * a very large `n` can crash the driver process with OutOfMemoryError. + * + * @group action + * @since 1.6.0 + */ + def takeAsList(n: Int): java.util.List[Row] = java.util.Arrays.asList(take(n) : _*) + /** * Returns an array that contains all of [[Row]]s in this [[DataFrame]]. + * + * Running take requires moving data into the application's driver process, and doing so with + * a very large `n` can crash the driver process with OutOfMemoryError. + * + * For Java API, use [[collectAsList]]. + * * @group action * @since 1.3.0 */ def collect(): Array[Row] = collect(needCallback = true) + /** + * Returns a Java list that contains all of [[Row]]s in this [[DataFrame]]. + * + * Running collect requires moving all the data into the application's driver process, and + * doing so on a very large dataset can crash the driver process with OutOfMemoryError. + * + * @group action + * @since 1.3.0 + */ + def collectAsList(): java.util.List[Row] = withCallback("collectAsList", this) { _ => + withNewExecutionId { + java.util.Arrays.asList(rdd.collect() : _*) + } + } + private def collect(needCallback: Boolean): Array[Row] = { def execute(): Array[Row] = withNewExecutionId { queryExecution.executedPlan.executeCollectPublic() @@ -1502,17 +1538,6 @@ class DataFrame private[sql]( } } - /** - * Returns a Java list that contains all of [[Row]]s in this [[DataFrame]]. - * @group action - * @since 1.3.0 - */ - def collectAsList(): java.util.List[Row] = withCallback("collectAsList", this) { _ => - withNewExecutionId { - java.util.Arrays.asList(rdd.collect() : _*) - } - } - /** * Returns the number of rows in the [[DataFrame]]. * @group action diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index fecbdac9a6004..959e0f5ba03e6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -22,7 +22,7 @@ import scala.collection.JavaConverters._ import org.apache.spark.annotation.Experimental import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.analysis.UnresolvedAlias -import org.apache.spark.api.java.function.{Function => JFunction, Function2 => JFunction2, _} +import org.apache.spark.api.java.function._ import org.apache.spark.sql.catalyst.encoders._ import org.apache.spark.sql.catalyst.expressions._ @@ -75,7 +75,11 @@ class Dataset[T] private[sql]( private[sql] def this(sqlContext: SQLContext, plan: LogicalPlan)(implicit encoder: Encoder[T]) = this(sqlContext, new QueryExecution(sqlContext, plan), encoder) - /** Returns the schema of the encoded form of the objects in this [[Dataset]]. */ + /** + * Returns the schema of the encoded form of the objects in this [[Dataset]]. + * + * @since 1.6.0 + */ def schema: StructType = encoder.schema /* ************* * @@ -103,6 +107,7 @@ class Dataset[T] private[sql]( /** * Applies a logical alias to this [[Dataset]] that can be used to disambiguate columns that have * the same name after two Datasets have been joined. + * @since 1.6.0 */ def as(alias: String): Dataset[T] = withPlan(Subquery(alias, _)) @@ -166,8 +171,7 @@ class Dataset[T] private[sql]( * Returns a new [[Dataset]] that only contains elements where `func` returns `true`. * @since 1.6.0 */ - def filter(func: JFunction[T, java.lang.Boolean]): Dataset[T] = - filter(t => func.call(t).booleanValue()) + def filter(func: FilterFunction[T]): Dataset[T] = filter(t => func.call(t)) /** * (Scala-specific) @@ -181,7 +185,7 @@ class Dataset[T] private[sql]( * Returns a new [[Dataset]] that contains the result of applying `func` to each element. * @since 1.6.0 */ - def map[U](func: JFunction[T, U], encoder: Encoder[U]): Dataset[U] = + def map[U](func: MapFunction[T, U], encoder: Encoder[U]): Dataset[U] = map(t => func.call(t))(encoder) /** @@ -205,10 +209,8 @@ class Dataset[T] private[sql]( * Returns a new [[Dataset]] that contains the result of applying `func` to each element. * @since 1.6.0 */ - def mapPartitions[U]( - f: FlatMapFunction[java.util.Iterator[T], U], - encoder: Encoder[U]): Dataset[U] = { - val func: (Iterator[T]) => Iterator[U] = x => f.call(x.asJava).iterator().asScala + def mapPartitions[U](f: MapPartitionsFunction[T, U], encoder: Encoder[U]): Dataset[U] = { + val func: (Iterator[T]) => Iterator[U] = x => f.call(x.asJava).iterator.asScala mapPartitions(func)(encoder) } @@ -248,7 +250,7 @@ class Dataset[T] private[sql]( * Runs `func` on each element of this Dataset. * @since 1.6.0 */ - def foreach(func: VoidFunction[T]): Unit = foreach(func.call(_)) + def foreach(func: ForeachFunction[T]): Unit = foreach(func.call(_)) /** * (Scala-specific) @@ -262,7 +264,7 @@ class Dataset[T] private[sql]( * Runs `func` on each partition of this Dataset. * @since 1.6.0 */ - def foreachPartition(func: VoidFunction[java.util.Iterator[T]]): Unit = + def foreachPartition(func: ForeachPartitionFunction[T]): Unit = foreachPartition(it => func.call(it.asJava)) /* ************* * @@ -271,7 +273,7 @@ class Dataset[T] private[sql]( /** * (Scala-specific) - * Reduces the elements of this Dataset using the specified binary function. The given function + * Reduces the elements of this Dataset using the specified binary function. The given function * must be commutative and associative or the result may be non-deterministic. * @since 1.6.0 */ @@ -279,33 +281,11 @@ class Dataset[T] private[sql]( /** * (Java-specific) - * Reduces the elements of this Dataset using the specified binary function. The given function + * Reduces the elements of this Dataset using the specified binary function. The given function * must be commutative and associative or the result may be non-deterministic. * @since 1.6.0 */ - def reduce(func: JFunction2[T, T, T]): T = reduce(func.call(_, _)) - - /** - * (Scala-specific) - * Aggregates the elements of each partition, and then the results for all the partitions, using a - * given associative and commutative function and a neutral "zero value". - * - * This behaves somewhat differently than the fold operations implemented for non-distributed - * collections in functional languages like Scala. This fold operation may be applied to - * partitions individually, and then those results will be folded into the final result. - * If op is not commutative, then the result may differ from that of a fold applied to a - * non-distributed collection. - * @since 1.6.0 - */ - def fold(zeroValue: T)(op: (T, T) => T): T = rdd.fold(zeroValue)(op) - - /** - * (Java-specific) - * Aggregates the elements of each partition, and then the results for all the partitions, using a - * given associative and commutative function and a neutral "zero value". - * @since 1.6.0 - */ - def fold(zeroValue: T, func: JFunction2[T, T, T]): T = fold(zeroValue)(func.call(_, _)) + def reduce(func: ReduceFunction[T]): T = reduce(func.call(_, _)) /** * (Scala-specific) @@ -351,7 +331,7 @@ class Dataset[T] private[sql]( * Returns a [[GroupedDataset]] where the data is grouped by the given key function. * @since 1.6.0 */ - def groupBy[K](f: JFunction[T, K], encoder: Encoder[K]): GroupedDataset[K, T] = + def groupBy[K](f: MapFunction[T, K], encoder: Encoder[K]): GroupedDataset[K, T] = groupBy(f.call(_))(encoder) /* ****************** * @@ -367,7 +347,7 @@ class Dataset[T] private[sql]( */ // Copied from Dataframe to make sure we don't have invalid overloads. @scala.annotation.varargs - def select(cols: Column*): DataFrame = toDF().select(cols: _*) + protected def select(cols: Column*): DataFrame = toDF().select(cols: _*) /** * Returns a new [[Dataset]] by computing the given [[Column]] expression for each element. @@ -462,8 +442,7 @@ class Dataset[T] private[sql]( * and thus is not affected by a custom `equals` function defined on `T`. * @since 1.6.0 */ - def intersect(other: Dataset[T]): Dataset[T] = - withPlan[T](other)(Intersect) + def intersect(other: Dataset[T]): Dataset[T] = withPlan[T](other)(Intersect) /** * Returns a new [[Dataset]] that contains the elements of both this and the `other` [[Dataset]] @@ -473,8 +452,7 @@ class Dataset[T] private[sql]( * duplicate items. As such, it is analagous to `UNION ALL` in SQL. * @since 1.6.0 */ - def union(other: Dataset[T]): Dataset[T] = - withPlan[T](other)(Union) + def union(other: Dataset[T]): Dataset[T] = withPlan[T](other)(Union) /** * Returns a new [[Dataset]] where any elements present in `other` have been removed. @@ -542,27 +520,47 @@ class Dataset[T] private[sql]( def first(): T = rdd.first() /** - * Collects the elements to an Array. + * Returns an array that contains all the elements in this [[Dataset]]. + * + * Running collect requires moving all the data into the application's driver process, and + * doing so on a very large dataset can crash the driver process with OutOfMemoryError. + * + * For Java API, use [[collectAsList]]. * @since 1.6.0 */ def collect(): Array[T] = rdd.collect() /** - * (Java-specific) - * Collects the elements to a Java list. + * Returns an array that contains all the elements in this [[Dataset]]. * - * Due to the incompatibility problem between Scala and Java, the return type of [[collect()]] at - * Java side is `java.lang.Object`, which is not easy to use. Java user can use this method - * instead and keep the generic type for result. + * Running collect requires moving all the data into the application's driver process, and + * doing so on a very large dataset can crash the driver process with OutOfMemoryError. * + * For Java API, use [[collectAsList]]. * @since 1.6.0 */ - def collectAsList(): java.util.List[T] = - rdd.collect().toSeq.asJava + def collectAsList(): java.util.List[T] = rdd.collect().toSeq.asJava - /** Returns the first `num` elements of this [[Dataset]] as an Array. */ + /** + * Returns the first `num` elements of this [[Dataset]] as an array. + * + * Running take requires moving data into the application's driver process, and doing so with + * a very large `n` can crash the driver process with OutOfMemoryError. + * + * @since 1.6.0 + */ def take(num: Int): Array[T] = rdd.take(num) + /** + * Returns the first `num` elements of this [[Dataset]] as an array. + * + * Running take requires moving data into the application's driver process, and doing so with + * a very large `n` can crash the driver process with OutOfMemoryError. + * + * @since 1.6.0 + */ + def takeAsList(num: Int): java.util.List[T] = java.util.Arrays.asList(take(num) : _*) + /* ******************** * * Internal Functions * * ******************** */ diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index 40bff57a17a03..d191b50fa802e 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -65,6 +65,13 @@ public void testExecution() { Assert.assertEquals(1, df.select("key").collect()[0].get(0)); } + @Test + public void testCollectAndTake() { + DataFrame df = context.table("testData").filter("key = 1 or key = 2 or key = 3"); + Assert.assertEquals(3, df.select("key").collectAsList().size()); + Assert.assertEquals(2, df.select("key").takeAsList(2).size()); + } + /** * See SPARK-5904. Abstract vararg methods defined in Scala do not work in Java. */ diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index 0d3b1a5af52c4..0f90de774dd3e 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -68,8 +68,16 @@ private Tuple2 tuple2(T1 t1, T2 t2) { public void testCollect() { List data = Arrays.asList("hello", "world"); Dataset ds = context.createDataset(data, e.STRING()); - String[] collected = (String[]) ds.collect(); - Assert.assertEquals(Arrays.asList("hello", "world"), Arrays.asList(collected)); + List collected = ds.collectAsList(); + Assert.assertEquals(Arrays.asList("hello", "world"), collected); + } + + @Test + public void testTake() { + List data = Arrays.asList("hello", "world"); + Dataset ds = context.createDataset(data, e.STRING()); + List collected = ds.takeAsList(1); + Assert.assertEquals(Arrays.asList("hello"), collected); } @Test @@ -78,16 +86,16 @@ public void testCommonOperation() { Dataset ds = context.createDataset(data, e.STRING()); Assert.assertEquals("hello", ds.first()); - Dataset filtered = ds.filter(new Function() { + Dataset filtered = ds.filter(new FilterFunction() { @Override - public Boolean call(String v) throws Exception { + public boolean call(String v) throws Exception { return v.startsWith("h"); } }); Assert.assertEquals(Arrays.asList("hello"), filtered.collectAsList()); - Dataset mapped = ds.map(new Function() { + Dataset mapped = ds.map(new MapFunction() { @Override public Integer call(String v) throws Exception { return v.length(); @@ -95,7 +103,7 @@ public Integer call(String v) throws Exception { }, e.INT()); Assert.assertEquals(Arrays.asList(5, 5), mapped.collectAsList()); - Dataset parMapped = ds.mapPartitions(new FlatMapFunction, String>() { + Dataset parMapped = ds.mapPartitions(new MapPartitionsFunction() { @Override public Iterable call(Iterator it) throws Exception { List ls = new LinkedList(); @@ -128,7 +136,7 @@ public void testForeach() { List data = Arrays.asList("a", "b", "c"); Dataset ds = context.createDataset(data, e.STRING()); - ds.foreach(new VoidFunction() { + ds.foreach(new ForeachFunction() { @Override public void call(String s) throws Exception { accum.add(1); @@ -142,28 +150,20 @@ public void testReduce() { List data = Arrays.asList(1, 2, 3); Dataset ds = context.createDataset(data, e.INT()); - int reduced = ds.reduce(new Function2() { + int reduced = ds.reduce(new ReduceFunction() { @Override public Integer call(Integer v1, Integer v2) throws Exception { return v1 + v2; } }); Assert.assertEquals(6, reduced); - - int folded = ds.fold(1, new Function2() { - @Override - public Integer call(Integer v1, Integer v2) throws Exception { - return v1 * v2; - } - }); - Assert.assertEquals(6, folded); } @Test public void testGroupBy() { List data = Arrays.asList("a", "foo", "bar"); Dataset ds = context.createDataset(data, e.STRING()); - GroupedDataset grouped = ds.groupBy(new Function() { + GroupedDataset grouped = ds.groupBy(new MapFunction() { @Override public Integer call(String v) throws Exception { return v.length(); @@ -187,7 +187,7 @@ public String call(Integer key, Iterator data) throws Exception { List data2 = Arrays.asList(2, 6, 10); Dataset ds2 = context.createDataset(data2, e.INT()); - GroupedDataset grouped2 = ds2.groupBy(new Function() { + GroupedDataset grouped2 = ds2.groupBy(new MapFunction() { @Override public Integer call(Integer v) throws Exception { return v / 2; diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala index fcf03f7180984..63b00975e4eb1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala @@ -75,11 +75,6 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext { assert(ds.reduce(_ + _) == 6) } - test("fold") { - val ds = Seq(1, 2, 3).toDS() - assert(ds.fold(0)(_ + _) == 6) - } - test("groupBy function, keys") { val ds = Seq(1, 2, 3, 4, 5).toDS() val grouped = ds.groupBy(_ % 2) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 6f1174e6577e3..aea5a700d0204 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -61,6 +61,11 @@ class DatasetSuite extends QueryTest with SharedSQLContext { assert(ds.collect() === Array(ClassData("a", 1), ClassData("b", 2), ClassData("c", 3))) } + test("as case class - take") { + val ds = Seq((1, "a"), (2, "b"), (3, "c")).toDF("b", "a").as[ClassData] + assert(ds.take(2) === Array(ClassData("a", 1), ClassData("b", 2))) + } + test("map") { val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS() checkAnswer( @@ -137,11 +142,6 @@ class DatasetSuite extends QueryTest with SharedSQLContext { assert(ds.reduce((a, b) => ("sum", a._2 + b._2)) == ("sum", 6)) } - test("fold") { - val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS() - assert(ds.fold(("", 0))((a, b) => ("sum", a._2 + b._2)) == ("sum", 6)) - } - test("joinWith, flat schema") { val ds1 = Seq(1, 2, 3).toDS().as("a") val ds2 = Seq(1, 2).toDS().as("b") From d8b50f70298dbf45e91074ee2d751fee7eecb119 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Sun, 8 Nov 2015 21:01:53 -0800 Subject: [PATCH 39/88] [SPARK-11453][SQL] append data to partitioned table will messes up the result The reason is that: 1. For partitioned hive table, we will move the partitioned columns after data columns. (e.g. `` partition by `a` will become ``) 2. When append data to table, we use position to figure out how to match input columns to table's columns. So when we append data to partitioned table, we will match wrong columns between input and table. A solution is reordering the input columns before match by position, like what we did for [`InsertIntoHadoopFsRelation`](https://github.com/apache/spark/blob/master/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelation.scala#L101-L105) Author: Wenchen Fan Closes #9408 from cloud-fan/append. --- .../apache/spark/sql/DataFrameWriter.scala | 29 ++++++++++++++++--- .../sql/sources/PartitionedWriteSuite.scala | 8 +++++ .../sql/hive/execution/SQLQuerySuite.scala | 20 +++++++++++++ 3 files changed, 53 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 7887e559a3025..e63a4d5e8b10b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -23,8 +23,8 @@ import scala.collection.JavaConverters._ import org.apache.spark.annotation.Experimental import org.apache.spark.sql.catalyst.{SqlParser, TableIdentifier} -import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation -import org.apache.spark.sql.catalyst.plans.logical.InsertIntoTable +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedRelation} +import org.apache.spark.sql.catalyst.plans.logical.{Project, InsertIntoTable} import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils import org.apache.spark.sql.execution.datasources.{CreateTableUsingAsSelect, ResolvedDataSource} import org.apache.spark.sql.sources.HadoopFsRelation @@ -167,17 +167,38 @@ final class DataFrameWriter private[sql](df: DataFrame) { } private def insertInto(tableIdent: TableIdentifier): Unit = { - val partitions = partitioningColumns.map(_.map(col => col -> (None: Option[String])).toMap) + val partitions = normalizedParCols.map(_.map(col => col -> (None: Option[String])).toMap) val overwrite = mode == SaveMode.Overwrite + + // A partitioned relation's schema can be different from the input logicalPlan, since + // partition columns are all moved after data columns. We Project to adjust the ordering. + // TODO: this belongs to the analyzer. + val input = normalizedParCols.map { parCols => + val (inputPartCols, inputDataCols) = df.logicalPlan.output.partition { attr => + parCols.contains(attr.name) + } + Project(inputDataCols ++ inputPartCols, df.logicalPlan) + }.getOrElse(df.logicalPlan) + df.sqlContext.executePlan( InsertIntoTable( UnresolvedRelation(tableIdent), partitions.getOrElse(Map.empty[String, Option[String]]), - df.logicalPlan, + input, overwrite, ifNotExists = false)).toRdd } + private def normalizedParCols: Option[Seq[String]] = partitioningColumns.map { parCols => + parCols.map { col => + df.logicalPlan.output + .map(_.name) + .find(df.sqlContext.analyzer.resolver(_, col)) + .getOrElse(throw new AnalysisException(s"Partition column $col not found in existing " + + s"columns (${df.logicalPlan.output.map(_.name).mkString(", ")})")) + } + } + /** * Saves the content of the [[DataFrame]] as the specified table. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala index c9791879ec74c..3eaa817f9c0b0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala @@ -53,4 +53,12 @@ class PartitionedWriteSuite extends QueryTest with SharedSQLContext { Utils.deleteRecursively(path) } + + test("partitioned columns should appear at the end of schema") { + withTempPath { f => + val path = f.getAbsolutePath + Seq(1 -> "a").toDF("i", "j").write.partitionBy("i").parquet(path) + assert(sqlContext.read.parquet(path).schema.map(_.name) == Seq("j", "i")) + } + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index af48d478953b4..9a425d7f6b265 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -1428,4 +1428,24 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { checkAnswer(sql("SELECT val FROM tbl10562 WHERE Year == 2012"), Row("a")) } } + + test("SPARK-11453: append data to partitioned table") { + withTable("tbl11453") { + Seq("1" -> "10", "2" -> "20").toDF("i", "j") + .write.partitionBy("i").saveAsTable("tbl11453") + + Seq("3" -> "30").toDF("i", "j") + .write.mode(SaveMode.Append).partitionBy("i").saveAsTable("tbl11453") + checkAnswer( + sqlContext.read.table("tbl11453").select("i", "j").orderBy("i"), + Row("1", "10") :: Row("2", "20") :: Row("3", "30") :: Nil) + + // make sure case sensitivity is correct. + Seq("4" -> "40").toDF("i", "j") + .write.mode(SaveMode.Append).partitionBy("I").saveAsTable("tbl11453") + checkAnswer( + sqlContext.read.table("tbl11453").select("i", "j").orderBy("i"), + Row("1", "10") :: Row("2", "20") :: Row("3", "30") :: Row("4", "40") :: Nil) + } + } } From 9e48cdfbdecc9554a425ba35c0252910fd1e8faa Mon Sep 17 00:00:00 2001 From: Charles Yeh Date: Mon, 9 Nov 2015 13:22:05 +0100 Subject: [PATCH 40/88] [SPARK-11218][CORE] show help messages for start-slave and start-master Addressing https://issues.apache.org/jira/browse/SPARK-11218, mostly copied start-thriftserver.sh. ``` charlesyeh-mbp:spark charlesyeh$ ./sbin/start-master.sh --help Usage: Master [options] Options: -i HOST, --ip HOST Hostname to listen on (deprecated, please use --host or -h) -h HOST, --host HOST Hostname to listen on -p PORT, --port PORT Port to listen on (default: 7077) --webui-port PORT Port for web UI (default: 8080) --properties-file FILE Path to a custom Spark properties file. Default is conf/spark-defaults.conf. ``` ``` charlesyeh-mbp:spark charlesyeh$ ./sbin/start-slave.sh Usage: Worker [options] Master must be a URL of the form spark://hostname:port Options: -c CORES, --cores CORES Number of cores to use -m MEM, --memory MEM Amount of memory to use (e.g. 1000M, 2G) -d DIR, --work-dir DIR Directory to run apps in (default: SPARK_HOME/work) -i HOST, --ip IP Hostname to listen on (deprecated, please use --host or -h) -h HOST, --host HOST Hostname to listen on -p PORT, --port PORT Port to listen on (default: random) --webui-port PORT Port for web UI (default: 8081) --properties-file FILE Path to a custom Spark properties file. Default is conf/spark-defaults.conf. ``` Author: Charles Yeh Closes #9432 from CharlesYeh/helpmsg. --- sbin/start-master.sh | 24 +++++++++++++++++++----- sbin/start-slave.sh | 24 +++++++++++++++--------- 2 files changed, 34 insertions(+), 14 deletions(-) diff --git a/sbin/start-master.sh b/sbin/start-master.sh index c20e19a8412df..9f2e14dff609f 100755 --- a/sbin/start-master.sh +++ b/sbin/start-master.sh @@ -23,6 +23,20 @@ if [ -z "${SPARK_HOME}" ]; then export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" fi +# NOTE: This exact class name is matched downstream by SparkSubmit. +# Any changes need to be reflected there. +CLASS="org.apache.spark.deploy.master.Master" + +if [[ "$@" = *--help ]] || [[ "$@" = *-h ]]; then + echo "Usage: ./sbin/start-master.sh [options]" + pattern="Usage:" + pattern+="\|Using Spark's default log4j profile:" + pattern+="\|Registered signal handlers for" + + "${SPARK_HOME}"/bin/spark-class $CLASS --help 2>&1 | grep -v "$pattern" 1>&2 + exit 1 +fi + ORIGINAL_ARGS="$@" START_TACHYON=false @@ -30,7 +44,7 @@ START_TACHYON=false while (( "$#" )); do case $1 in --with-tachyon) - if [ ! -e "$sbin"/../tachyon/bin/tachyon ]; then + if [ ! -e "${SPARK_HOME}"/tachyon/bin/tachyon ]; then echo "Error: --with-tachyon specified, but tachyon not found." exit -1 fi @@ -56,12 +70,12 @@ if [ "$SPARK_MASTER_WEBUI_PORT" = "" ]; then SPARK_MASTER_WEBUI_PORT=8080 fi -"${SPARK_HOME}/sbin"/spark-daemon.sh start org.apache.spark.deploy.master.Master 1 \ +"${SPARK_HOME}/sbin"/spark-daemon.sh start $CLASS 1 \ --ip $SPARK_MASTER_IP --port $SPARK_MASTER_PORT --webui-port $SPARK_MASTER_WEBUI_PORT \ $ORIGINAL_ARGS if [ "$START_TACHYON" == "true" ]; then - "${SPARK_HOME}/sbin"/../tachyon/bin/tachyon bootstrap-conf $SPARK_MASTER_IP - "${SPARK_HOME}/sbin"/../tachyon/bin/tachyon format -s - "${SPARK_HOME}/sbin"/../tachyon/bin/tachyon-start.sh master + "${SPARK_HOME}"/tachyon/bin/tachyon bootstrap-conf $SPARK_MASTER_IP + "${SPARK_HOME}"/tachyon/bin/tachyon format -s + "${SPARK_HOME}"/tachyon/bin/tachyon-start.sh master fi diff --git a/sbin/start-slave.sh b/sbin/start-slave.sh index 21455648d1c6d..8c268b8859155 100755 --- a/sbin/start-slave.sh +++ b/sbin/start-slave.sh @@ -31,18 +31,24 @@ # worker. Subsequent workers will increment this # number. Default is 8081. -usage="Usage: start-slave.sh where is like spark://localhost:7077" - -if [ $# -lt 1 ]; then - echo $usage - echo Called as start-slave.sh $* - exit 1 -fi - if [ -z "${SPARK_HOME}" ]; then export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" fi +# NOTE: This exact class name is matched downstream by SparkSubmit. +# Any changes need to be reflected there. +CLASS="org.apache.spark.deploy.worker.Worker" + +if [[ $# -lt 1 ]] || [[ "$@" = *--help ]] || [[ "$@" = *-h ]]; then + echo "Usage: ./sbin/start-slave.sh [options] " + pattern="Usage:" + pattern+="\|Using Spark's default log4j profile:" + pattern+="\|Registered signal handlers for" + + "${SPARK_HOME}"/bin/spark-class $CLASS --help 2>&1 | grep -v "$pattern" 1>&2 + exit 1 +fi + . "${SPARK_HOME}/sbin/spark-config.sh" . "${SPARK_HOME}/bin/load-spark-env.sh" @@ -72,7 +78,7 @@ function start_instance { fi WEBUI_PORT=$(( $SPARK_WORKER_WEBUI_PORT + $WORKER_NUM - 1 )) - "${SPARK_HOME}/sbin"/spark-daemon.sh start org.apache.spark.deploy.worker.Worker $WORKER_NUM \ + "${SPARK_HOME}/sbin"/spark-daemon.sh start $CLASS $WORKER_NUM \ --webui-port "$WEBUI_PORT" $PORT_FLAG $PORT_NUM $MASTER "$@" } From b541b31630b1b85b48d6096079d073ccf46a62e8 Mon Sep 17 00:00:00 2001 From: Rohit Agarwal Date: Mon, 9 Nov 2015 13:28:00 +0100 Subject: [PATCH 41/88] [DOC][MINOR][SQL] Fix internal link It doesn't show up as a hyperlink currently. It will show up as a hyperlink after this change. Author: Rohit Agarwal Closes #9544 from mindprince/patch-2. --- docs/sql-programming-guide.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 52e03b951f966..ccd26904329d3 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -2287,7 +2287,7 @@ Several caching related features are not supported yet: Spark SQL is designed to be compatible with the Hive Metastore, SerDes and UDFs. Currently Hive SerDes and UDFs are based on Hive 1.2.1, and Spark SQL can be connected to different versions of Hive Metastore -(from 0.12.0 to 1.2.1. Also see http://spark.apache.org/docs/latest/sql-programming-guide.html#interacting-with-different-versions-of-hive-metastore). +(from 0.12.0 to 1.2.1. Also see [Interacting with Different Versions of Hive Metastore] (#interacting-with-different-versions-of-hive-metastore)). #### Deploying in Existing Hive Warehouses From 8c0e1b50e960d3e8e51d0618c462eed2bb4936f0 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Mon, 9 Nov 2015 08:56:22 -0800 Subject: [PATCH 42/88] [SPARK-11494][ML][R] Expose R-like summary statistics in SparkR::glm for linear regression Expose R-like summary statistics in SparkR::glm for linear regression, the output of ```summary``` like ```Java $DevianceResiduals Min Max -0.9509607 0.7291832 $Coefficients Estimate Std. Error t value Pr(>|t|) (Intercept) 1.6765 0.2353597 7.123139 4.456124e-11 Sepal_Length 0.3498801 0.04630128 7.556598 4.187317e-12 Species_versicolor -0.9833885 0.07207471 -13.64402 0 Species_virginica -1.00751 0.09330565 -10.79796 0 ``` Author: Yanbo Liang Closes #9561 from yanboliang/spark-11494. --- R/pkg/R/mllib.R | 22 ++++++-- R/pkg/inst/tests/test_mllib.R | 31 +++++++++--- .../apache/spark/ml/r/SparkRWrappers.scala | 50 +++++++++++++++++-- 3 files changed, 88 insertions(+), 15 deletions(-) diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R index b0d73dd93a79d..7ff859741b4a0 100644 --- a/R/pkg/R/mllib.R +++ b/R/pkg/R/mllib.R @@ -91,12 +91,26 @@ setMethod("predict", signature(object = "PipelineModel"), #'} setMethod("summary", signature(x = "PipelineModel"), function(x, ...) { + modelName <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", + "getModelName", x@model) features <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", "getModelFeatures", x@model) coefficients <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", "getModelCoefficients", x@model) - coefficients <- as.matrix(unlist(coefficients)) - colnames(coefficients) <- c("Estimate") - rownames(coefficients) <- unlist(features) - return(list(coefficients = coefficients)) + if (modelName == "LinearRegressionModel") { + devianceResiduals <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", + "getModelDevianceResiduals", x@model) + devianceResiduals <- matrix(devianceResiduals, nrow = 1) + colnames(devianceResiduals) <- c("Min", "Max") + rownames(devianceResiduals) <- rep("", times = 1) + coefficients <- matrix(coefficients, ncol = 4) + colnames(coefficients) <- c("Estimate", "Std. Error", "t value", "Pr(>|t|)") + rownames(coefficients) <- unlist(features) + return(list(DevianceResiduals = devianceResiduals, Coefficients = coefficients)) + } else { + coefficients <- as.matrix(unlist(coefficients)) + colnames(coefficients) <- c("Estimate") + rownames(coefficients) <- unlist(features) + return(list(coefficients = coefficients)) + } }) diff --git a/R/pkg/inst/tests/test_mllib.R b/R/pkg/inst/tests/test_mllib.R index 4761e285a2479..2606407bdcb44 100644 --- a/R/pkg/inst/tests/test_mllib.R +++ b/R/pkg/inst/tests/test_mllib.R @@ -71,12 +71,23 @@ test_that("feature interaction vs native glm", { test_that("summary coefficients match with native glm", { training <- createDataFrame(sqlContext, iris) - stats <- summary(glm(Sepal_Width ~ Sepal_Length + Species, data = training, solver = "l-bfgs")) - coefs <- as.vector(stats$coefficients) + stats <- summary(glm(Sepal_Width ~ Sepal_Length + Species, data = training, solver = "normal")) + coefs <- unlist(stats$Coefficients) + devianceResiduals <- unlist(stats$DevianceResiduals) + rCoefs <- as.vector(coef(glm(Sepal.Width ~ Sepal.Length + Species, data = iris))) - expect_true(all(abs(rCoefs - coefs) < 1e-6)) + rStdError <- c(0.23536, 0.04630, 0.07207, 0.09331) + rTValue <- c(7.123, 7.557, -13.644, -10.798) + rPValue <- c(0.0, 0.0, 0.0, 0.0) + rDevianceResiduals <- c(-0.95096, 0.72918) + + expect_true(all(abs(rCoefs - coefs[1:4]) < 1e-6)) + expect_true(all(abs(rStdError - coefs[5:8]) < 1e-5)) + expect_true(all(abs(rTValue - coefs[9:12]) < 1e-3)) + expect_true(all(abs(rPValue - coefs[13:16]) < 1e-6)) + expect_true(all(abs(rDevianceResiduals - devianceResiduals) < 1e-5)) expect_true(all( - as.character(stats$features) == + rownames(stats$Coefficients) == c("(Intercept)", "Sepal_Length", "Species_versicolor", "Species_virginica"))) }) @@ -85,14 +96,20 @@ test_that("summary coefficients match with native glm of family 'binomial'", { training <- filter(df, df$Species != "setosa") stats <- summary(glm(Species ~ Sepal_Length + Sepal_Width, data = training, family = "binomial")) - coefs <- as.vector(stats$coefficients) + coefs <- as.vector(stats$Coefficients) rTraining <- iris[iris$Species %in% c("versicolor","virginica"),] rCoefs <- as.vector(coef(glm(Species ~ Sepal.Length + Sepal.Width, data = rTraining, family = binomial(link = "logit")))) + rStdError <- c(3.0974, 0.5169, 0.8628) + rTValue <- c(-4.212, 3.680, 0.469) + rPValue <- c(0.000, 0.000, 0.639) - expect_true(all(abs(rCoefs - coefs) < 1e-4)) + expect_true(all(abs(rCoefs - coefs[1:3]) < 1e-4)) + expect_true(all(abs(rStdError - coefs[4:6]) < 1e-4)) + expect_true(all(abs(rTValue - coefs[7:9]) < 1e-3)) + expect_true(all(abs(rPValue - coefs[10:12]) < 1e-3)) expect_true(all( - as.character(stats$features) == + rownames(stats$Coefficients) == c("(Intercept)", "Sepal_Length", "Sepal_Width"))) }) diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala b/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala index 5be2f86936211..4d82b90bfdf20 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala @@ -52,11 +52,36 @@ private[r] object SparkRWrappers { } def getModelCoefficients(model: PipelineModel): Array[Double] = { + model.stages.last match { + case m: LinearRegressionModel => { + val coefficientStandardErrorsR = Array(m.summary.coefficientStandardErrors.last) ++ + m.summary.coefficientStandardErrors.dropRight(1) + val tValuesR = Array(m.summary.tValues.last) ++ m.summary.tValues.dropRight(1) + val pValuesR = Array(m.summary.pValues.last) ++ m.summary.pValues.dropRight(1) + if (m.getFitIntercept) { + Array(m.intercept) ++ m.coefficients.toArray ++ coefficientStandardErrorsR ++ + tValuesR ++ pValuesR + } else { + m.coefficients.toArray ++ coefficientStandardErrorsR ++ tValuesR ++ pValuesR + } + } + case m: LogisticRegressionModel => { + if (m.getFitIntercept) { + Array(m.intercept) ++ m.coefficients.toArray + } else { + m.coefficients.toArray + } + } + } + } + + def getModelDevianceResiduals(model: PipelineModel): Array[Double] = { model.stages.last match { case m: LinearRegressionModel => - Array(m.intercept) ++ m.coefficients.toArray + m.summary.devianceResiduals case m: LogisticRegressionModel => - Array(m.intercept) ++ m.coefficients.toArray + throw new UnsupportedOperationException( + "No deviance residuals available for LogisticRegressionModel") } } @@ -65,11 +90,28 @@ private[r] object SparkRWrappers { case m: LinearRegressionModel => val attrs = AttributeGroup.fromStructField( m.summary.predictions.schema(m.summary.featuresCol)) - Array("(Intercept)") ++ attrs.attributes.get.map(_.name.get) + if (m.getFitIntercept) { + Array("(Intercept)") ++ attrs.attributes.get.map(_.name.get) + } else { + attrs.attributes.get.map(_.name.get) + } case m: LogisticRegressionModel => val attrs = AttributeGroup.fromStructField( m.summary.predictions.schema(m.summary.featuresCol)) - Array("(Intercept)") ++ attrs.attributes.get.map(_.name.get) + if (m.getFitIntercept) { + Array("(Intercept)") ++ attrs.attributes.get.map(_.name.get) + } else { + attrs.attributes.get.map(_.name.get) + } + } + } + + def getModelName(model: PipelineModel): String = { + model.stages.last match { + case m: LinearRegressionModel => + "LinearRegressionModel" + case m: LogisticRegressionModel => + "LogisticRegressionModel" } } } From d50a66cc04bfa1c483f04daffe465322316c745e Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Mon, 9 Nov 2015 08:57:29 -0800 Subject: [PATCH 43/88] [SPARK-10689][ML][DOC] User guide and example code for AFTSurvivalRegression Add user guide and example code for ```AFTSurvivalRegression```. Author: Yanbo Liang Closes #9491 from yanboliang/spark-10689. --- docs/ml-guide.md | 1 + docs/ml-survival-regression.md | 96 +++++++++++++++++++ .../ml/JavaAFTSurvivalRegressionExample.java | 71 ++++++++++++++ .../main/python/ml/aft_survival_regression.py | 51 ++++++++++ .../ml/AFTSurvivalRegressionExample.scala | 62 ++++++++++++ 5 files changed, 281 insertions(+) create mode 100644 docs/ml-survival-regression.md create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaAFTSurvivalRegressionExample.java create mode 100644 examples/src/main/python/ml/aft_survival_regression.py create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/AFTSurvivalRegressionExample.scala diff --git a/docs/ml-guide.md b/docs/ml-guide.md index fd3a6167bc65e..c293e71d2870e 100644 --- a/docs/ml-guide.md +++ b/docs/ml-guide.md @@ -44,6 +44,7 @@ provide class probabilities, and linear models provide model summaries. * [Ensembles](ml-ensembles.html) * [Linear methods with elastic net regularization](ml-linear-methods.html) * [Multilayer perceptron classifier](ml-ann.html) +* [Survival Regression](ml-survival-regression.html) # Main concepts in Pipelines diff --git a/docs/ml-survival-regression.md b/docs/ml-survival-regression.md new file mode 100644 index 0000000000000..ab275213b9a84 --- /dev/null +++ b/docs/ml-survival-regression.md @@ -0,0 +1,96 @@ +--- +layout: global +title: Survival Regression - ML +displayTitle: ML - Survival Regression +--- + + +`\[ +\newcommand{\R}{\mathbb{R}} +\newcommand{\E}{\mathbb{E}} +\newcommand{\x}{\mathbf{x}} +\newcommand{\y}{\mathbf{y}} +\newcommand{\wv}{\mathbf{w}} +\newcommand{\av}{\mathbf{\alpha}} +\newcommand{\bv}{\mathbf{b}} +\newcommand{\N}{\mathbb{N}} +\newcommand{\id}{\mathbf{I}} +\newcommand{\ind}{\mathbf{1}} +\newcommand{\0}{\mathbf{0}} +\newcommand{\unit}{\mathbf{e}} +\newcommand{\one}{\mathbf{1}} +\newcommand{\zero}{\mathbf{0}} +\]` + + +In `spark.ml`, we implement the [Accelerated failure time (AFT)](https://en.wikipedia.org/wiki/Accelerated_failure_time_model) +model which is a parametric survival regression model for censored data. +It describes a model for the log of survival time, so it's often called +log-linear model for survival analysis. Different from +[Proportional hazards](https://en.wikipedia.org/wiki/Proportional_hazards_model) model +designed for the same purpose, the AFT model is more easily to parallelize +because each instance contribute to the objective function independently. + +Given the values of the covariates $x^{'}$, for random lifetime $t_{i}$ of +subjects i = 1, ..., n, with possible right-censoring, +the likelihood function under the AFT model is given as: +`\[ +L(\beta,\sigma)=\prod_{i=1}^n[\frac{1}{\sigma}f_{0}(\frac{\log{t_{i}}-x^{'}\beta}{\sigma})]^{\delta_{i}}S_{0}(\frac{\log{t_{i}}-x^{'}\beta}{\sigma})^{1-\delta_{i}} +\]` +Where $\delta_{i}$ is the indicator of the event has occurred i.e. uncensored or not. +Using $\epsilon_{i}=\frac{\log{t_{i}}-x^{'}\beta}{\sigma}$, the log-likelihood function +assumes the form: +`\[ +\iota(\beta,\sigma)=\sum_{i=1}^{n}[-\delta_{i}\log\sigma+\delta_{i}\log{f_{0}}(\epsilon_{i})+(1-\delta_{i})\log{S_{0}(\epsilon_{i})}] +\]` +Where $S_{0}(\epsilon_{i})$ is the baseline survivor function, +and $f_{0}(\epsilon_{i})$ is corresponding density function. + +The most commonly used AFT model is based on the Weibull distribution of the survival time. +The Weibull distribution for lifetime corresponding to extreme value distribution for +log of the lifetime, and the $S_{0}(\epsilon)$ function is: +`\[ +S_{0}(\epsilon_{i})=\exp(-e^{\epsilon_{i}}) +\]` +the $f_{0}(\epsilon_{i})$ function is: +`\[ +f_{0}(\epsilon_{i})=e^{\epsilon_{i}}\exp(-e^{\epsilon_{i}}) +\]` +The log-likelihood function for AFT model with Weibull distribution of lifetime is: +`\[ +\iota(\beta,\sigma)= -\sum_{i=1}^n[\delta_{i}\log\sigma-\delta_{i}\epsilon_{i}+e^{\epsilon_{i}}] +\]` +Due to minimizing the negative log-likelihood equivalent to maximum a posteriori probability, +the loss function we use to optimize is $-\iota(\beta,\sigma)$. +The gradient functions for $\beta$ and $\log\sigma$ respectively are: +`\[ +\frac{\partial (-\iota)}{\partial \beta}=\sum_{1=1}^{n}[\delta_{i}-e^{\epsilon_{i}}]\frac{x_{i}}{\sigma} +\]` +`\[ +\frac{\partial (-\iota)}{\partial (\log\sigma)}=\sum_{i=1}^{n}[\delta_{i}+(\delta_{i}-e^{\epsilon_{i}})\epsilon_{i}] +\]` + +The AFT model can be formulated as a convex optimization problem, +i.e. the task of finding a minimizer of a convex function $-\iota(\beta,\sigma)$ +that depends coefficients vector $\beta$ and the log of scale parameter $\log\sigma$. +The optimization algorithm underlying the implementation is L-BFGS. +The implementation matches the result from R's survival function +[survreg](https://stat.ethz.ch/R-manual/R-devel/library/survival/html/survreg.html) + +## Example: + +
+ +
+{% include_example scala/org/apache/spark/examples/ml/AFTSurvivalRegressionExample.scala %} +
+ +
+{% include_example java/org/apache/spark/examples/ml/JavaAFTSurvivalRegressionExample.java %} +
+ +
+{% include_example python/ml/aft_survival_regression.py %} +
+ +
\ No newline at end of file diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaAFTSurvivalRegressionExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaAFTSurvivalRegressionExample.java new file mode 100644 index 0000000000000..69a174562fcf5 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaAFTSurvivalRegressionExample.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +// $example on$ +import java.util.Arrays; +import java.util.List; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.ml.regression.AFTSurvivalRegression; +import org.apache.spark.ml.regression.AFTSurvivalRegressionModel; +import org.apache.spark.mllib.linalg.*; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.types.*; +// $example off$ + +public class JavaAFTSurvivalRegressionExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaAFTSurvivalRegressionExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext jsql = new SQLContext(jsc); + + // $example on$ + List data = Arrays.asList( + RowFactory.create(1.218, 1.0, Vectors.dense(1.560, -0.605)), + RowFactory.create(2.949, 0.0, Vectors.dense(0.346, 2.158)), + RowFactory.create(3.627, 0.0, Vectors.dense(1.380, 0.231)), + RowFactory.create(0.273, 1.0, Vectors.dense(0.520, 1.151)), + RowFactory.create(4.199, 0.0, Vectors.dense(0.795, -0.226)) + ); + StructType schema = new StructType(new StructField[]{ + new StructField("label", DataTypes.DoubleType, false, Metadata.empty()), + new StructField("censor", DataTypes.DoubleType, false, Metadata.empty()), + new StructField("features", new VectorUDT(), false, Metadata.empty()) + }); + DataFrame training = jsql.createDataFrame(data, schema); + double[] quantileProbabilities = new double[]{0.3, 0.6}; + AFTSurvivalRegression aft = new AFTSurvivalRegression() + .setQuantileProbabilities(quantileProbabilities) + .setQuantilesCol("quantiles"); + + AFTSurvivalRegressionModel model = aft.fit(training); + + // Print the coefficients, intercept and scale parameter for AFT survival regression + System.out.println("Coefficients: " + model.coefficients() + " Intercept: " + + model.intercept() + " Scale: " + model.scale()); + model.transform(training).show(false); + // $example off$ + + jsc.stop(); + } +} diff --git a/examples/src/main/python/ml/aft_survival_regression.py b/examples/src/main/python/ml/aft_survival_regression.py new file mode 100644 index 0000000000000..0ee01fd8258df --- /dev/null +++ b/examples/src/main/python/ml/aft_survival_regression.py @@ -0,0 +1,51 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.ml.regression import AFTSurvivalRegression +from pyspark.mllib.linalg import Vectors +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="AFTSurvivalRegressionExample") + sqlContext = SQLContext(sc) + + # $example on$ + training = sqlContext.createDataFrame([ + (1.218, 1.0, Vectors.dense(1.560, -0.605)), + (2.949, 0.0, Vectors.dense(0.346, 2.158)), + (3.627, 0.0, Vectors.dense(1.380, 0.231)), + (0.273, 1.0, Vectors.dense(0.520, 1.151)), + (4.199, 0.0, Vectors.dense(0.795, -0.226))], ["label", "censor", "features"]) + quantileProbabilities = [0.3, 0.6] + aft = AFTSurvivalRegression(quantileProbabilities=quantileProbabilities, + quantilesCol="quantiles") + + model = aft.fit(training) + + # Print the coefficients, intercept and scale parameter for AFT survival regression + print("Coefficients: " + str(model.coefficients)) + print("Intercept: " + str(model.intercept)) + print("Scale: " + str(model.scale)) + model.transform(training).show(truncate=False) + # $example off$ + + sc.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/AFTSurvivalRegressionExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/AFTSurvivalRegressionExample.scala new file mode 100644 index 0000000000000..5da285e83681f --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/AFTSurvivalRegressionExample.scala @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkContext, SparkConf} +// $example on$ +import org.apache.spark.ml.regression.AFTSurvivalRegression +import org.apache.spark.mllib.linalg.Vectors +// $example off$ + +/** + * An example for AFTSurvivalRegression. + */ +object AFTSurvivalRegressionExample { + + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("AFTSurvivalRegressionExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + + // $example on$ + val training = sqlContext.createDataFrame(Seq( + (1.218, 1.0, Vectors.dense(1.560, -0.605)), + (2.949, 0.0, Vectors.dense(0.346, 2.158)), + (3.627, 0.0, Vectors.dense(1.380, 0.231)), + (0.273, 1.0, Vectors.dense(0.520, 1.151)), + (4.199, 0.0, Vectors.dense(0.795, -0.226)) + )).toDF("label", "censor", "features") + val quantileProbabilities = Array(0.3, 0.6) + val aft = new AFTSurvivalRegression() + .setQuantileProbabilities(quantileProbabilities) + .setQuantilesCol("quantiles") + + val model = aft.fit(training) + + // Print the coefficients, intercept and scale parameter for AFT survival regression + println(s"Coefficients: ${model.coefficients} Intercept: " + + s"${model.intercept} Scale: ${model.scale}") + model.transform(training).show(false) + // $example off$ + + sc.stop() + } +} +// scalastyle:off println From 9b88e1dcad6b5b14a22cf64a1055ad9870507b5a Mon Sep 17 00:00:00 2001 From: fazlan-nazeem Date: Mon, 9 Nov 2015 08:58:55 -0800 Subject: [PATCH 44/88] [SPARK-11582][MLLIB] specifying pmml version attribute =4.2 in the root node of pmml model The current pmml models generated do not specify the pmml version in its root node. This is a problem when using this pmml model in other tools because they expect the version attribute to be set explicitly. This fix adds the pmml version attribute to the generated pmml models and specifies its value as 4.2. Author: fazlan-nazeem Closes #9558 from fazlan-nazeem/master. --- .../org/apache/spark/mllib/pmml/export/PMMLModelExport.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExport.scala b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExport.scala index c5fdecd3ca17f..9267e6dbdb857 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExport.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExport.scala @@ -32,6 +32,7 @@ private[mllib] trait PMMLModelExport { @BeanProperty val pmml: PMML = new PMML + pmml.setVersion("4.2") setHeader(pmml) private def setHeader(pmml: PMML): Unit = { From 08a7a836c393d6a62b9b216eeb01fad0b90b6c52 Mon Sep 17 00:00:00 2001 From: Charles Yeh Date: Mon, 9 Nov 2015 11:59:32 -0600 Subject: [PATCH 45/88] [SPARK-10565][CORE] add missing web UI stats to /api/v1/applications JSON I looked at the other endpoints, and they don't seem to be missing any fields. Added fields: ![image](https://cloud.githubusercontent.com/assets/613879/10948801/58159982-82e4-11e5-86dc-62da201af910.png) Author: Charles Yeh Closes #9472 from CharlesYeh/api_vars. --- .../spark/deploy/master/ui/MasterWebUI.scala | 7 +- .../api/v1/ApplicationListResource.scala | 8 ++ .../org/apache/spark/status/api/v1/api.scala | 4 + .../scala/org/apache/spark/ui/SparkUI.scala | 4 + .../deploy/master/ui/MasterWebUISuite.scala | 90 +++++++++++++++++++ project/MimaExcludes.scala | 3 + 6 files changed, 114 insertions(+), 2 deletions(-) create mode 100644 core/src/test/scala/org/apache/spark/deploy/master/ui/MasterWebUISuite.scala diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala index 6174fc11f83d8..e41554a5a6d26 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala @@ -28,14 +28,17 @@ import org.apache.spark.ui.JettyUtils._ * Web UI server for the standalone master. */ private[master] -class MasterWebUI(val master: Master, requestedPort: Int) +class MasterWebUI( + val master: Master, + requestedPort: Int, + customMasterPage: Option[MasterPage] = None) extends WebUI(master.securityMgr, requestedPort, master.conf, name = "MasterUI") with Logging with UIRoot { val masterEndpointRef = master.self val killEnabled = master.conf.getBoolean("spark.ui.killEnabled", true) - val masterPage = new MasterPage(this) + val masterPage = customMasterPage.getOrElse(new MasterPage(this)) initialize() diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/ApplicationListResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/ApplicationListResource.scala index 17b521f3e1d41..0fc0fb59d861f 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/ApplicationListResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/ApplicationListResource.scala @@ -62,6 +62,10 @@ private[spark] object ApplicationsListResource { new ApplicationInfo( id = app.id, name = app.name, + coresGranted = None, + maxCores = None, + coresPerExecutor = None, + memoryPerExecutorMB = None, attempts = app.attempts.map { internalAttemptInfo => new ApplicationAttemptInfo( attemptId = internalAttemptInfo.attemptId, @@ -81,6 +85,10 @@ private[spark] object ApplicationsListResource { new ApplicationInfo( id = internal.id, name = internal.desc.name, + coresGranted = Some(internal.coresGranted), + maxCores = internal.desc.maxCores, + coresPerExecutor = internal.desc.coresPerExecutor, + memoryPerExecutorMB = Some(internal.desc.memoryPerExecutorMB), attempts = Seq(new ApplicationAttemptInfo( attemptId = None, startTime = new Date(internal.startTime), diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala index 2bec64f2ef02b..baddfc50c1a40 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala @@ -25,6 +25,10 @@ import org.apache.spark.JobExecutionStatus class ApplicationInfo private[spark]( val id: String, val name: String, + val coresGranted: Option[Int], + val maxCores: Option[Int], + val coresPerExecutor: Option[Int], + val memoryPerExecutorMB: Option[Int], val attempts: Seq[ApplicationAttemptInfo]) class ApplicationAttemptInfo private[spark]( diff --git a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala index 99085ada9f0af..4608bce202ec8 100644 --- a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala @@ -102,6 +102,10 @@ private[spark] class SparkUI private ( Iterator(new ApplicationInfo( id = appId, name = appName, + coresGranted = None, + maxCores = None, + coresPerExecutor = None, + memoryPerExecutorMB = None, attempts = Seq(new ApplicationAttemptInfo( attemptId = None, startTime = new Date(startTime), diff --git a/core/src/test/scala/org/apache/spark/deploy/master/ui/MasterWebUISuite.scala b/core/src/test/scala/org/apache/spark/deploy/master/ui/MasterWebUISuite.scala new file mode 100644 index 0000000000000..fba835f054f8a --- /dev/null +++ b/core/src/test/scala/org/apache/spark/deploy/master/ui/MasterWebUISuite.scala @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.master.ui + +import java.util.Date + +import scala.io.Source +import scala.language.postfixOps + +import org.json4s.jackson.JsonMethods._ +import org.json4s.JsonAST.{JNothing, JString, JInt} +import org.mockito.Mockito.{mock, when} +import org.scalatest.BeforeAndAfter + +import org.apache.spark.{SparkConf, SecurityManager, SparkFunSuite} +import org.apache.spark.deploy.DeployMessages.MasterStateResponse +import org.apache.spark.deploy.DeployTestUtils._ +import org.apache.spark.deploy.master._ +import org.apache.spark.rpc.RpcEnv + + +class MasterWebUISuite extends SparkFunSuite with BeforeAndAfter { + + val masterPage = mock(classOf[MasterPage]) + val master = { + val conf = new SparkConf + val securityMgr = new SecurityManager(conf) + val rpcEnv = RpcEnv.create(Master.SYSTEM_NAME, "localhost", 0, conf, securityMgr) + val master = new Master(rpcEnv, rpcEnv.address, 0, securityMgr, conf) + master + } + val masterWebUI = new MasterWebUI(master, 0, customMasterPage = Some(masterPage)) + + before { + masterWebUI.bind() + } + + after { + masterWebUI.stop() + } + + test("list applications") { + val worker = createWorkerInfo() + val appDesc = createAppDesc() + // use new start date so it isn't filtered by UI + val activeApp = new ApplicationInfo( + new Date().getTime, "id", appDesc, new Date(), null, Int.MaxValue) + activeApp.addExecutor(worker, 2) + + val workers = Array[WorkerInfo](worker) + val activeApps = Array(activeApp) + val completedApps = Array[ApplicationInfo]() + val activeDrivers = Array[DriverInfo]() + val completedDrivers = Array[DriverInfo]() + val stateResponse = new MasterStateResponse( + "host", 8080, None, workers, activeApps, completedApps, + activeDrivers, completedDrivers, RecoveryState.ALIVE) + + when(masterPage.getMasterState).thenReturn(stateResponse) + + val resultJson = Source.fromURL( + s"http://localhost:${masterWebUI.boundPort}/api/v1/applications") + .mkString + val parsedJson = parse(resultJson) + val firstApp = parsedJson(0) + + assert(firstApp \ "id" === JString(activeApp.id)) + assert(firstApp \ "name" === JString(activeApp.desc.name)) + assert(firstApp \ "coresGranted" === JInt(2)) + assert(firstApp \ "maxCores" === JInt(4)) + assert(firstApp \ "memoryPerExecutorMB" === JInt(1234)) + assert(firstApp \ "coresPerExecutor" === JNothing) + } + +} diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index dacef911e397e..50220790d1f84 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -134,6 +134,9 @@ object MimaExcludes { ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.jdbc.PostgresDialect.toString"), ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.jdbc.PostgresDialect.hashCode"), ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.sql.jdbc.NoopDialect$") + ) ++ Seq ( + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.status.api.v1.ApplicationInfo.this") ) case v if v.startsWith("1.5") => Seq( From 404a28f4edd09cf17361dcbd770e4cafde51bf6d Mon Sep 17 00:00:00 2001 From: tedyu Date: Mon, 9 Nov 2015 10:07:58 -0800 Subject: [PATCH 46/88] [SPARK-11112] Fix Scala 2.11 compilation error in RDDInfo.scala As shown in https://amplab.cs.berkeley.edu/jenkins/view/Spark-QA-Compile/job/Spark-Master-Scala211-Compile/1946/console , compilation fails with: ``` [error] /home/jenkins/workspace/Spark-Master-Scala211-Compile/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala:25: in class RDDInfo, multiple overloaded alternatives of constructor RDDInfo define default arguments. [error] class RDDInfo( [error] ``` This PR tries to fix the compilation error Author: tedyu Closes #9538 from tedyu/master. --- .../scala/org/apache/spark/storage/RDDInfo.scala | 12 +----------- 1 file changed, 1 insertion(+), 11 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala b/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala index 3fa209b924170..87c1b981e7e13 100644 --- a/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala +++ b/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala @@ -28,20 +28,10 @@ class RDDInfo( val numPartitions: Int, var storageLevel: StorageLevel, val parentIds: Seq[Int], - val callSite: CallSite, + val callSite: CallSite = CallSite.empty, val scope: Option[RDDOperationScope] = None) extends Ordered[RDDInfo] { - def this( - id: Int, - name: String, - numPartitions: Int, - storageLevel: StorageLevel, - parentIds: Seq[Int], - scope: Option[RDDOperationScope] = None) { - this(id, name, numPartitions, storageLevel, parentIds, CallSite.empty, scope) - } - var numCachedPartitions = 0 var memSize = 0L var diskSize = 0L From cd174882a5a211298d6e173fe989d567d08ebc0d Mon Sep 17 00:00:00 2001 From: felixcheung Date: Mon, 9 Nov 2015 10:26:09 -0800 Subject: [PATCH 47/88] [SPARK-9865][SPARKR] Flaky SparkR test: test_sparkSQL.R: sample on a DataFrame Make sample test less flaky by setting the seed Tested with ``` repeat { if (count(sample(df, FALSE, 0.1)) == 3) { break } } ``` Author: felixcheung Closes #9549 from felixcheung/rsample. --- R/pkg/inst/tests/test_sparkSQL.R | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index 92cff1fba7193..fbdb9a8f1ef6b 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -647,11 +647,11 @@ test_that("sample on a DataFrame", { sampled <- sample(df, FALSE, 1.0) expect_equal(nrow(collect(sampled)), count(df)) expect_is(sampled, "DataFrame") - sampled2 <- sample(df, FALSE, 0.1) + sampled2 <- sample(df, FALSE, 0.1, 0) # set seed for predictable result expect_true(count(sampled2) < 3) # Also test sample_frac - sampled3 <- sample_frac(df, FALSE, 0.1) + sampled3 <- sample_frac(df, FALSE, 0.1, 0) # set seed for predictable result expect_true(count(sampled3) < 3) }) From 874cd66d4b6d156d0ef112a3d0f3bc5683c6a0ec Mon Sep 17 00:00:00 2001 From: chriskang90 Date: Mon, 9 Nov 2015 19:39:22 +0100 Subject: [PATCH 48/88] [DOCS] Fix typo for Python section on unifying Kafka streams 1) kafkaStreams is a list. The list should be unpacked when passing it into the streaming context union method, which accepts a variable number of streams. 2) print() should be pprint() for pyspark. This contribution is my original work, and I license the work to the project under the project's open source license. Author: chriskang90 Closes #9545 from c-kang/streaming_python_typo. --- docs/streaming-programming-guide.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md index c751dbb41785a..e9a27f446a898 100644 --- a/docs/streaming-programming-guide.md +++ b/docs/streaming-programming-guide.md @@ -1948,8 +1948,8 @@ unifiedStream.print(); {% highlight python %} numStreams = 5 kafkaStreams = [KafkaUtils.createStream(...) for _ in range (numStreams)] -unifiedStream = streamingContext.union(kafkaStreams) -unifiedStream.print() +unifiedStream = streamingContext.union(*kafkaStreams) +unifiedStream.pprint() {% endhighlight %}
From 860ea0d386b5fbbe26bf2954f402a9a73ad37edc Mon Sep 17 00:00:00 2001 From: Bharat Lal Date: Mon, 9 Nov 2015 11:33:01 -0800 Subject: [PATCH 49/88] [SPARK-11581][DOCS] Example mllib code in documentation incorrectly computes MSE Author: Bharat Lal Closes #9560 from bharatl/SPARK-11581. --- docs/mllib-decision-tree.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/mllib-decision-tree.md b/docs/mllib-decision-tree.md index f31c4f88936bd..b5b454bc69245 100644 --- a/docs/mllib-decision-tree.md +++ b/docs/mllib-decision-tree.md @@ -439,7 +439,7 @@ Double testMSE = public Double call(Double a, Double b) { return a + b; } - }) / data.count(); + }) / testData.count(); System.out.println("Test Mean Squared Error: " + testMSE); System.out.println("Learned regression tree model:\n" + model.toDebugString()); From 88a3fdcc783f880a8d01c7e194ec42fc114bdf8a Mon Sep 17 00:00:00 2001 From: Yu ISHIKAWA Date: Mon, 9 Nov 2015 13:16:04 -0800 Subject: [PATCH 50/88] [SPARK-10280][MLLIB][PYSPARK][DOCS] Add @since annotation to pyspark.ml.classification Author: Yu ISHIKAWA Closes #8690 from yu-iskw/SPARK-10280. --- python/pyspark/ml/classification.py | 56 +++++++++++++++++++++++++++++ 1 file changed, 56 insertions(+) diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 2e468f67b8987..603f2c7f798dc 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -67,6 +67,8 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti Traceback (most recent call last): ... TypeError: Method setParams forces keyword arguments. + + .. versionadded:: 1.3.0 """ # a placeholder to make it appear in the generated doc @@ -99,6 +101,7 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred self._checkThresholdConsistency() @keyword_only + @since("1.3.0") def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, threshold=0.5, thresholds=None, probabilityCol="probability", @@ -119,6 +122,7 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre def _create_model(self, java_model): return LogisticRegressionModel(java_model) + @since("1.4.0") def setThreshold(self, value): """ Sets the value of :py:attr:`threshold`. @@ -129,6 +133,7 @@ def setThreshold(self, value): del self._paramMap[self.thresholds] return self + @since("1.4.0") def getThreshold(self): """ Gets the value of threshold or its default value. @@ -144,6 +149,7 @@ def getThreshold(self): else: return self.getOrDefault(self.threshold) + @since("1.5.0") def setThresholds(self, value): """ Sets the value of :py:attr:`thresholds`. @@ -154,6 +160,7 @@ def setThresholds(self, value): del self._paramMap[self.threshold] return self + @since("1.5.0") def getThresholds(self): """ If :py:attr:`thresholds` is set, return its value. @@ -185,9 +192,12 @@ def _checkThresholdConsistency(self): class LogisticRegressionModel(JavaModel): """ Model fitted by LogisticRegression. + + .. versionadded:: 1.3.0 """ @property + @since("1.4.0") def weights(self): """ Model weights. @@ -205,6 +215,7 @@ def coefficients(self): return self._call_java("coefficients") @property + @since("1.4.0") def intercept(self): """ Model intercept. @@ -215,6 +226,8 @@ def intercept(self): class TreeClassifierParams(object): """ Private class to track supported impurity measures. + + .. versionadded:: 1.4.0 """ supportedImpurities = ["entropy", "gini"] @@ -231,6 +244,7 @@ def __init__(self): "gain calculation (case-insensitive). Supported options: " + ", ".join(self.supportedImpurities)) + @since("1.6.0") def setImpurity(self, value): """ Sets the value of :py:attr:`impurity`. @@ -238,6 +252,7 @@ def setImpurity(self, value): self._paramMap[self.impurity] = value return self + @since("1.6.0") def getImpurity(self): """ Gets the value of impurity or its default value. @@ -248,6 +263,8 @@ def getImpurity(self): class GBTParams(TreeEnsembleParams): """ Private class to track supported GBT params. + + .. versionadded:: 1.4.0 """ supportedLossTypes = ["logistic"] @@ -287,6 +304,8 @@ class DecisionTreeClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred >>> test1 = sqlContext.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"]) >>> model.transform(test1).head().prediction 1.0 + + .. versionadded:: 1.4.0 """ @keyword_only @@ -310,6 +329,7 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred self.setParams(**kwargs) @keyword_only + @since("1.4.0") def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", probabilityCol="probability", rawPredictionCol="rawPrediction", maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, @@ -333,6 +353,8 @@ def _create_model(self, java_model): class DecisionTreeClassificationModel(DecisionTreeModel): """ Model fitted by DecisionTreeClassifier. + + .. versionadded:: 1.4.0 """ @@ -371,6 +393,8 @@ class RandomForestClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred >>> test1 = sqlContext.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"]) >>> model.transform(test1).head().prediction 1.0 + + .. versionadded:: 1.4.0 """ @keyword_only @@ -396,6 +420,7 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred self.setParams(**kwargs) @keyword_only + @since("1.4.0") def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", probabilityCol="probability", rawPredictionCol="rawPrediction", maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, @@ -419,6 +444,8 @@ def _create_model(self, java_model): class RandomForestClassificationModel(TreeEnsembleModels): """ Model fitted by RandomForestClassifier. + + .. versionadded:: 1.4.0 """ @@ -450,6 +477,8 @@ class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol >>> test1 = sqlContext.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"]) >>> model.transform(test1).head().prediction 1.0 + + .. versionadded:: 1.4.0 """ # a placeholder to make it appear in the generated doc @@ -482,6 +511,7 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred self.setParams(**kwargs) @keyword_only + @since("1.4.0") def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, @@ -499,6 +529,7 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre def _create_model(self, java_model): return GBTClassificationModel(java_model) + @since("1.4.0") def setLossType(self, value): """ Sets the value of :py:attr:`lossType`. @@ -506,6 +537,7 @@ def setLossType(self, value): self._paramMap[self.lossType] = value return self + @since("1.4.0") def getLossType(self): """ Gets the value of lossType or its default value. @@ -516,6 +548,8 @@ def getLossType(self): class GBTClassificationModel(TreeEnsembleModels): """ Model fitted by GBTClassifier. + + .. versionadded:: 1.4.0 """ @@ -555,6 +589,8 @@ class NaiveBayes(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, H >>> test1 = sc.parallelize([Row(features=Vectors.sparse(2, [0], [1.0]))]).toDF() >>> model.transform(test1).head().prediction 1.0 + + .. versionadded:: 1.5.0 """ # a placeholder to make it appear in the generated doc @@ -587,6 +623,7 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred self.setParams(**kwargs) @keyword_only + @since("1.5.0") def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", probabilityCol="probability", rawPredictionCol="rawPrediction", smoothing=1.0, modelType="multinomial"): @@ -602,6 +639,7 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre def _create_model(self, java_model): return NaiveBayesModel(java_model) + @since("1.5.0") def setSmoothing(self, value): """ Sets the value of :py:attr:`smoothing`. @@ -609,12 +647,14 @@ def setSmoothing(self, value): self._paramMap[self.smoothing] = value return self + @since("1.5.0") def getSmoothing(self): """ Gets the value of smoothing or its default value. """ return self.getOrDefault(self.smoothing) + @since("1.5.0") def setModelType(self, value): """ Sets the value of :py:attr:`modelType`. @@ -622,6 +662,7 @@ def setModelType(self, value): self._paramMap[self.modelType] = value return self + @since("1.5.0") def getModelType(self): """ Gets the value of modelType or its default value. @@ -632,9 +673,12 @@ def getModelType(self): class NaiveBayesModel(JavaModel): """ Model fitted by NaiveBayes. + + .. versionadded:: 1.5.0 """ @property + @since("1.5.0") def pi(self): """ log of class priors. @@ -642,6 +686,7 @@ def pi(self): return self._call_java("pi") @property + @since("1.5.0") def theta(self): """ log of class conditional probabilities. @@ -681,6 +726,8 @@ class MultilayerPerceptronClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, |[0.0,0.0]| 0.0| +---------+----------+ ... + + .. versionadded:: 1.6.0 """ # a placeholder to make it appear in the generated doc @@ -715,6 +762,7 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred self.setParams(**kwargs) @keyword_only + @since("1.6.0") def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxIter=100, tol=1e-4, seed=None, layers=None, blockSize=128): """ @@ -731,6 +779,7 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre def _create_model(self, java_model): return MultilayerPerceptronClassificationModel(java_model) + @since("1.6.0") def setLayers(self, value): """ Sets the value of :py:attr:`layers`. @@ -738,12 +787,14 @@ def setLayers(self, value): self._paramMap[self.layers] = value return self + @since("1.6.0") def getLayers(self): """ Gets the value of layers or its default value. """ return self.getOrDefault(self.layers) + @since("1.6.0") def setBlockSize(self, value): """ Sets the value of :py:attr:`blockSize`. @@ -751,6 +802,7 @@ def setBlockSize(self, value): self._paramMap[self.blockSize] = value return self + @since("1.6.0") def getBlockSize(self): """ Gets the value of blockSize or its default value. @@ -761,9 +813,12 @@ def getBlockSize(self): class MultilayerPerceptronClassificationModel(JavaModel): """ Model fitted by MultilayerPerceptronClassifier. + + .. versionadded:: 1.6.0 """ @property + @since("1.6.0") def layers(self): """ array of layer sizes including input and output layers. @@ -771,6 +826,7 @@ def layers(self): return self._call_java("javaLayers") @property + @since("1.6.0") def weights(self): """ vector of initial weights for the model that consists of the weights of layers. From 5039a49b636325f321daa089971107003fae9d4b Mon Sep 17 00:00:00 2001 From: Felix Bechstein Date: Mon, 9 Nov 2015 13:36:14 -0800 Subject: [PATCH 51/88] [SPARK-10471][CORE][MESOS] prevent getting offers for unmet constraints this change rejects offers for slaves with unmet constraints for 120s to mitigate offer starvation. this prevents mesos to send us these offers again and again. in return, we get more offers for slaves which might meet our constraints. and it enables mesos to send the rejected offers to other frameworks. Author: Felix Bechstein Closes #8639 from felixb/decline_offers_constraint_mismatch. --- .../mesos/CoarseMesosSchedulerBackend.scala | 92 +++++++++++-------- .../cluster/mesos/MesosSchedulerBackend.scala | 48 +++++++--- .../cluster/mesos/MesosSchedulerUtils.scala | 4 + 3 files changed, 91 insertions(+), 53 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala index d10a77f8e5c78..2de9b6a651692 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala @@ -101,6 +101,10 @@ private[spark] class CoarseMesosSchedulerBackend( private val slaveOfferConstraints = parseConstraintString(sc.conf.get("spark.mesos.constraints", "")) + // reject offers with mismatched constraints in seconds + private val rejectOfferDurationForUnmetConstraints = + getRejectOfferDurationForUnmetConstraints(sc) + // A client for talking to the external shuffle service, if it is a private val mesosExternalShuffleClient: Option[MesosExternalShuffleClient] = { if (shuffleServiceEnabled) { @@ -249,48 +253,56 @@ private[spark] class CoarseMesosSchedulerBackend( val mem = getResource(offer.getResourcesList, "mem") val cpus = getResource(offer.getResourcesList, "cpus").toInt val id = offer.getId.getValue - if (taskIdToSlaveId.size < executorLimit && - totalCoresAcquired < maxCores && - meetsConstraints && - mem >= calculateTotalMemory(sc) && - cpus >= 1 && - failuresBySlaveId.getOrElse(slaveId, 0) < MAX_SLAVE_FAILURES && - !slaveIdsWithExecutors.contains(slaveId)) { - // Launch an executor on the slave - val cpusToUse = math.min(cpus, maxCores - totalCoresAcquired) - totalCoresAcquired += cpusToUse - val taskId = newMesosTaskId() - taskIdToSlaveId.put(taskId, slaveId) - slaveIdsWithExecutors += slaveId - coresByTaskId(taskId) = cpusToUse - // Gather cpu resources from the available resources and use them in the task. - val (remainingResources, cpuResourcesToUse) = - partitionResources(offer.getResourcesList, "cpus", cpusToUse) - val (_, memResourcesToUse) = - partitionResources(remainingResources.asJava, "mem", calculateTotalMemory(sc)) - val taskBuilder = MesosTaskInfo.newBuilder() - .setTaskId(TaskID.newBuilder().setValue(taskId.toString).build()) - .setSlaveId(offer.getSlaveId) - .setCommand(createCommand(offer, cpusToUse + extraCoresPerSlave, taskId)) - .setName("Task " + taskId) - .addAllResources(cpuResourcesToUse.asJava) - .addAllResources(memResourcesToUse.asJava) - - sc.conf.getOption("spark.mesos.executor.docker.image").foreach { image => - MesosSchedulerBackendUtil - .setupContainerBuilderDockerInfo(image, sc.conf, taskBuilder.getContainerBuilder()) + if (meetsConstraints) { + if (taskIdToSlaveId.size < executorLimit && + totalCoresAcquired < maxCores && + mem >= calculateTotalMemory(sc) && + cpus >= 1 && + failuresBySlaveId.getOrElse(slaveId, 0) < MAX_SLAVE_FAILURES && + !slaveIdsWithExecutors.contains(slaveId)) { + // Launch an executor on the slave + val cpusToUse = math.min(cpus, maxCores - totalCoresAcquired) + totalCoresAcquired += cpusToUse + val taskId = newMesosTaskId() + taskIdToSlaveId.put(taskId, slaveId) + slaveIdsWithExecutors += slaveId + coresByTaskId(taskId) = cpusToUse + // Gather cpu resources from the available resources and use them in the task. + val (remainingResources, cpuResourcesToUse) = + partitionResources(offer.getResourcesList, "cpus", cpusToUse) + val (_, memResourcesToUse) = + partitionResources(remainingResources.asJava, "mem", calculateTotalMemory(sc)) + val taskBuilder = MesosTaskInfo.newBuilder() + .setTaskId(TaskID.newBuilder().setValue(taskId.toString).build()) + .setSlaveId(offer.getSlaveId) + .setCommand(createCommand(offer, cpusToUse + extraCoresPerSlave, taskId)) + .setName("Task " + taskId) + .addAllResources(cpuResourcesToUse.asJava) + .addAllResources(memResourcesToUse.asJava) + + sc.conf.getOption("spark.mesos.executor.docker.image").foreach { image => + MesosSchedulerBackendUtil + .setupContainerBuilderDockerInfo(image, sc.conf, taskBuilder.getContainerBuilder()) + } + + // Accept the offer and launch the task + logDebug(s"Accepting offer: $id with attributes: $offerAttributes mem: $mem cpu: $cpus") + slaveIdToHost(offer.getSlaveId.getValue) = offer.getHostname + d.launchTasks( + Collections.singleton(offer.getId), + Collections.singleton(taskBuilder.build()), filters) + } else { + // Decline the offer + logDebug(s"Declining offer: $id with attributes: $offerAttributes mem: $mem cpu: $cpus") + d.declineOffer(offer.getId) } - - // accept the offer and launch the task - logDebug(s"Accepting offer: $id with attributes: $offerAttributes mem: $mem cpu: $cpus") - slaveIdToHost(offer.getSlaveId.getValue) = offer.getHostname - d.launchTasks( - Collections.singleton(offer.getId), - Collections.singleton(taskBuilder.build()), filters) } else { - // Decline the offer - logDebug(s"Declining offer: $id with attributes: $offerAttributes mem: $mem cpu: $cpus") - d.declineOffer(offer.getId) + // This offer does not meet constraints. We don't need to see it again. + // Decline the offer for a long period of time. + logDebug(s"Declining offer: $id with attributes: $offerAttributes mem: $mem cpu: $cpus" + + s" for $rejectOfferDurationForUnmetConstraints seconds") + d.declineOffer(offer.getId, Filters.newBuilder() + .setRefuseSeconds(rejectOfferDurationForUnmetConstraints).build()) } } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala index aaffac604a885..281965a5981bb 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala @@ -63,6 +63,10 @@ private[spark] class MesosSchedulerBackend( private[this] val slaveOfferConstraints = parseConstraintString(sc.conf.get("spark.mesos.constraints", "")) + // reject offers with mismatched constraints in seconds + private val rejectOfferDurationForUnmetConstraints = + getRejectOfferDurationForUnmetConstraints(sc) + @volatile var appId: String = _ override def start() { @@ -212,29 +216,47 @@ private[spark] class MesosSchedulerBackend( */ override def resourceOffers(d: SchedulerDriver, offers: JList[Offer]) { inClassLoader() { - // Fail-fast on offers we know will be rejected - val (usableOffers, unUsableOffers) = offers.asScala.partition { o => + // Fail first on offers with unmet constraints + val (offersMatchingConstraints, offersNotMatchingConstraints) = + offers.asScala.partition { o => + val offerAttributes = toAttributeMap(o.getAttributesList) + val meetsConstraints = + matchesAttributeRequirements(slaveOfferConstraints, offerAttributes) + + // add some debug messaging + if (!meetsConstraints) { + val id = o.getId.getValue + logDebug(s"Declining offer: $id with attributes: $offerAttributes") + } + + meetsConstraints + } + + // These offers do not meet constraints. We don't need to see them again. + // Decline the offer for a long period of time. + offersNotMatchingConstraints.foreach { o => + d.declineOffer(o.getId, Filters.newBuilder() + .setRefuseSeconds(rejectOfferDurationForUnmetConstraints).build()) + } + + // Of the matching constraints, see which ones give us enough memory and cores + val (usableOffers, unUsableOffers) = offersMatchingConstraints.partition { o => val mem = getResource(o.getResourcesList, "mem") val cpus = getResource(o.getResourcesList, "cpus") val slaveId = o.getSlaveId.getValue val offerAttributes = toAttributeMap(o.getAttributesList) - // check if all constraints are satisfield - // 1. Attribute constraints - // 2. Memory requirements - // 3. CPU requirements - need at least 1 for executor, 1 for task - val meetsConstraints = matchesAttributeRequirements(slaveOfferConstraints, offerAttributes) + // check offers for + // 1. Memory requirements + // 2. CPU requirements - need at least 1 for executor, 1 for task val meetsMemoryRequirements = mem >= calculateTotalMemory(sc) val meetsCPURequirements = cpus >= (mesosExecutorCores + scheduler.CPUS_PER_TASK) - val meetsRequirements = - (meetsConstraints && meetsMemoryRequirements && meetsCPURequirements) || + (meetsMemoryRequirements && meetsCPURequirements) || (slaveIdToExecutorInfo.contains(slaveId) && cpus >= scheduler.CPUS_PER_TASK) - - // add some debug messaging val debugstr = if (meetsRequirements) "Accepting" else "Declining" - val id = o.getId.getValue - logDebug(s"$debugstr offer: $id with attributes: $offerAttributes mem: $mem cpu: $cpus") + logDebug(s"$debugstr offer: ${o.getId.getValue} with attributes: " + + s"$offerAttributes mem: $mem cpu: $cpus") meetsRequirements } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala index 860c8e097b3b9..721861fbbc517 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala @@ -336,4 +336,8 @@ private[mesos] trait MesosSchedulerUtils extends Logging { } } + protected def getRejectOfferDurationForUnmetConstraints(sc: SparkContext): Long = { + sc.conf.getTimeAsSeconds("spark.mesos.rejectOfferDurationForUnmetConstraints", "120s") + } + } From 51d41e4b1a3a25a3fde3a4345afcfe4766023d23 Mon Sep 17 00:00:00 2001 From: sachin aggarwal Date: Mon, 9 Nov 2015 14:25:42 -0800 Subject: [PATCH 52/88] [SPARK-11552][DOCS][Replaced example code in ml-decision-tree.md using include_example] I have tested it on my local, it is working fine, please review Author: sachin aggarwal Closes #9539 from agsachin/SPARK-11552-real. --- docs/ml-decision-tree.md | 338 +----------------- ...JavaDecisionTreeClassificationExample.java | 103 ++++++ .../ml/JavaDecisionTreeRegressionExample.java | 90 +++++ .../decision_tree_classification_example.py | 77 ++++ .../ml/decision_tree_regression_example.py | 74 ++++ .../DecisionTreeClassificationExample.scala | 94 +++++ .../ml/DecisionTreeRegressionExample.scala | 81 +++++ 7 files changed, 527 insertions(+), 330 deletions(-) create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaDecisionTreeClassificationExample.java create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaDecisionTreeRegressionExample.java create mode 100644 examples/src/main/python/ml/decision_tree_classification_example.py create mode 100644 examples/src/main/python/ml/decision_tree_regression_example.py create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeClassificationExample.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeRegressionExample.scala diff --git a/docs/ml-decision-tree.md b/docs/ml-decision-tree.md index 542819e93e6dc..2bfac6f6c8378 100644 --- a/docs/ml-decision-tree.md +++ b/docs/ml-decision-tree.md @@ -118,196 +118,24 @@ We use two feature transformers to prepare the data; these help index categories More details on parameters can be found in the [Scala API documentation](api/scala/index.html#org.apache.spark.ml.classification.DecisionTreeClassifier). -{% highlight scala %} -import org.apache.spark.ml.Pipeline -import org.apache.spark.ml.classification.DecisionTreeClassifier -import org.apache.spark.ml.classification.DecisionTreeClassificationModel -import org.apache.spark.ml.feature.{StringIndexer, IndexToString, VectorIndexer} -import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator -import org.apache.spark.mllib.util.MLUtils - -// Load and parse the data file, converting it to a DataFrame. -val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF() - -// Index labels, adding metadata to the label column. -// Fit on whole dataset to include all labels in index. -val labelIndexer = new StringIndexer() - .setInputCol("label") - .setOutputCol("indexedLabel") - .fit(data) -// Automatically identify categorical features, and index them. -val featureIndexer = new VectorIndexer() - .setInputCol("features") - .setOutputCol("indexedFeatures") - .setMaxCategories(4) // features with > 4 distinct values are treated as continuous - .fit(data) - -// Split the data into training and test sets (30% held out for testing) -val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3)) - -// Train a DecisionTree model. -val dt = new DecisionTreeClassifier() - .setLabelCol("indexedLabel") - .setFeaturesCol("indexedFeatures") - -// Convert indexed labels back to original labels. -val labelConverter = new IndexToString() - .setInputCol("prediction") - .setOutputCol("predictedLabel") - .setLabels(labelIndexer.labels) - -// Chain indexers and tree in a Pipeline -val pipeline = new Pipeline() - .setStages(Array(labelIndexer, featureIndexer, dt, labelConverter)) - -// Train model. This also runs the indexers. -val model = pipeline.fit(trainingData) - -// Make predictions. -val predictions = model.transform(testData) - -// Select example rows to display. -predictions.select("predictedLabel", "label", "features").show(5) - -// Select (prediction, true label) and compute test error -val evaluator = new MulticlassClassificationEvaluator() - .setLabelCol("indexedLabel") - .setPredictionCol("prediction") - .setMetricName("precision") -val accuracy = evaluator.evaluate(predictions) -println("Test Error = " + (1.0 - accuracy)) - -val treeModel = model.stages(2).asInstanceOf[DecisionTreeClassificationModel] -println("Learned classification tree model:\n" + treeModel.toDebugString) -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/DecisionTreeClassificationExample.scala %} +
More details on parameters can be found in the [Java API documentation](api/java/org/apache/spark/ml/classification/DecisionTreeClassifier.html). -{% highlight java %} -import org.apache.spark.ml.Pipeline; -import org.apache.spark.ml.PipelineModel; -import org.apache.spark.ml.PipelineStage; -import org.apache.spark.ml.classification.DecisionTreeClassifier; -import org.apache.spark.ml.classification.DecisionTreeClassificationModel; -import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator; -import org.apache.spark.ml.feature.*; -import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.mllib.util.MLUtils; -import org.apache.spark.rdd.RDD; -import org.apache.spark.sql.DataFrame; - -// Load and parse the data file, converting it to a DataFrame. -RDD rdd = MLUtils.loadLibSVMFile(sc.sc(), "data/mllib/sample_libsvm_data.txt"); -DataFrame data = jsql.createDataFrame(rdd, LabeledPoint.class); - -// Index labels, adding metadata to the label column. -// Fit on whole dataset to include all labels in index. -StringIndexerModel labelIndexer = new StringIndexer() - .setInputCol("label") - .setOutputCol("indexedLabel") - .fit(data); -// Automatically identify categorical features, and index them. -VectorIndexerModel featureIndexer = new VectorIndexer() - .setInputCol("features") - .setOutputCol("indexedFeatures") - .setMaxCategories(4) // features with > 4 distinct values are treated as continuous - .fit(data); - -// Split the data into training and test sets (30% held out for testing) -DataFrame[] splits = data.randomSplit(new double[] {0.7, 0.3}); -DataFrame trainingData = splits[0]; -DataFrame testData = splits[1]; - -// Train a DecisionTree model. -DecisionTreeClassifier dt = new DecisionTreeClassifier() - .setLabelCol("indexedLabel") - .setFeaturesCol("indexedFeatures"); - -// Convert indexed labels back to original labels. -IndexToString labelConverter = new IndexToString() - .setInputCol("prediction") - .setOutputCol("predictedLabel") - .setLabels(labelIndexer.labels()); - -// Chain indexers and tree in a Pipeline -Pipeline pipeline = new Pipeline() - .setStages(new PipelineStage[] {labelIndexer, featureIndexer, dt, labelConverter}); - -// Train model. This also runs the indexers. -PipelineModel model = pipeline.fit(trainingData); - -// Make predictions. -DataFrame predictions = model.transform(testData); - -// Select example rows to display. -predictions.select("predictedLabel", "label", "features").show(5); - -// Select (prediction, true label) and compute test error -MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator() - .setLabelCol("indexedLabel") - .setPredictionCol("prediction") - .setMetricName("precision"); -double accuracy = evaluator.evaluate(predictions); -System.out.println("Test Error = " + (1.0 - accuracy)); - -DecisionTreeClassificationModel treeModel = - (DecisionTreeClassificationModel)(model.stages()[2]); -System.out.println("Learned classification tree model:\n" + treeModel.toDebugString()); -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaDecisionTreeClassificationExample.java %} +
More details on parameters can be found in the [Python API documentation](api/python/pyspark.ml.html#pyspark.ml.classification.DecisionTreeClassifier). -{% highlight python %} -from pyspark.ml import Pipeline -from pyspark.ml.classification import DecisionTreeClassifier -from pyspark.ml.feature import StringIndexer, VectorIndexer -from pyspark.ml.evaluation import MulticlassClassificationEvaluator -from pyspark.mllib.util import MLUtils - -# Load and parse the data file, converting it to a DataFrame. -data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF() - -# Index labels, adding metadata to the label column. -# Fit on whole dataset to include all labels in index. -labelIndexer = StringIndexer(inputCol="label", outputCol="indexedLabel").fit(data) -# Automatically identify categorical features, and index them. -# We specify maxCategories so features with > 4 distinct values are treated as continuous. -featureIndexer =\ - VectorIndexer(inputCol="features", outputCol="indexedFeatures", maxCategories=4).fit(data) - -# Split the data into training and test sets (30% held out for testing) -(trainingData, testData) = data.randomSplit([0.7, 0.3]) - -# Train a DecisionTree model. -dt = DecisionTreeClassifier(labelCol="indexedLabel", featuresCol="indexedFeatures") - -# Chain indexers and tree in a Pipeline -pipeline = Pipeline(stages=[labelIndexer, featureIndexer, dt]) - -# Train model. This also runs the indexers. -model = pipeline.fit(trainingData) - -# Make predictions. -predictions = model.transform(testData) - -# Select example rows to display. -predictions.select("prediction", "indexedLabel", "features").show(5) - -# Select (prediction, true label) and compute test error -evaluator = MulticlassClassificationEvaluator( - labelCol="indexedLabel", predictionCol="prediction", metricName="precision") -accuracy = evaluator.evaluate(predictions) -print "Test Error = %g" % (1.0 - accuracy) +{% include_example python/ml/decision_tree_classification_example.py %} -treeModel = model.stages[2] -print treeModel # summary only -{% endhighlight %}
@@ -323,171 +151,21 @@ We use a feature transformer to index categorical features, adding metadata to t More details on parameters can be found in the [Scala API documentation](api/scala/index.html#org.apache.spark.ml.regression.DecisionTreeRegressor). -{% highlight scala %} -import org.apache.spark.ml.Pipeline -import org.apache.spark.ml.regression.DecisionTreeRegressor -import org.apache.spark.ml.regression.DecisionTreeRegressionModel -import org.apache.spark.ml.feature.VectorIndexer -import org.apache.spark.ml.evaluation.RegressionEvaluator -import org.apache.spark.mllib.util.MLUtils - -// Load and parse the data file, converting it to a DataFrame. -val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF() - -// Automatically identify categorical features, and index them. -// Here, we treat features with > 4 distinct values as continuous. -val featureIndexer = new VectorIndexer() - .setInputCol("features") - .setOutputCol("indexedFeatures") - .setMaxCategories(4) - .fit(data) - -// Split the data into training and test sets (30% held out for testing) -val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3)) - -// Train a DecisionTree model. -val dt = new DecisionTreeRegressor() - .setLabelCol("label") - .setFeaturesCol("indexedFeatures") - -// Chain indexer and tree in a Pipeline -val pipeline = new Pipeline() - .setStages(Array(featureIndexer, dt)) - -// Train model. This also runs the indexer. -val model = pipeline.fit(trainingData) - -// Make predictions. -val predictions = model.transform(testData) - -// Select example rows to display. -predictions.select("prediction", "label", "features").show(5) - -// Select (prediction, true label) and compute test error -val evaluator = new RegressionEvaluator() - .setLabelCol("label") - .setPredictionCol("prediction") - .setMetricName("rmse") -val rmse = evaluator.evaluate(predictions) -println("Root Mean Squared Error (RMSE) on test data = " + rmse) - -val treeModel = model.stages(1).asInstanceOf[DecisionTreeRegressionModel] -println("Learned regression tree model:\n" + treeModel.toDebugString) -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/ml/DecisionTreeRegressionExample.scala %}
More details on parameters can be found in the [Java API documentation](api/java/org/apache/spark/ml/regression/DecisionTreeRegressor.html). -{% highlight java %} -import org.apache.spark.ml.Pipeline; -import org.apache.spark.ml.PipelineModel; -import org.apache.spark.ml.PipelineStage; -import org.apache.spark.ml.evaluation.RegressionEvaluator; -import org.apache.spark.ml.feature.VectorIndexer; -import org.apache.spark.ml.feature.VectorIndexerModel; -import org.apache.spark.ml.regression.DecisionTreeRegressionModel; -import org.apache.spark.ml.regression.DecisionTreeRegressor; -import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.mllib.util.MLUtils; -import org.apache.spark.rdd.RDD; -import org.apache.spark.sql.DataFrame; - -// Load and parse the data file, converting it to a DataFrame. -RDD rdd = MLUtils.loadLibSVMFile(sc.sc(), "data/mllib/sample_libsvm_data.txt"); -DataFrame data = jsql.createDataFrame(rdd, LabeledPoint.class); - -// Automatically identify categorical features, and index them. -// Set maxCategories so features with > 4 distinct values are treated as continuous. -VectorIndexerModel featureIndexer = new VectorIndexer() - .setInputCol("features") - .setOutputCol("indexedFeatures") - .setMaxCategories(4) - .fit(data); - -// Split the data into training and test sets (30% held out for testing) -DataFrame[] splits = data.randomSplit(new double[] {0.7, 0.3}); -DataFrame trainingData = splits[0]; -DataFrame testData = splits[1]; - -// Train a DecisionTree model. -DecisionTreeRegressor dt = new DecisionTreeRegressor() - .setFeaturesCol("indexedFeatures"); - -// Chain indexer and tree in a Pipeline -Pipeline pipeline = new Pipeline() - .setStages(new PipelineStage[] {featureIndexer, dt}); - -// Train model. This also runs the indexer. -PipelineModel model = pipeline.fit(trainingData); - -// Make predictions. -DataFrame predictions = model.transform(testData); - -// Select example rows to display. -predictions.select("label", "features").show(5); - -// Select (prediction, true label) and compute test error -RegressionEvaluator evaluator = new RegressionEvaluator() - .setLabelCol("label") - .setPredictionCol("prediction") - .setMetricName("rmse"); -double rmse = evaluator.evaluate(predictions); -System.out.println("Root Mean Squared Error (RMSE) on test data = " + rmse); - -DecisionTreeRegressionModel treeModel = - (DecisionTreeRegressionModel)(model.stages()[1]); -System.out.println("Learned regression tree model:\n" + treeModel.toDebugString()); -{% endhighlight %} +{% include_example java/org/apache/spark/examples/ml/JavaDecisionTreeRegressionExample.java %}
More details on parameters can be found in the [Python API documentation](api/python/pyspark.ml.html#pyspark.ml.regression.DecisionTreeRegressor). -{% highlight python %} -from pyspark.ml import Pipeline -from pyspark.ml.regression import DecisionTreeRegressor -from pyspark.ml.feature import VectorIndexer -from pyspark.ml.evaluation import RegressionEvaluator -from pyspark.mllib.util import MLUtils - -# Load and parse the data file, converting it to a DataFrame. -data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF() - -# Automatically identify categorical features, and index them. -# We specify maxCategories so features with > 4 distinct values are treated as continuous. -featureIndexer =\ - VectorIndexer(inputCol="features", outputCol="indexedFeatures", maxCategories=4).fit(data) - -# Split the data into training and test sets (30% held out for testing) -(trainingData, testData) = data.randomSplit([0.7, 0.3]) - -# Train a DecisionTree model. -dt = DecisionTreeRegressor(featuresCol="indexedFeatures") - -# Chain indexer and tree in a Pipeline -pipeline = Pipeline(stages=[featureIndexer, dt]) - -# Train model. This also runs the indexer. -model = pipeline.fit(trainingData) - -# Make predictions. -predictions = model.transform(testData) - -# Select example rows to display. -predictions.select("prediction", "label", "features").show(5) - -# Select (prediction, true label) and compute test error -evaluator = RegressionEvaluator( - labelCol="label", predictionCol="prediction", metricName="rmse") -rmse = evaluator.evaluate(predictions) -print "Root Mean Squared Error (RMSE) on test data = %g" % rmse - -treeModel = model.stages[1] -print treeModel # summary only -{% endhighlight %} +{% include_example python/ml/decision_tree_regression_example.py %}
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaDecisionTreeClassificationExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaDecisionTreeClassificationExample.java new file mode 100644 index 0000000000000..51c1730a8a085 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaDecisionTreeClassificationExample.java @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +// scalastyle:off println +package org.apache.spark.examples.ml; +// $example on$ +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.ml.Pipeline; +import org.apache.spark.ml.PipelineModel; +import org.apache.spark.ml.PipelineStage; +import org.apache.spark.ml.classification.DecisionTreeClassifier; +import org.apache.spark.ml.classification.DecisionTreeClassificationModel; +import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator; +import org.apache.spark.ml.feature.*; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.util.MLUtils; +import org.apache.spark.rdd.RDD; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.SQLContext; +// $example off$ + +public class JavaDecisionTreeClassificationExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaDecisionTreeClassificationExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext sqlContext = new SQLContext(jsc); + + // $example on$ + // Load and parse the data file, converting it to a DataFrame. + RDD rdd = MLUtils.loadLibSVMFile(jsc.sc(), "data/mllib/sample_libsvm_data.txt"); + DataFrame data = sqlContext.createDataFrame(rdd, LabeledPoint.class); + + // Index labels, adding metadata to the label column. + // Fit on whole dataset to include all labels in index. + StringIndexerModel labelIndexer = new StringIndexer() + .setInputCol("label") + .setOutputCol("indexedLabel") + .fit(data); + + // Automatically identify categorical features, and index them. + VectorIndexerModel featureIndexer = new VectorIndexer() + .setInputCol("features") + .setOutputCol("indexedFeatures") + .setMaxCategories(4) // features with > 4 distinct values are treated as continuous + .fit(data); + + // Split the data into training and test sets (30% held out for testing) + DataFrame[] splits = data.randomSplit(new double[]{0.7, 0.3}); + DataFrame trainingData = splits[0]; + DataFrame testData = splits[1]; + + // Train a DecisionTree model. + DecisionTreeClassifier dt = new DecisionTreeClassifier() + .setLabelCol("indexedLabel") + .setFeaturesCol("indexedFeatures"); + + // Convert indexed labels back to original labels. + IndexToString labelConverter = new IndexToString() + .setInputCol("prediction") + .setOutputCol("predictedLabel") + .setLabels(labelIndexer.labels()); + + // Chain indexers and tree in a Pipeline + Pipeline pipeline = new Pipeline() + .setStages(new PipelineStage[]{labelIndexer, featureIndexer, dt, labelConverter}); + + // Train model. This also runs the indexers. + PipelineModel model = pipeline.fit(trainingData); + + // Make predictions. + DataFrame predictions = model.transform(testData); + + // Select example rows to display. + predictions.select("predictedLabel", "label", "features").show(5); + + // Select (prediction, true label) and compute test error + MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator() + .setLabelCol("indexedLabel") + .setPredictionCol("prediction") + .setMetricName("precision"); + double accuracy = evaluator.evaluate(predictions); + System.out.println("Test Error = " + (1.0 - accuracy)); + + DecisionTreeClassificationModel treeModel = + (DecisionTreeClassificationModel) (model.stages()[2]); + System.out.println("Learned classification tree model:\n" + treeModel.toDebugString()); + // $example off$ + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaDecisionTreeRegressionExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaDecisionTreeRegressionExample.java new file mode 100644 index 0000000000000..a4098a4233ec2 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaDecisionTreeRegressionExample.java @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +// scalastyle:off println +package org.apache.spark.examples.ml; +// $example on$ +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.ml.Pipeline; +import org.apache.spark.ml.PipelineModel; +import org.apache.spark.ml.PipelineStage; +import org.apache.spark.ml.evaluation.RegressionEvaluator; +import org.apache.spark.ml.feature.VectorIndexer; +import org.apache.spark.ml.feature.VectorIndexerModel; +import org.apache.spark.ml.regression.DecisionTreeRegressionModel; +import org.apache.spark.ml.regression.DecisionTreeRegressor; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.util.MLUtils; +import org.apache.spark.rdd.RDD; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.SQLContext; +// $example off$ + +public class JavaDecisionTreeRegressionExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaDecisionTreeRegressionExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext sqlContext = new SQLContext(jsc); + // $example on$ + // Load and parse the data file, converting it to a DataFrame. + RDD rdd = MLUtils.loadLibSVMFile(jsc.sc(), "data/mllib/sample_libsvm_data.txt"); + DataFrame data = sqlContext.createDataFrame(rdd, LabeledPoint.class); + + // Automatically identify categorical features, and index them. + // Set maxCategories so features with > 4 distinct values are treated as continuous. + VectorIndexerModel featureIndexer = new VectorIndexer() + .setInputCol("features") + .setOutputCol("indexedFeatures") + .setMaxCategories(4) + .fit(data); + + // Split the data into training and test sets (30% held out for testing) + DataFrame[] splits = data.randomSplit(new double[]{0.7, 0.3}); + DataFrame trainingData = splits[0]; + DataFrame testData = splits[1]; + + // Train a DecisionTree model. + DecisionTreeRegressor dt = new DecisionTreeRegressor() + .setFeaturesCol("indexedFeatures"); + + // Chain indexer and tree in a Pipeline + Pipeline pipeline = new Pipeline() + .setStages(new PipelineStage[]{featureIndexer, dt}); + + // Train model. This also runs the indexer. + PipelineModel model = pipeline.fit(trainingData); + + // Make predictions. + DataFrame predictions = model.transform(testData); + + // Select example rows to display. + predictions.select("label", "features").show(5); + + // Select (prediction, true label) and compute test error + RegressionEvaluator evaluator = new RegressionEvaluator() + .setLabelCol("label") + .setPredictionCol("prediction") + .setMetricName("rmse"); + double rmse = evaluator.evaluate(predictions); + System.out.println("Root Mean Squared Error (RMSE) on test data = " + rmse); + + DecisionTreeRegressionModel treeModel = + (DecisionTreeRegressionModel) (model.stages()[1]); + System.out.println("Learned regression tree model:\n" + treeModel.toDebugString()); + // $example off$ + } +} diff --git a/examples/src/main/python/ml/decision_tree_classification_example.py b/examples/src/main/python/ml/decision_tree_classification_example.py new file mode 100644 index 0000000000000..0af92050e3e3b --- /dev/null +++ b/examples/src/main/python/ml/decision_tree_classification_example.py @@ -0,0 +1,77 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Decision Tree Classification Example. +""" +from __future__ import print_function + +import sys + +# $example on$ +from pyspark import SparkContext, SQLContext +from pyspark.ml import Pipeline +from pyspark.ml.classification import DecisionTreeClassifier +from pyspark.ml.feature import StringIndexer, VectorIndexer +from pyspark.ml.evaluation import MulticlassClassificationEvaluator +from pyspark.mllib.util import MLUtils +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="decision_tree_classification_example") + sqlContext = SQLContext(sc) + + # $example on$ + # Load and parse the data file, converting it to a DataFrame. + data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF() + + # Index labels, adding metadata to the label column. + # Fit on whole dataset to include all labels in index. + labelIndexer = StringIndexer(inputCol="label", outputCol="indexedLabel").fit(data) + # Automatically identify categorical features, and index them. + # We specify maxCategories so features with > 4 distinct values are treated as continuous. + featureIndexer =\ + VectorIndexer(inputCol="features", outputCol="indexedFeatures", maxCategories=4).fit(data) + + # Split the data into training and test sets (30% held out for testing) + (trainingData, testData) = data.randomSplit([0.7, 0.3]) + + # Train a DecisionTree model. + dt = DecisionTreeClassifier(labelCol="indexedLabel", featuresCol="indexedFeatures") + + # Chain indexers and tree in a Pipeline + pipeline = Pipeline(stages=[labelIndexer, featureIndexer, dt]) + + # Train model. This also runs the indexers. + model = pipeline.fit(trainingData) + + # Make predictions. + predictions = model.transform(testData) + + # Select example rows to display. + predictions.select("prediction", "indexedLabel", "features").show(5) + + # Select (prediction, true label) and compute test error + evaluator = MulticlassClassificationEvaluator( + labelCol="indexedLabel", predictionCol="prediction", metricName="precision") + accuracy = evaluator.evaluate(predictions) + print("Test Error = %g " % (1.0 - accuracy)) + + treeModel = model.stages[2] + # summary only + print(treeModel) + # $example off$ diff --git a/examples/src/main/python/ml/decision_tree_regression_example.py b/examples/src/main/python/ml/decision_tree_regression_example.py new file mode 100644 index 0000000000000..3857aed538da2 --- /dev/null +++ b/examples/src/main/python/ml/decision_tree_regression_example.py @@ -0,0 +1,74 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Decision Tree Regression Example. +""" +from __future__ import print_function + +import sys + +from pyspark import SparkContext, SQLContext +# $example on$ +from pyspark.ml import Pipeline +from pyspark.ml.regression import DecisionTreeRegressor +from pyspark.ml.feature import VectorIndexer +from pyspark.ml.evaluation import RegressionEvaluator +from pyspark.mllib.util import MLUtils +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="decision_tree_classification_example") + sqlContext = SQLContext(sc) + + # $example on$ + # Load and parse the data file, converting it to a DataFrame. + data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF() + + # Automatically identify categorical features, and index them. + # We specify maxCategories so features with > 4 distinct values are treated as continuous. + featureIndexer =\ + VectorIndexer(inputCol="features", outputCol="indexedFeatures", maxCategories=4).fit(data) + + # Split the data into training and test sets (30% held out for testing) + (trainingData, testData) = data.randomSplit([0.7, 0.3]) + + # Train a DecisionTree model. + dt = DecisionTreeRegressor(featuresCol="indexedFeatures") + + # Chain indexer and tree in a Pipeline + pipeline = Pipeline(stages=[featureIndexer, dt]) + + # Train model. This also runs the indexer. + model = pipeline.fit(trainingData) + + # Make predictions. + predictions = model.transform(testData) + + # Select example rows to display. + predictions.select("prediction", "label", "features").show(5) + + # Select (prediction, true label) and compute test error + evaluator = RegressionEvaluator( + labelCol="label", predictionCol="prediction", metricName="rmse") + rmse = evaluator.evaluate(predictions) + print("Root Mean Squared Error (RMSE) on test data = %g" % rmse) + + treeModel = model.stages[1] + # summary only + print(treeModel) + # $example off$ diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeClassificationExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeClassificationExample.scala new file mode 100644 index 0000000000000..a24a344f1bcf4 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeClassificationExample.scala @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkContext, SparkConf} +// $example on$ +import org.apache.spark.ml.Pipeline +import org.apache.spark.ml.classification.DecisionTreeClassifier +import org.apache.spark.ml.classification.DecisionTreeClassificationModel +import org.apache.spark.ml.feature.{StringIndexer, IndexToString, VectorIndexer} +import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator +import org.apache.spark.mllib.util.MLUtils +// $example off$ + +object DecisionTreeClassificationExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("DecisionTreeClassificationExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + import sqlContext.implicits._ + // $example on$ + // Load and parse the data file, converting it to a DataFrame. + val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF() + + // Index labels, adding metadata to the label column. + // Fit on whole dataset to include all labels in index. + val labelIndexer = new StringIndexer() + .setInputCol("label") + .setOutputCol("indexedLabel") + .fit(data) + // Automatically identify categorical features, and index them. + val featureIndexer = new VectorIndexer() + .setInputCol("features") + .setOutputCol("indexedFeatures") + .setMaxCategories(4) // features with > 4 distinct values are treated as continuous + .fit(data) + + // Split the data into training and test sets (30% held out for testing) + val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3)) + + // Train a DecisionTree model. + val dt = new DecisionTreeClassifier() + .setLabelCol("indexedLabel") + .setFeaturesCol("indexedFeatures") + + // Convert indexed labels back to original labels. + val labelConverter = new IndexToString() + .setInputCol("prediction") + .setOutputCol("predictedLabel") + .setLabels(labelIndexer.labels) + + // Chain indexers and tree in a Pipeline + val pipeline = new Pipeline() + .setStages(Array(labelIndexer, featureIndexer, dt, labelConverter)) + + // Train model. This also runs the indexers. + val model = pipeline.fit(trainingData) + + // Make predictions. + val predictions = model.transform(testData) + + // Select example rows to display. + predictions.select("predictedLabel", "label", "features").show(5) + + // Select (prediction, true label) and compute test error + val evaluator = new MulticlassClassificationEvaluator() + .setLabelCol("indexedLabel") + .setPredictionCol("prediction") + .setMetricName("precision") + val accuracy = evaluator.evaluate(predictions) + println("Test Error = " + (1.0 - accuracy)) + + val treeModel = model.stages(2).asInstanceOf[DecisionTreeClassificationModel] + println("Learned classification tree model:\n" + treeModel.toDebugString) + // $example off$ + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeRegressionExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeRegressionExample.scala new file mode 100644 index 0000000000000..64cd986129007 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeRegressionExample.scala @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml +import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkContext, SparkConf} +// $example on$ +import org.apache.spark.ml.Pipeline +import org.apache.spark.ml.regression.DecisionTreeRegressor +import org.apache.spark.ml.regression.DecisionTreeRegressionModel +import org.apache.spark.ml.feature.VectorIndexer +import org.apache.spark.ml.evaluation.RegressionEvaluator +import org.apache.spark.mllib.util.MLUtils +// $example off$ +object DecisionTreeRegressionExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("DecisionTreeRegressionExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + import sqlContext.implicits._ + // $example on$ + // Load and parse the data file, converting it to a DataFrame. + val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF() + + // Automatically identify categorical features, and index them. + // Here, we treat features with > 4 distinct values as continuous. + val featureIndexer = new VectorIndexer() + .setInputCol("features") + .setOutputCol("indexedFeatures") + .setMaxCategories(4) + .fit(data) + + // Split the data into training and test sets (30% held out for testing) + val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3)) + + // Train a DecisionTree model. + val dt = new DecisionTreeRegressor() + .setLabelCol("label") + .setFeaturesCol("indexedFeatures") + + // Chain indexer and tree in a Pipeline + val pipeline = new Pipeline() + .setStages(Array(featureIndexer, dt)) + + // Train model. This also runs the indexer. + val model = pipeline.fit(trainingData) + + // Make predictions. + val predictions = model.transform(testData) + + // Select example rows to display. + predictions.select("prediction", "label", "features").show(5) + + // Select (prediction, true label) and compute test error + val evaluator = new RegressionEvaluator() + .setLabelCol("label") + .setPredictionCol("prediction") + .setMetricName("rmse") + val rmse = evaluator.evaluate(predictions) + println("Root Mean Squared Error (RMSE) on test data = " + rmse) + + val treeModel = model.stages(1).asInstanceOf[DecisionTreeRegressionModel] + println("Learned regression tree model:\n" + treeModel.toDebugString) + // $example off$ + } +} From b7720fa45525cff6e812fa448d0841cb41f6c8a5 Mon Sep 17 00:00:00 2001 From: Rishabh Bhardwaj Date: Mon, 9 Nov 2015 14:27:36 -0800 Subject: [PATCH 53/88] [SPARK-11548][DOCS] Replaced example code in mllib-collaborative-filtering.md using include_example Kindly review the changes. Author: Rishabh Bhardwaj Closes #9519 from rishabhbhardwaj/SPARK-11337. --- docs/mllib-collaborative-filtering.md | 138 +----------------- .../mllib/JavaRecommendationExample.java | 97 ++++++++++++ .../python/mllib/recommendation_example.py | 54 +++++++ .../mllib/RecommendationExample.scala | 67 +++++++++ 4 files changed, 221 insertions(+), 135 deletions(-) create mode 100644 examples/src/main/java/org/apache/spark/examples/mllib/JavaRecommendationExample.java create mode 100644 examples/src/main/python/mllib/recommendation_example.py create mode 100644 examples/src/main/scala/org/apache/spark/examples/mllib/RecommendationExample.scala diff --git a/docs/mllib-collaborative-filtering.md b/docs/mllib-collaborative-filtering.md index 1ad52123c74aa..7cd1b894e7cb5 100644 --- a/docs/mllib-collaborative-filtering.md +++ b/docs/mllib-collaborative-filtering.md @@ -66,43 +66,7 @@ recommendation model by measuring the Mean Squared Error of rating prediction. Refer to the [`ALS` Scala docs](api/scala/index.html#org.apache.spark.mllib.recommendation.ALS) for details on the API. -{% highlight scala %} -import org.apache.spark.mllib.recommendation.ALS -import org.apache.spark.mllib.recommendation.MatrixFactorizationModel -import org.apache.spark.mllib.recommendation.Rating - -// Load and parse the data -val data = sc.textFile("data/mllib/als/test.data") -val ratings = data.map(_.split(',') match { case Array(user, item, rate) => - Rating(user.toInt, item.toInt, rate.toDouble) - }) - -// Build the recommendation model using ALS -val rank = 10 -val numIterations = 10 -val model = ALS.train(ratings, rank, numIterations, 0.01) - -// Evaluate the model on rating data -val usersProducts = ratings.map { case Rating(user, product, rate) => - (user, product) -} -val predictions = - model.predict(usersProducts).map { case Rating(user, product, rate) => - ((user, product), rate) - } -val ratesAndPreds = ratings.map { case Rating(user, product, rate) => - ((user, product), rate) -}.join(predictions) -val MSE = ratesAndPreds.map { case ((user, product), (r1, r2)) => - val err = (r1 - r2) - err * err -}.mean() -println("Mean Squared Error = " + MSE) - -// Save and load model -model.save(sc, "myModelPath") -val sameModel = MatrixFactorizationModel.load(sc, "myModelPath") -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/mllib/RecommendationExample.scala %} If the rating matrix is derived from another source of information (e.g., it is inferred from other signals), you can use the `trainImplicit` method to get better results. @@ -123,81 +87,7 @@ that is equivalent to the provided example in Scala is given below: Refer to the [`ALS` Java docs](api/java/org/apache/spark/mllib/recommendation/ALS.html) for details on the API. -{% highlight java %} -import scala.Tuple2; - -import org.apache.spark.api.java.*; -import org.apache.spark.api.java.function.Function; -import org.apache.spark.mllib.recommendation.ALS; -import org.apache.spark.mllib.recommendation.MatrixFactorizationModel; -import org.apache.spark.mllib.recommendation.Rating; -import org.apache.spark.SparkConf; - -public class CollaborativeFiltering { - public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("Collaborative Filtering Example"); - JavaSparkContext sc = new JavaSparkContext(conf); - - // Load and parse the data - String path = "data/mllib/als/test.data"; - JavaRDD data = sc.textFile(path); - JavaRDD ratings = data.map( - new Function() { - public Rating call(String s) { - String[] sarray = s.split(","); - return new Rating(Integer.parseInt(sarray[0]), Integer.parseInt(sarray[1]), - Double.parseDouble(sarray[2])); - } - } - ); - - // Build the recommendation model using ALS - int rank = 10; - int numIterations = 10; - MatrixFactorizationModel model = ALS.train(JavaRDD.toRDD(ratings), rank, numIterations, 0.01); - - // Evaluate the model on rating data - JavaRDD> userProducts = ratings.map( - new Function>() { - public Tuple2 call(Rating r) { - return new Tuple2(r.user(), r.product()); - } - } - ); - JavaPairRDD, Double> predictions = JavaPairRDD.fromJavaRDD( - model.predict(JavaRDD.toRDD(userProducts)).toJavaRDD().map( - new Function, Double>>() { - public Tuple2, Double> call(Rating r){ - return new Tuple2, Double>( - new Tuple2(r.user(), r.product()), r.rating()); - } - } - )); - JavaRDD> ratesAndPreds = - JavaPairRDD.fromJavaRDD(ratings.map( - new Function, Double>>() { - public Tuple2, Double> call(Rating r){ - return new Tuple2, Double>( - new Tuple2(r.user(), r.product()), r.rating()); - } - } - )).join(predictions).values(); - double MSE = JavaDoubleRDD.fromRDD(ratesAndPreds.map( - new Function, Object>() { - public Object call(Tuple2 pair) { - Double err = pair._1() - pair._2(); - return err * err; - } - } - ).rdd()).mean(); - System.out.println("Mean Squared Error = " + MSE); - - // Save and load model - model.save(sc.sc(), "myModelPath"); - MatrixFactorizationModel sameModel = MatrixFactorizationModel.load(sc.sc(), "myModelPath"); - } -} -{% endhighlight %} +{% include_example java/org/apache/spark/examples/mllib/JavaRecommendationExample.java %}
@@ -207,29 +97,7 @@ recommendation by measuring the Mean Squared Error of rating prediction. Refer to the [`ALS` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.recommendation.ALS) for more details on the API. -{% highlight python %} -from pyspark.mllib.recommendation import ALS, MatrixFactorizationModel, Rating - -# Load and parse the data -data = sc.textFile("data/mllib/als/test.data") -ratings = data.map(lambda l: l.split(',')).map(lambda l: Rating(int(l[0]), int(l[1]), float(l[2]))) - -# Build the recommendation model using Alternating Least Squares -rank = 10 -numIterations = 10 -model = ALS.train(ratings, rank, numIterations) - -# Evaluate the model on training data -testdata = ratings.map(lambda p: (p[0], p[1])) -predictions = model.predictAll(testdata).map(lambda r: ((r[0], r[1]), r[2])) -ratesAndPreds = ratings.map(lambda r: ((r[0], r[1]), r[2])).join(predictions) -MSE = ratesAndPreds.map(lambda r: (r[1][0] - r[1][1])**2).mean() -print("Mean Squared Error = " + str(MSE)) - -# Save and load model -model.save(sc, "myModelPath") -sameModel = MatrixFactorizationModel.load(sc, "myModelPath") -{% endhighlight %} +{% include_example python/mllib/recommendation_example.py %} If the rating matrix is derived from other source of information (i.e., it is inferred from other signals), you can use the trainImplicit method to get better results. diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaRecommendationExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaRecommendationExample.java new file mode 100644 index 0000000000000..1065fde953b96 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaRecommendationExample.java @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib; + +// $example on$ +import scala.Tuple2; + +import org.apache.spark.api.java.*; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.mllib.recommendation.ALS; +import org.apache.spark.mllib.recommendation.MatrixFactorizationModel; +import org.apache.spark.mllib.recommendation.Rating; +import org.apache.spark.SparkConf; +// $example off$ + +public class JavaRecommendationExample { + public static void main(String args[]) { + // $example on$ + SparkConf conf = new SparkConf().setAppName("Java Collaborative Filtering Example"); + JavaSparkContext jsc = new JavaSparkContext(conf); + + // Load and parse the data + String path = "data/mllib/als/test.data"; + JavaRDD data = jsc.textFile(path); + JavaRDD ratings = data.map( + new Function() { + public Rating call(String s) { + String[] sarray = s.split(","); + return new Rating(Integer.parseInt(sarray[0]), Integer.parseInt(sarray[1]), + Double.parseDouble(sarray[2])); + } + } + ); + + // Build the recommendation model using ALS + int rank = 10; + int numIterations = 10; + MatrixFactorizationModel model = ALS.train(JavaRDD.toRDD(ratings), rank, numIterations, 0.01); + + // Evaluate the model on rating data + JavaRDD> userProducts = ratings.map( + new Function>() { + public Tuple2 call(Rating r) { + return new Tuple2(r.user(), r.product()); + } + } + ); + JavaPairRDD, Double> predictions = JavaPairRDD.fromJavaRDD( + model.predict(JavaRDD.toRDD(userProducts)).toJavaRDD().map( + new Function, Double>>() { + public Tuple2, Double> call(Rating r){ + return new Tuple2, Double>( + new Tuple2(r.user(), r.product()), r.rating()); + } + } + )); + JavaRDD> ratesAndPreds = + JavaPairRDD.fromJavaRDD(ratings.map( + new Function, Double>>() { + public Tuple2, Double> call(Rating r){ + return new Tuple2, Double>( + new Tuple2(r.user(), r.product()), r.rating()); + } + } + )).join(predictions).values(); + double MSE = JavaDoubleRDD.fromRDD(ratesAndPreds.map( + new Function, Object>() { + public Object call(Tuple2 pair) { + Double err = pair._1() - pair._2(); + return err * err; + } + } + ).rdd()).mean(); + System.out.println("Mean Squared Error = " + MSE); + + // Save and load model + model.save(jsc.sc(), "target/tmp/myCollaborativeFilter"); + MatrixFactorizationModel sameModel = MatrixFactorizationModel.load(jsc.sc(), + "target/tmp/myCollaborativeFilter"); + // $example off$ + } +} diff --git a/examples/src/main/python/mllib/recommendation_example.py b/examples/src/main/python/mllib/recommendation_example.py new file mode 100644 index 0000000000000..615db0749b182 --- /dev/null +++ b/examples/src/main/python/mllib/recommendation_example.py @@ -0,0 +1,54 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Collaborative Filtering Classification Example. +""" +from __future__ import print_function + +import sys + +from pyspark import SparkContext + +# $example on$ +from pyspark.mllib.recommendation import ALS, MatrixFactorizationModel, Rating +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="PythonCollaborativeFilteringExample") + # $example on$ + # Load and parse the data + data = sc.textFile("data/mllib/als/test.data") + ratings = data.map(lambda l: l.split(','))\ + .map(lambda l: Rating(int(l[0]), int(l[1]), float(l[2]))) + + # Build the recommendation model using Alternating Least Squares + rank = 10 + numIterations = 10 + model = ALS.train(ratings, rank, numIterations) + + # Evaluate the model on training data + testdata = ratings.map(lambda p: (p[0], p[1])) + predictions = model.predictAll(testdata).map(lambda r: ((r[0], r[1]), r[2])) + ratesAndPreds = ratings.map(lambda r: ((r[0], r[1]), r[2])).join(predictions) + MSE = ratesAndPreds.map(lambda r: (r[1][0] - r[1][1])**2).mean() + print("Mean Squared Error = " + str(MSE)) + + # Save and load model + model.save(sc, "target/tmp/myCollaborativeFilter") + sameModel = MatrixFactorizationModel.load(sc, "target/tmp/myCollaborativeFilter") + # $example off$ diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/RecommendationExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/RecommendationExample.scala new file mode 100644 index 0000000000000..64e4602465444 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/RecommendationExample.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.mllib + +import org.apache.spark.{SparkContext, SparkConf} +// $example on$ +import org.apache.spark.mllib.recommendation.ALS +import org.apache.spark.mllib.recommendation.MatrixFactorizationModel +import org.apache.spark.mllib.recommendation.Rating +// $example off$ + +object RecommendationExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("CollaborativeFilteringExample") + val sc = new SparkContext(conf) + // $example on$ + // Load and parse the data + val data = sc.textFile("data/mllib/als/test.data") + val ratings = data.map(_.split(',') match { case Array(user, item, rate) => + Rating(user.toInt, item.toInt, rate.toDouble) + }) + + // Build the recommendation model using ALS + val rank = 10 + val numIterations = 10 + val model = ALS.train(ratings, rank, numIterations, 0.01) + + // Evaluate the model on rating data + val usersProducts = ratings.map { case Rating(user, product, rate) => + (user, product) + } + val predictions = + model.predict(usersProducts).map { case Rating(user, product, rate) => + ((user, product), rate) + } + val ratesAndPreds = ratings.map { case Rating(user, product, rate) => + ((user, product), rate) + }.join(predictions) + val MSE = ratesAndPreds.map { case ((user, product), (r1, r2)) => + val err = (r1 - r2) + err * err + }.mean() + println("Mean Squared Error = " + MSE) + + // Save and load model + model.save(sc, "target/tmp/myCollaborativeFilter") + val sameModel = MatrixFactorizationModel.load(sc, "target/tmp/myCollaborativeFilter") + // $example off$ + } +} +// scalastyle:on println From f138cb873335654476d1cd1070900b552dd8b21a Mon Sep 17 00:00:00 2001 From: Nick Buroojy Date: Mon, 9 Nov 2015 14:30:37 -0800 Subject: [PATCH 54/88] [SPARK-9301][SQL] Add collect_set and collect_list aggregate functions For now they are thin wrappers around the corresponding Hive UDAFs. One limitation with these in Hive 0.13.0 is they only support aggregating primitive types. I chose snake_case here instead of camelCase because it seems to be used in the majority of the multi-word fns. Do we also want to add these to `functions.py`? This approach was recommended here: https://github.com/apache/spark/pull/8592#issuecomment-154247089 marmbrus rxin Author: Nick Buroojy Closes #9526 from nburoojy/nick/udaf-alias. (cherry picked from commit a6ee4f989d020420dd08b97abb24802200ff23b2) Signed-off-by: Michael Armbrust --- python/pyspark/sql/functions.py | 25 +++++++++++-------- python/pyspark/sql/tests.py | 17 +++++++++++++ .../org/apache/spark/sql/functions.scala | 20 +++++++++++++++ .../hive/HiveDataFrameAnalyticsSuite.scala | 15 +++++++++-- 4 files changed, 64 insertions(+), 13 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 2f7c2f4aacd47..962f676d406d8 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -124,17 +124,20 @@ def _(): _functions_1_6 = { # unary math functions - "stddev": "Aggregate function: returns the unbiased sample standard deviation of" + - " the expression in a group.", - "stddev_samp": "Aggregate function: returns the unbiased sample standard deviation of" + - " the expression in a group.", - "stddev_pop": "Aggregate function: returns population standard deviation of" + - " the expression in a group.", - "variance": "Aggregate function: returns the population variance of the values in a group.", - "var_samp": "Aggregate function: returns the unbiased variance of the values in a group.", - "var_pop": "Aggregate function: returns the population variance of the values in a group.", - "skewness": "Aggregate function: returns the skewness of the values in a group.", - "kurtosis": "Aggregate function: returns the kurtosis of the values in a group." + 'stddev': 'Aggregate function: returns the unbiased sample standard deviation of' + + ' the expression in a group.', + 'stddev_samp': 'Aggregate function: returns the unbiased sample standard deviation of' + + ' the expression in a group.', + 'stddev_pop': 'Aggregate function: returns population standard deviation of' + + ' the expression in a group.', + 'variance': 'Aggregate function: returns the population variance of the values in a group.', + 'var_samp': 'Aggregate function: returns the unbiased variance of the values in a group.', + 'var_pop': 'Aggregate function: returns the population variance of the values in a group.', + 'skewness': 'Aggregate function: returns the skewness of the values in a group.', + 'kurtosis': 'Aggregate function: returns the kurtosis of the values in a group.', + 'collect_list': 'Aggregate function: returns a list of objects with duplicates.', + 'collect_set': 'Aggregate function: returns a set of objects with duplicate elements' + + ' eliminated.' } # math functions that take two arguments as input diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 4c03a0d4ffe93..e224574bcb301 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -1230,6 +1230,23 @@ def test_window_functions_without_partitionBy(self): for r, ex in zip(rs, expected): self.assertEqual(tuple(r), ex[:len(r)]) + def test_collect_functions(self): + df = self.sqlCtx.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"]) + from pyspark.sql import functions + + self.assertEqual( + sorted(df.select(functions.collect_set(df.key).alias('r')).collect()[0].r), + [1, 2]) + self.assertEqual( + sorted(df.select(functions.collect_list(df.key).alias('r')).collect()[0].r), + [1, 1, 1, 2]) + self.assertEqual( + sorted(df.select(functions.collect_set(df.value).alias('r')).collect()[0].r), + ["1", "2"]) + self.assertEqual( + sorted(df.select(functions.collect_list(df.value).alias('r')).collect()[0].r), + ["1", "2", "2", "2"]) + if __name__ == "__main__": if xmlrunner: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 04627589886a8..3f0b24b68b816 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -174,6 +174,26 @@ object functions { */ def avg(columnName: String): Column = avg(Column(columnName)) + /** + * Aggregate function: returns a list of objects with duplicates. + * + * For now this is an alias for the collect_list Hive UDAF. + * + * @group agg_funcs + * @since 1.6.0 + */ + def collect_list(e: Column): Column = callUDF("collect_list", e) + + /** + * Aggregate function: returns a set of objects with duplicate elements eliminated. + * + * For now this is an alias for the collect_set Hive UDAF. + * + * @group agg_funcs + * @since 1.6.0 + */ + def collect_set(e: Column): Column = callUDF("collect_set", e) + /** * Aggregate function: returns the Pearson Correlation Coefficient for two columns. * diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala index 2e5cae415e54b..9864acf765265 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.hive -import org.apache.spark.sql.{DataFrame, QueryTest} +import org.apache.spark.sql.{DataFrame, QueryTest, Row} import org.apache.spark.sql.functions._ import org.apache.spark.sql.hive.test.TestHiveSingleton import org.scalatest.BeforeAndAfterAll @@ -32,7 +32,7 @@ class HiveDataFrameAnalyticsSuite extends QueryTest with TestHiveSingleton with private var testData: DataFrame = _ override def beforeAll() { - testData = Seq((1, 2), (2, 4)).toDF("a", "b") + testData = Seq((1, 2), (2, 2), (3, 4)).toDF("a", "b") hiveContext.registerDataFrameAsTable(testData, "mytable") } @@ -52,6 +52,17 @@ class HiveDataFrameAnalyticsSuite extends QueryTest with TestHiveSingleton with ) } + test("collect functions") { + checkAnswer( + testData.select(collect_list($"a"), collect_list($"b")), + Seq(Row(Seq(1, 2, 3), Seq(2, 2, 4))) + ) + checkAnswer( + testData.select(collect_set($"a"), collect_set($"b")), + Seq(Row(Seq(1, 2, 3), Seq(2, 4))) + ) + } + test("cube") { checkAnswer( testData.cube($"a" + $"b", $"b").agg(sum($"a" - $"b")), From 150f6a89b79f0e5bc31aa83731429dc7ac5ea76b Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Mon, 9 Nov 2015 14:32:52 -0800 Subject: [PATCH 55/88] [SPARK-11595] [SQL] Fixes ADD JAR when the input path contains URL scheme Author: Cheng Lian Closes #9569 from liancheng/spark-11595.fix-add-jar. --- .../hive/thriftserver/HiveThriftServer2Suites.scala | 1 + .../apache/spark/sql/hive/client/ClientWrapper.scala | 11 +++++++++-- .../spark/sql/hive/client/IsolatedClientLoader.scala | 9 +++------ .../spark/sql/hive/execution/HiveQuerySuite.scala | 8 +++++--- 4 files changed, 18 insertions(+), 11 deletions(-) diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala index ff8ca0150649d..5903b9e71cdd2 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala @@ -41,6 +41,7 @@ import org.apache.thrift.transport.TSocket import org.scalatest.BeforeAndAfterAll import org.apache.spark.sql.hive.HiveContext +import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.test.ProcessTestUtils.ProcessOutputCapturer import org.apache.spark.util.Utils import org.apache.spark.{Logging, SparkFunSuite} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala index 3dce86c480747..f1c2489b38271 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.hive.client import java.io.{File, PrintStream} import java.util.{Map => JMap} -import javax.annotation.concurrent.GuardedBy import scala.collection.JavaConverters._ import scala.language.reflectiveCalls @@ -548,7 +547,15 @@ private[hive] class ClientWrapper( } def addJar(path: String): Unit = { - clientLoader.addJar(path) + val uri = new Path(path).toUri + val jarURL = if (uri.getScheme == null) { + // `path` is a local file path without a URL scheme + new File(path).toURI.toURL + } else { + // `path` is a URL with a scheme + uri.toURL + } + clientLoader.addJar(jarURL) runSqlHive(s"ADD JAR $path") } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala index f99c3ed2ae987..e041e0d8e5ae8 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala @@ -22,7 +22,6 @@ import java.lang.reflect.InvocationTargetException import java.net.{URL, URLClassLoader} import java.util -import scala.collection.mutable import scala.language.reflectiveCalls import scala.util.Try @@ -30,10 +29,9 @@ import org.apache.commons.io.{FileUtils, IOUtils} import org.apache.spark.Logging import org.apache.spark.deploy.SparkSubmitUtils -import org.apache.spark.util.{MutableURLClassLoader, Utils} - import org.apache.spark.sql.catalyst.util.quietly import org.apache.spark.sql.hive.HiveContext +import org.apache.spark.util.{MutableURLClassLoader, Utils} /** Factory for `IsolatedClientLoader` with specific versions of hive. */ private[hive] object IsolatedClientLoader { @@ -190,9 +188,8 @@ private[hive] class IsolatedClientLoader( new NonClosableMutableURLClassLoader(isolatedClassLoader) } - private[hive] def addJar(path: String): Unit = synchronized { - val jarURL = new java.io.File(path).toURI.toURL - classLoader.addURL(jarURL) + private[hive] def addJar(path: URL): Unit = synchronized { + classLoader.addURL(path) } /** The isolated client interface to Hive. */ diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index fc72e3c7dc6aa..78378c8b69c7a 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -927,7 +927,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { test("SPARK-2263: Insert Map values") { sql("CREATE TABLE m(value MAP)") sql("INSERT OVERWRITE TABLE m SELECT MAP(key, value) FROM src LIMIT 10") - sql("SELECT * FROM m").collect().zip(sql("SELECT * FROM src LIMIT 10").collect()).map { + sql("SELECT * FROM m").collect().zip(sql("SELECT * FROM src LIMIT 10").collect()).foreach { case (Row(map: Map[_, _]), Row(key: Int, value: String)) => assert(map.size === 1) assert(map.head === (key, value)) @@ -961,10 +961,12 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { test("CREATE TEMPORARY FUNCTION") { val funcJar = TestHive.getHiveFile("TestUDTF.jar").getCanonicalPath - sql(s"ADD JAR $funcJar") + val jarURL = s"file://$funcJar" + sql(s"ADD JAR $jarURL") sql( """CREATE TEMPORARY FUNCTION udtf_count2 AS - | 'org.apache.spark.sql.hive.execution.GenericUDTFCount2'""".stripMargin) + |'org.apache.spark.sql.hive.execution.GenericUDTFCount2' + """.stripMargin) assert(sql("DESCRIBE FUNCTION udtf_count2").count > 1) sql("DROP TEMPORARY FUNCTION udtf_count2") } From a3a7c9103e136035d65a5564f9eb0fa04727c4f3 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Mon, 9 Nov 2015 14:39:18 -0800 Subject: [PATCH 56/88] [SPARK-11359][STREAMING][KINESIS] Checkpoint to DynamoDB even when new data doesn't come in Currently, the checkpoints to DynamoDB occur only when new data comes in, as we update the clock for the checkpointState. This PR makes the checkpoint a scheduled execution based on the `checkpointInterval`. Author: Burak Yavuz Closes #9421 from brkyvz/kinesis-checkpoint. --- .../kinesis/KinesisCheckpointState.scala | 54 ------- .../kinesis/KinesisCheckpointer.scala | 133 +++++++++++++++ .../streaming/kinesis/KinesisReceiver.scala | 38 ++++- .../kinesis/KinesisRecordProcessor.scala | 59 ++----- .../kinesis/KinesisCheckpointerSuite.scala | 152 ++++++++++++++++++ .../kinesis/KinesisReceiverSuite.scala | 96 +++-------- 6 files changed, 349 insertions(+), 183 deletions(-) delete mode 100644 extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointState.scala create mode 100644 extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointer.scala create mode 100644 extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointerSuite.scala diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointState.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointState.scala deleted file mode 100644 index 83a4537559512..0000000000000 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointState.scala +++ /dev/null @@ -1,54 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.streaming.kinesis - -import org.apache.spark.Logging -import org.apache.spark.streaming.Duration -import org.apache.spark.util.{Clock, ManualClock, SystemClock} - -/** - * This is a helper class for managing checkpoint clocks. - * - * @param checkpointInterval - * @param currentClock. Default to current SystemClock if none is passed in (mocking purposes) - */ -private[kinesis] class KinesisCheckpointState( - checkpointInterval: Duration, - currentClock: Clock = new SystemClock()) - extends Logging { - - /* Initialize the checkpoint clock using the given currentClock + checkpointInterval millis */ - val checkpointClock = new ManualClock() - checkpointClock.setTime(currentClock.getTimeMillis() + checkpointInterval.milliseconds) - - /** - * Check if it's time to checkpoint based on the current time and the derived time - * for the next checkpoint - * - * @return true if it's time to checkpoint - */ - def shouldCheckpoint(): Boolean = { - new SystemClock().getTimeMillis() > checkpointClock.getTimeMillis() - } - - /** - * Advance the checkpoint clock by the checkpoint interval. - */ - def advanceCheckpoint(): Unit = { - checkpointClock.advance(checkpointInterval.milliseconds) - } -} diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointer.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointer.scala new file mode 100644 index 0000000000000..1ca6d4302c2bb --- /dev/null +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointer.scala @@ -0,0 +1,133 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.streaming.kinesis + +import java.util.concurrent._ + +import scala.util.control.NonFatal + +import com.amazonaws.services.kinesis.clientlibrary.interfaces.IRecordProcessorCheckpointer +import com.amazonaws.services.kinesis.clientlibrary.types.ShutdownReason + +import org.apache.spark.Logging +import org.apache.spark.streaming.Duration +import org.apache.spark.streaming.util.RecurringTimer +import org.apache.spark.util.{Clock, SystemClock, ThreadUtils} + +/** + * This is a helper class for managing Kinesis checkpointing. + * + * @param receiver The receiver that keeps track of which sequence numbers we can checkpoint + * @param checkpointInterval How frequently we will checkpoint to DynamoDB + * @param workerId Worker Id of KCL worker for logging purposes + * @param clock In order to use ManualClocks for the purpose of testing + */ +private[kinesis] class KinesisCheckpointer( + receiver: KinesisReceiver[_], + checkpointInterval: Duration, + workerId: String, + clock: Clock = new SystemClock) extends Logging { + + // a map from shardId's to checkpointers + private val checkpointers = new ConcurrentHashMap[String, IRecordProcessorCheckpointer]() + + private val lastCheckpointedSeqNums = new ConcurrentHashMap[String, String]() + + private val checkpointerThread: RecurringTimer = startCheckpointerThread() + + /** Update the checkpointer instance to the most recent one for the given shardId. */ + def setCheckpointer(shardId: String, checkpointer: IRecordProcessorCheckpointer): Unit = { + checkpointers.put(shardId, checkpointer) + } + + /** + * Stop tracking the specified shardId. + * + * If a checkpointer is provided, e.g. on IRecordProcessor.shutdown [[ShutdownReason.TERMINATE]], + * we will use that to make the final checkpoint. If `null` is provided, we will not make the + * checkpoint, e.g. in case of [[ShutdownReason.ZOMBIE]]. + */ + def removeCheckpointer(shardId: String, checkpointer: IRecordProcessorCheckpointer): Unit = { + synchronized { + checkpointers.remove(shardId) + checkpoint(shardId, checkpointer) + } + } + + /** Perform the checkpoint. */ + private def checkpoint(shardId: String, checkpointer: IRecordProcessorCheckpointer): Unit = { + try { + if (checkpointer != null) { + receiver.getLatestSeqNumToCheckpoint(shardId).foreach { latestSeqNum => + val lastSeqNum = lastCheckpointedSeqNums.get(shardId) + // Kinesis sequence numbers are monotonically increasing strings, therefore we can do + // safely do the string comparison + if (lastSeqNum == null || latestSeqNum > lastSeqNum) { + /* Perform the checkpoint */ + KinesisRecordProcessor.retryRandom(checkpointer.checkpoint(latestSeqNum), 4, 100) + logDebug(s"Checkpoint: WorkerId $workerId completed checkpoint at sequence number" + + s" $latestSeqNum for shardId $shardId") + lastCheckpointedSeqNums.put(shardId, latestSeqNum) + } + } + } else { + logDebug(s"Checkpointing skipped for shardId $shardId. Checkpointer not set.") + } + } catch { + case NonFatal(e) => + logWarning(s"Failed to checkpoint shardId $shardId to DynamoDB.", e) + } + } + + /** Checkpoint the latest saved sequence numbers for all active shardId's. */ + private def checkpointAll(): Unit = synchronized { + // if this method throws an exception, then the scheduled task will not run again + try { + val shardIds = checkpointers.keys() + while (shardIds.hasMoreElements) { + val shardId = shardIds.nextElement() + checkpoint(shardId, checkpointers.get(shardId)) + } + } catch { + case NonFatal(e) => + logWarning("Failed to checkpoint to DynamoDB.", e) + } + } + + /** + * Start the checkpointer thread with the given checkpoint duration. + */ + private def startCheckpointerThread(): RecurringTimer = { + val period = checkpointInterval.milliseconds + val threadName = s"Kinesis Checkpointer - Worker $workerId" + val timer = new RecurringTimer(clock, period, _ => checkpointAll(), threadName) + timer.start() + logDebug(s"Started checkpointer thread: $threadName") + timer + } + + /** + * Shutdown the checkpointer. Should be called on the onStop of the Receiver. + */ + def shutdown(): Unit = { + // the recurring timer checkpoints for us one last time. + checkpointerThread.stop(interruptTimer = false) + checkpointers.clear() + lastCheckpointedSeqNums.clear() + logInfo("Successfully shutdown Kinesis Checkpointer.") + } +} diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala index 134d627cdaffa..50993f157cd95 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala @@ -23,7 +23,7 @@ import scala.collection.mutable import scala.util.control.NonFatal import com.amazonaws.auth.{AWSCredentials, AWSCredentialsProvider, DefaultAWSCredentialsProviderChain} -import com.amazonaws.services.kinesis.clientlibrary.interfaces.{IRecordProcessor, IRecordProcessorFactory} +import com.amazonaws.services.kinesis.clientlibrary.interfaces.{IRecordProcessorCheckpointer, IRecordProcessor, IRecordProcessorFactory} import com.amazonaws.services.kinesis.clientlibrary.lib.worker.{InitialPositionInStream, KinesisClientLibConfiguration, Worker} import com.amazonaws.services.kinesis.model.Record @@ -31,8 +31,7 @@ import org.apache.spark.storage.{StorageLevel, StreamBlockId} import org.apache.spark.streaming.Duration import org.apache.spark.streaming.receiver.{BlockGenerator, BlockGeneratorListener, Receiver} import org.apache.spark.util.Utils -import org.apache.spark.{Logging, SparkEnv} - +import org.apache.spark.Logging private[kinesis] case class SerializableAWSCredentials(accessKeyId: String, secretKey: String) @@ -127,6 +126,11 @@ private[kinesis] class KinesisReceiver[T]( private val blockIdToSeqNumRanges = new mutable.HashMap[StreamBlockId, SequenceNumberRanges] with mutable.SynchronizedMap[StreamBlockId, SequenceNumberRanges] + /** + * The centralized kinesisCheckpointer that checkpoints based on the given checkpointInterval. + */ + @volatile private var kinesisCheckpointer: KinesisCheckpointer = null + /** * Latest sequence number ranges that have been stored successfully. * This is used for checkpointing through KCL */ @@ -141,6 +145,7 @@ private[kinesis] class KinesisReceiver[T]( workerId = Utils.localHostName() + ":" + UUID.randomUUID() + kinesisCheckpointer = new KinesisCheckpointer(receiver, checkpointInterval, workerId) // KCL config instance val awsCredProvider = resolveAWSCredentialsProvider() val kinesisClientLibConfiguration = @@ -157,8 +162,8 @@ private[kinesis] class KinesisReceiver[T]( * We're using our custom KinesisRecordProcessor in this case. */ val recordProcessorFactory = new IRecordProcessorFactory { - override def createProcessor: IRecordProcessor = new KinesisRecordProcessor(receiver, - workerId, new KinesisCheckpointState(checkpointInterval)) + override def createProcessor: IRecordProcessor = + new KinesisRecordProcessor(receiver, workerId) } worker = new Worker(recordProcessorFactory, kinesisClientLibConfiguration) @@ -198,6 +203,10 @@ private[kinesis] class KinesisReceiver[T]( logInfo(s"Stopped receiver for workerId $workerId") } workerId = null + if (kinesisCheckpointer != null) { + kinesisCheckpointer.shutdown() + kinesisCheckpointer = null + } } /** Add records of the given shard to the current block being generated */ @@ -216,6 +225,25 @@ private[kinesis] class KinesisReceiver[T]( shardIdToLatestStoredSeqNum.get(shardId) } + /** + * Set the checkpointer that will be used to checkpoint sequence numbers to DynamoDB for the + * given shardId. + */ + def setCheckpointer(shardId: String, checkpointer: IRecordProcessorCheckpointer): Unit = { + assert(kinesisCheckpointer != null, "Kinesis Checkpointer not initialized!") + kinesisCheckpointer.setCheckpointer(shardId, checkpointer) + } + + /** + * Remove the checkpointer for the given shardId. The provided checkpointer will be used to + * checkpoint one last time for the given shard. If `checkpointer` is `null`, then we will not + * checkpoint. + */ + def removeCheckpointer(shardId: String, checkpointer: IRecordProcessorCheckpointer): Unit = { + assert(kinesisCheckpointer != null, "Kinesis Checkpointer not initialized!") + kinesisCheckpointer.removeCheckpointer(shardId, checkpointer) + } + /** * Remember the range of sequence numbers that was added to the currently active block. * Internally, this is synchronized with `finalizeRangesForCurrentBlock()`. diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala index 1d5178790ec4c..e381ffa0cbef4 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala @@ -27,26 +27,23 @@ import com.amazonaws.services.kinesis.clientlibrary.types.ShutdownReason import com.amazonaws.services.kinesis.model.Record import org.apache.spark.Logging +import org.apache.spark.streaming.Duration /** * Kinesis-specific implementation of the Kinesis Client Library (KCL) IRecordProcessor. * This implementation operates on the Array[Byte] from the KinesisReceiver. * The Kinesis Worker creates an instance of this KinesisRecordProcessor for each - * shard in the Kinesis stream upon startup. This is normally done in separate threads, - * but the KCLs within the KinesisReceivers will balance themselves out if you create - * multiple Receivers. + * shard in the Kinesis stream upon startup. This is normally done in separate threads, + * but the KCLs within the KinesisReceivers will balance themselves out if you create + * multiple Receivers. * * @param receiver Kinesis receiver * @param workerId for logging purposes - * @param checkpointState represents the checkpoint state including the next checkpoint time. - * It's injected here for mocking purposes. */ -private[kinesis] class KinesisRecordProcessor[T]( - receiver: KinesisReceiver[T], - workerId: String, - checkpointState: KinesisCheckpointState) extends IRecordProcessor with Logging { +private[kinesis] class KinesisRecordProcessor[T](receiver: KinesisReceiver[T], workerId: String) + extends IRecordProcessor with Logging { - // shardId to be populated during initialize() + // shardId populated during initialize() @volatile private var shardId: String = _ @@ -74,34 +71,7 @@ private[kinesis] class KinesisRecordProcessor[T]( try { receiver.addRecords(shardId, batch) logDebug(s"Stored: Worker $workerId stored ${batch.size} records for shardId $shardId") - - /* - * - * Checkpoint the sequence number of the last record successfully stored. - * Note that in this current implementation, the checkpointing occurs only when after - * checkpointIntervalMillis from the last checkpoint, AND when there is new record - * to process. This leads to the checkpointing lagging behind what records have been - * stored by the receiver. Ofcourse, this can lead records processed more than once, - * under failures and restarts. - * - * TODO: Instead of checkpointing here, run a separate timer task to perform - * checkpointing so that it checkpoints in a timely manner independent of whether - * new records are available or not. - */ - if (checkpointState.shouldCheckpoint()) { - receiver.getLatestSeqNumToCheckpoint(shardId).foreach { latestSeqNum => - /* Perform the checkpoint */ - KinesisRecordProcessor.retryRandom(checkpointer.checkpoint(latestSeqNum), 4, 100) - - /* Update the next checkpoint time */ - checkpointState.advanceCheckpoint() - - logDebug(s"Checkpoint: WorkerId $workerId completed checkpoint of ${batch.size}" + - s" records for shardId $shardId") - logDebug(s"Checkpoint: Next checkpoint is at " + - s" ${checkpointState.checkpointClock.getTimeMillis()} for shardId $shardId") - } - } + receiver.setCheckpointer(shardId, checkpointer) } catch { case NonFatal(e) => { /* @@ -142,23 +112,18 @@ private[kinesis] class KinesisRecordProcessor[T]( * It's now OK to read from the new shards that resulted from a resharding event. */ case ShutdownReason.TERMINATE => - val latestSeqNumToCheckpointOption = receiver.getLatestSeqNumToCheckpoint(shardId) - if (latestSeqNumToCheckpointOption.nonEmpty) { - KinesisRecordProcessor.retryRandom( - checkpointer.checkpoint(latestSeqNumToCheckpointOption.get), 4, 100) - } + receiver.removeCheckpointer(shardId, checkpointer) /* - * ZOMBIE Use Case. NoOp. + * ZOMBIE Use Case or Unknown reason. NoOp. * No checkpoint because other workers may have taken over and already started processing * the same records. * This may lead to records being processed more than once. */ - case ShutdownReason.ZOMBIE => - - /* Unknown reason. NoOp */ case _ => + receiver.removeCheckpointer(shardId, null) // return null so that we don't checkpoint } + } } diff --git a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointerSuite.scala b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointerSuite.scala new file mode 100644 index 0000000000000..645e64a0bc3a0 --- /dev/null +++ b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointerSuite.scala @@ -0,0 +1,152 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kinesis + +import java.util.concurrent.{TimeoutException, ExecutorService} + +import scala.concurrent.{Await, ExecutionContext, Future} +import scala.concurrent.duration._ +import scala.language.postfixOps + +import com.amazonaws.services.kinesis.clientlibrary.interfaces.IRecordProcessorCheckpointer +import org.mockito.Matchers._ +import org.mockito.Mockito._ +import org.mockito.invocation.InvocationOnMock +import org.mockito.stubbing.Answer +import org.scalatest.{PrivateMethodTester, BeforeAndAfterEach} +import org.scalatest.concurrent.Eventually +import org.scalatest.concurrent.Eventually._ +import org.scalatest.mock.MockitoSugar + +import org.apache.spark.streaming.{Duration, TestSuiteBase} +import org.apache.spark.util.ManualClock + +class KinesisCheckpointerSuite extends TestSuiteBase + with MockitoSugar + with BeforeAndAfterEach + with PrivateMethodTester + with Eventually { + + private val workerId = "dummyWorkerId" + private val shardId = "dummyShardId" + private val seqNum = "123" + private val otherSeqNum = "245" + private val checkpointInterval = Duration(10) + private val someSeqNum = Some(seqNum) + private val someOtherSeqNum = Some(otherSeqNum) + + private var receiverMock: KinesisReceiver[Array[Byte]] = _ + private var checkpointerMock: IRecordProcessorCheckpointer = _ + private var kinesisCheckpointer: KinesisCheckpointer = _ + private var clock: ManualClock = _ + + private val checkpoint = PrivateMethod[Unit]('checkpoint) + + override def beforeEach(): Unit = { + receiverMock = mock[KinesisReceiver[Array[Byte]]] + checkpointerMock = mock[IRecordProcessorCheckpointer] + clock = new ManualClock() + kinesisCheckpointer = new KinesisCheckpointer(receiverMock, checkpointInterval, workerId, clock) + } + + test("checkpoint is not called twice for the same sequence number") { + when(receiverMock.getLatestSeqNumToCheckpoint(shardId)).thenReturn(someSeqNum) + kinesisCheckpointer.invokePrivate(checkpoint(shardId, checkpointerMock)) + kinesisCheckpointer.invokePrivate(checkpoint(shardId, checkpointerMock)) + + verify(checkpointerMock, times(1)).checkpoint(anyString()) + } + + test("checkpoint is called after sequence number increases") { + when(receiverMock.getLatestSeqNumToCheckpoint(shardId)) + .thenReturn(someSeqNum).thenReturn(someOtherSeqNum) + kinesisCheckpointer.invokePrivate(checkpoint(shardId, checkpointerMock)) + kinesisCheckpointer.invokePrivate(checkpoint(shardId, checkpointerMock)) + + verify(checkpointerMock, times(1)).checkpoint(seqNum) + verify(checkpointerMock, times(1)).checkpoint(otherSeqNum) + } + + test("should checkpoint if we have exceeded the checkpoint interval") { + when(receiverMock.getLatestSeqNumToCheckpoint(shardId)) + .thenReturn(someSeqNum).thenReturn(someOtherSeqNum) + + kinesisCheckpointer.setCheckpointer(shardId, checkpointerMock) + clock.advance(5 * checkpointInterval.milliseconds) + + eventually(timeout(1 second)) { + verify(checkpointerMock, times(1)).checkpoint(seqNum) + verify(checkpointerMock, times(1)).checkpoint(otherSeqNum) + } + } + + test("shouldn't checkpoint if we have not exceeded the checkpoint interval") { + when(receiverMock.getLatestSeqNumToCheckpoint(shardId)).thenReturn(someSeqNum) + + kinesisCheckpointer.setCheckpointer(shardId, checkpointerMock) + clock.advance(checkpointInterval.milliseconds / 2) + + verify(checkpointerMock, never()).checkpoint(anyString()) + } + + test("should not checkpoint for the same sequence number") { + when(receiverMock.getLatestSeqNumToCheckpoint(shardId)).thenReturn(someSeqNum) + + kinesisCheckpointer.setCheckpointer(shardId, checkpointerMock) + + clock.advance(checkpointInterval.milliseconds * 5) + eventually(timeout(1 second)) { + verify(checkpointerMock, atMost(1)).checkpoint(anyString()) + } + } + + test("removing checkpointer checkpoints one last time") { + when(receiverMock.getLatestSeqNumToCheckpoint(shardId)).thenReturn(someSeqNum) + + kinesisCheckpointer.removeCheckpointer(shardId, checkpointerMock) + verify(checkpointerMock, times(1)).checkpoint(anyString()) + } + + test("if checkpointing is going on, wait until finished before removing and checkpointing") { + when(receiverMock.getLatestSeqNumToCheckpoint(shardId)) + .thenReturn(someSeqNum).thenReturn(someOtherSeqNum) + when(checkpointerMock.checkpoint(anyString)).thenAnswer(new Answer[Unit] { + override def answer(invocations: InvocationOnMock): Unit = { + clock.waitTillTime(clock.getTimeMillis() + checkpointInterval.milliseconds / 2) + } + }) + + kinesisCheckpointer.setCheckpointer(shardId, checkpointerMock) + clock.advance(checkpointInterval.milliseconds) + eventually(timeout(1 second)) { + verify(checkpointerMock, times(1)).checkpoint(anyString()) + } + // don't block test thread + val f = Future(kinesisCheckpointer.removeCheckpointer(shardId, checkpointerMock))( + ExecutionContext.global) + + intercept[TimeoutException] { + Await.ready(f, 50 millis) + } + + clock.advance(checkpointInterval.milliseconds / 2) + eventually(timeout(1 second)) { + verify(checkpointerMock, times(2)).checkpoint(anyString()) + } + } +} diff --git a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala index 17ab444704f44..e5c70db554a27 100644 --- a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala +++ b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala @@ -25,12 +25,13 @@ import com.amazonaws.services.kinesis.clientlibrary.interfaces.IRecordProcessorC import com.amazonaws.services.kinesis.clientlibrary.types.ShutdownReason import com.amazonaws.services.kinesis.model.Record import org.mockito.Matchers._ +import org.mockito.Matchers.{eq => meq} import org.mockito.Mockito._ import org.scalatest.mock.MockitoSugar import org.scalatest.{BeforeAndAfter, Matchers} -import org.apache.spark.streaming.{Milliseconds, TestSuiteBase} -import org.apache.spark.util.{Clock, ManualClock, Utils} +import org.apache.spark.streaming.{Duration, TestSuiteBase} +import org.apache.spark.util.Utils /** * Suite of Kinesis streaming receiver tests focusing mostly on the KinesisRecordProcessor @@ -44,6 +45,7 @@ class KinesisReceiverSuite extends TestSuiteBase with Matchers with BeforeAndAft val workerId = "dummyWorkerId" val shardId = "dummyShardId" val seqNum = "dummySeqNum" + val checkpointInterval = Duration(10) val someSeqNum = Some(seqNum) val record1 = new Record() @@ -54,24 +56,10 @@ class KinesisReceiverSuite extends TestSuiteBase with Matchers with BeforeAndAft var receiverMock: KinesisReceiver[Array[Byte]] = _ var checkpointerMock: IRecordProcessorCheckpointer = _ - var checkpointClockMock: ManualClock = _ - var checkpointStateMock: KinesisCheckpointState = _ - var currentClockMock: Clock = _ override def beforeFunction(): Unit = { receiverMock = mock[KinesisReceiver[Array[Byte]]] checkpointerMock = mock[IRecordProcessorCheckpointer] - checkpointClockMock = mock[ManualClock] - checkpointStateMock = mock[KinesisCheckpointState] - currentClockMock = mock[Clock] - } - - override def afterFunction(): Unit = { - super.afterFunction() - // Since this suite was originally written using EasyMock, add this to preserve the old - // mocking semantics (see SPARK-5735 for more details) - verifyNoMoreInteractions(receiverMock, checkpointerMock, checkpointClockMock, - checkpointStateMock, currentClockMock) } test("check serializability of SerializableAWSCredentials") { @@ -79,113 +67,67 @@ class KinesisReceiverSuite extends TestSuiteBase with Matchers with BeforeAndAft Utils.serialize(new SerializableAWSCredentials("x", "y"))) } - test("process records including store and checkpoint") { + test("process records including store and set checkpointer") { when(receiverMock.isStopped()).thenReturn(false) - when(receiverMock.getLatestSeqNumToCheckpoint(shardId)).thenReturn(someSeqNum) - when(checkpointStateMock.shouldCheckpoint()).thenReturn(true) - val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId, checkpointStateMock) + val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId) recordProcessor.initialize(shardId) recordProcessor.processRecords(batch, checkpointerMock) verify(receiverMock, times(1)).isStopped() verify(receiverMock, times(1)).addRecords(shardId, batch) - verify(receiverMock, times(1)).getLatestSeqNumToCheckpoint(shardId) - verify(checkpointStateMock, times(1)).shouldCheckpoint() - verify(checkpointerMock, times(1)).checkpoint(anyString) - verify(checkpointStateMock, times(1)).advanceCheckpoint() + verify(receiverMock, times(1)).setCheckpointer(shardId, checkpointerMock) } - test("shouldn't store and checkpoint when receiver is stopped") { + test("shouldn't store and update checkpointer when receiver is stopped") { when(receiverMock.isStopped()).thenReturn(true) - val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId, checkpointStateMock) + val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId) recordProcessor.processRecords(batch, checkpointerMock) verify(receiverMock, times(1)).isStopped() verify(receiverMock, never).addRecords(anyString, anyListOf(classOf[Record])) - verify(checkpointerMock, never).checkpoint(anyString) + verify(receiverMock, never).setCheckpointer(anyString, meq(checkpointerMock)) } - test("shouldn't checkpoint when exception occurs during store") { + test("shouldn't update checkpointer when exception occurs during store") { when(receiverMock.isStopped()).thenReturn(false) when( receiverMock.addRecords(anyString, anyListOf(classOf[Record])) ).thenThrow(new RuntimeException()) intercept[RuntimeException] { - val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId, checkpointStateMock) + val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId) recordProcessor.initialize(shardId) recordProcessor.processRecords(batch, checkpointerMock) } verify(receiverMock, times(1)).isStopped() verify(receiverMock, times(1)).addRecords(shardId, batch) - verify(checkpointerMock, never).checkpoint(anyString) - } - - test("should set checkpoint time to currentTime + checkpoint interval upon instantiation") { - when(currentClockMock.getTimeMillis()).thenReturn(0) - - val checkpointIntervalMillis = 10 - val checkpointState = - new KinesisCheckpointState(Milliseconds(checkpointIntervalMillis), currentClockMock) - assert(checkpointState.checkpointClock.getTimeMillis() == checkpointIntervalMillis) - - verify(currentClockMock, times(1)).getTimeMillis() - } - - test("should checkpoint if we have exceeded the checkpoint interval") { - when(currentClockMock.getTimeMillis()).thenReturn(0) - - val checkpointState = new KinesisCheckpointState(Milliseconds(Long.MinValue), currentClockMock) - assert(checkpointState.shouldCheckpoint()) - - verify(currentClockMock, times(1)).getTimeMillis() - } - - test("shouldn't checkpoint if we have not exceeded the checkpoint interval") { - when(currentClockMock.getTimeMillis()).thenReturn(0) - - val checkpointState = new KinesisCheckpointState(Milliseconds(Long.MaxValue), currentClockMock) - assert(!checkpointState.shouldCheckpoint()) - - verify(currentClockMock, times(1)).getTimeMillis() - } - - test("should add to time when advancing checkpoint") { - when(currentClockMock.getTimeMillis()).thenReturn(0) - - val checkpointIntervalMillis = 10 - val checkpointState = - new KinesisCheckpointState(Milliseconds(checkpointIntervalMillis), currentClockMock) - assert(checkpointState.checkpointClock.getTimeMillis() == checkpointIntervalMillis) - checkpointState.advanceCheckpoint() - assert(checkpointState.checkpointClock.getTimeMillis() == (2 * checkpointIntervalMillis)) - - verify(currentClockMock, times(1)).getTimeMillis() + verify(receiverMock, never).setCheckpointer(anyString, meq(checkpointerMock)) } test("shutdown should checkpoint if the reason is TERMINATE") { when(receiverMock.getLatestSeqNumToCheckpoint(shardId)).thenReturn(someSeqNum) - val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId, checkpointStateMock) + val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId) recordProcessor.initialize(shardId) recordProcessor.shutdown(checkpointerMock, ShutdownReason.TERMINATE) - verify(receiverMock, times(1)).getLatestSeqNumToCheckpoint(shardId) - verify(checkpointerMock, times(1)).checkpoint(anyString) + verify(receiverMock, times(1)).removeCheckpointer(meq(shardId), meq(checkpointerMock)) } + test("shutdown should not checkpoint if the reason is something other than TERMINATE") { when(receiverMock.getLatestSeqNumToCheckpoint(shardId)).thenReturn(someSeqNum) - val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId, checkpointStateMock) + val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId) recordProcessor.initialize(shardId) recordProcessor.shutdown(checkpointerMock, ShutdownReason.ZOMBIE) recordProcessor.shutdown(checkpointerMock, null) - verify(checkpointerMock, never).checkpoint(anyString) + verify(receiverMock, times(2)).removeCheckpointer(meq(shardId), + meq[IRecordProcessorCheckpointer](null)) } test("retry success on first attempt") { From 8a2336893a7ff610a6c4629dd567b85078730616 Mon Sep 17 00:00:00 2001 From: Yu ISHIKAWA Date: Mon, 9 Nov 2015 14:56:36 -0800 Subject: [PATCH 57/88] [SPARK-6517][MLLIB] Implement the Algorithm of Hierarchical Clustering I implemented a hierarchical clustering algorithm again. This PR doesn't include examples, documentation and spark.ml APIs. I am going to send another PRs later. https://issues.apache.org/jira/browse/SPARK-6517 - This implementation based on a bi-sectiong K-means clustering. - It derives from the freeman-lab 's implementation - The basic idea is not changed from the previous version. (#2906) - However, It is 1000x faster than the previous version through parallel processing. Thank you for your great cooperation, RJ Nowling(rnowling), Jeremy Freeman(freeman-lab), Xiangrui Meng(mengxr) and Sean Owen(srowen). Author: Yu ISHIKAWA Author: Xiangrui Meng Author: Yu ISHIKAWA Closes #5267 from yu-iskw/new-hierarchical-clustering. --- .../mllib/clustering/BisectingKMeans.scala | 491 ++++++++++++++++++ .../clustering/BisectingKMeansModel.scala | 95 ++++ .../clustering/JavaBisectingKMeansSuite.java | 73 +++ .../clustering/BisectingKMeansSuite.scala | 182 +++++++ 4 files changed, 841 insertions(+) create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeansModel.scala create mode 100644 mllib/src/test/java/org/apache/spark/mllib/clustering/JavaBisectingKMeansSuite.java create mode 100644 mllib/src/test/scala/org/apache/spark/mllib/clustering/BisectingKMeansSuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala new file mode 100644 index 0000000000000..29a7aa0bb63f2 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala @@ -0,0 +1,491 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.clustering + +import java.util.Random + +import scala.collection.mutable + +import org.apache.spark.Logging +import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.api.java.JavaRDD +import org.apache.spark.mllib.linalg.{BLAS, Vector, Vectors} +import org.apache.spark.mllib.util.MLUtils +import org.apache.spark.rdd.RDD +import org.apache.spark.storage.StorageLevel + +/** + * A bisecting k-means algorithm based on the paper "A comparison of document clustering techniques" + * by Steinbach, Karypis, and Kumar, with modification to fit Spark. + * The algorithm starts from a single cluster that contains all points. + * Iteratively it finds divisible clusters on the bottom level and bisects each of them using + * k-means, until there are `k` leaf clusters in total or no leaf clusters are divisible. + * The bisecting steps of clusters on the same level are grouped together to increase parallelism. + * If bisecting all divisible clusters on the bottom level would result more than `k` leaf clusters, + * larger clusters get higher priority. + * + * @param k the desired number of leaf clusters (default: 4). The actual number could be smaller if + * there are no divisible leaf clusters. + * @param maxIterations the max number of k-means iterations to split clusters (default: 20) + * @param minDivisibleClusterSize the minimum number of points (if >= 1.0) or the minimum proportion + * of points (if < 1.0) of a divisible cluster (default: 1) + * @param seed a random seed (default: hash value of the class name) + * + * @see [[http://glaros.dtc.umn.edu/gkhome/fetch/papers/docclusterKDDTMW00.pdf + * Steinbach, Karypis, and Kumar, A comparison of document clustering techniques, + * KDD Workshop on Text Mining, 2000.]] + */ +@Since("1.6.0") +@Experimental +class BisectingKMeans private ( + private var k: Int, + private var maxIterations: Int, + private var minDivisibleClusterSize: Double, + private var seed: Long) extends Logging { + + import BisectingKMeans._ + + /** + * Constructs with the default configuration + */ + @Since("1.6.0") + def this() = this(4, 20, 1.0, classOf[BisectingKMeans].getName.##) + + /** + * Sets the desired number of leaf clusters (default: 4). + * The actual number could be smaller if there are no divisible leaf clusters. + */ + @Since("1.6.0") + def setK(k: Int): this.type = { + require(k > 0, s"k must be positive but got $k.") + this.k = k + this + } + + /** + * Gets the desired number of leaf clusters. + */ + @Since("1.6.0") + def getK: Int = this.k + + /** + * Sets the max number of k-means iterations to split clusters (default: 20). + */ + @Since("1.6.0") + def setMaxIterations(maxIterations: Int): this.type = { + require(maxIterations > 0, s"maxIterations must be positive but got $maxIterations.") + this.maxIterations = maxIterations + this + } + + /** + * Gets the max number of k-means iterations to split clusters. + */ + @Since("1.6.0") + def getMaxIterations: Int = this.maxIterations + + /** + * Sets the minimum number of points (if >= `1.0`) or the minimum proportion of points + * (if < `1.0`) of a divisible cluster (default: 1). + */ + @Since("1.6.0") + def setMinDivisibleClusterSize(minDivisibleClusterSize: Double): this.type = { + require(minDivisibleClusterSize > 0.0, + s"minDivisibleClusterSize must be positive but got $minDivisibleClusterSize.") + this.minDivisibleClusterSize = minDivisibleClusterSize + this + } + + /** + * Gets the minimum number of points (if >= `1.0`) or the minimum proportion of points + * (if < `1.0`) of a divisible cluster. + */ + @Since("1.6.0") + def getMinDivisibleClusterSize: Double = minDivisibleClusterSize + + /** + * Sets the random seed (default: hash value of the class name). + */ + @Since("1.6.0") + def setSeed(seed: Long): this.type = { + this.seed = seed + this + } + + /** + * Gets the random seed. + */ + @Since("1.6.0") + def getSeed: Long = this.seed + + /** + * Runs the bisecting k-means algorithm. + * @param input RDD of vectors + * @return model for the bisecting kmeans + */ + @Since("1.6.0") + def run(input: RDD[Vector]): BisectingKMeansModel = { + if (input.getStorageLevel == StorageLevel.NONE) { + logWarning(s"The input RDD ${input.id} is not directly cached, which may hurt performance if" + + " its parent RDDs are also not cached.") + } + val d = input.map(_.size).first() + logInfo(s"Feature dimension: $d.") + // Compute and cache vector norms for fast distance computation. + val norms = input.map(v => Vectors.norm(v, 2.0)).persist(StorageLevel.MEMORY_AND_DISK) + val vectors = input.zip(norms).map { case (x, norm) => new VectorWithNorm(x, norm) } + var assignments = vectors.map(v => (ROOT_INDEX, v)) + var activeClusters = summarize(d, assignments) + val rootSummary = activeClusters(ROOT_INDEX) + val n = rootSummary.size + logInfo(s"Number of points: $n.") + logInfo(s"Initial cost: ${rootSummary.cost}.") + val minSize = if (minDivisibleClusterSize >= 1.0) { + math.ceil(minDivisibleClusterSize).toLong + } else { + math.ceil(minDivisibleClusterSize * n).toLong + } + logInfo(s"The minimum number of points of a divisible cluster is $minSize.") + var inactiveClusters = mutable.Seq.empty[(Long, ClusterSummary)] + val random = new Random(seed) + var numLeafClustersNeeded = k - 1 + var level = 1 + while (activeClusters.nonEmpty && numLeafClustersNeeded > 0 && level < LEVEL_LIMIT) { + // Divisible clusters are sufficiently large and have non-trivial cost. + var divisibleClusters = activeClusters.filter { case (_, summary) => + (summary.size >= minSize) && (summary.cost > MLUtils.EPSILON * summary.size) + } + // If we don't need all divisible clusters, take the larger ones. + if (divisibleClusters.size > numLeafClustersNeeded) { + divisibleClusters = divisibleClusters.toSeq.sortBy { case (_, summary) => + -summary.size + }.take(numLeafClustersNeeded) + .toMap + } + if (divisibleClusters.nonEmpty) { + val divisibleIndices = divisibleClusters.keys.toSet + logInfo(s"Dividing ${divisibleIndices.size} clusters on level $level.") + var newClusterCenters = divisibleClusters.flatMap { case (index, summary) => + val (left, right) = splitCenter(summary.center, random) + Iterator((leftChildIndex(index), left), (rightChildIndex(index), right)) + }.map(identity) // workaround for a Scala bug (SI-7005) that produces a not serializable map + var newClusters: Map[Long, ClusterSummary] = null + var newAssignments: RDD[(Long, VectorWithNorm)] = null + for (iter <- 0 until maxIterations) { + newAssignments = updateAssignments(assignments, divisibleIndices, newClusterCenters) + .filter { case (index, _) => + divisibleIndices.contains(parentIndex(index)) + } + newClusters = summarize(d, newAssignments) + newClusterCenters = newClusters.mapValues(_.center).map(identity) + } + // TODO: Unpersist old indices. + val indices = updateAssignments(assignments, divisibleIndices, newClusterCenters).keys + .persist(StorageLevel.MEMORY_AND_DISK) + assignments = indices.zip(vectors) + inactiveClusters ++= activeClusters + activeClusters = newClusters + numLeafClustersNeeded -= divisibleClusters.size + } else { + logInfo(s"None active and divisible clusters left on level $level. Stop iterations.") + inactiveClusters ++= activeClusters + activeClusters = Map.empty + } + level += 1 + } + val clusters = activeClusters ++ inactiveClusters + val root = buildTree(clusters) + new BisectingKMeansModel(root) + } + + /** + * Java-friendly version of [[run(RDD[Vector])*]] + */ + def run(data: JavaRDD[Vector]): BisectingKMeansModel = run(data.rdd) +} + +private object BisectingKMeans extends Serializable { + + /** The index of the root node of a tree. */ + private val ROOT_INDEX: Long = 1 + + private val MAX_DIVISIBLE_CLUSTER_INDEX: Long = Long.MaxValue / 2 + + private val LEVEL_LIMIT = math.log10(Long.MaxValue) / math.log10(2) + + /** Returns the left child index of the given node index. */ + private def leftChildIndex(index: Long): Long = { + require(index <= MAX_DIVISIBLE_CLUSTER_INDEX, s"Child index out of bound: 2 * $index.") + 2 * index + } + + /** Returns the right child index of the given node index. */ + private def rightChildIndex(index: Long): Long = { + require(index <= MAX_DIVISIBLE_CLUSTER_INDEX, s"Child index out of bound: 2 * $index + 1.") + 2 * index + 1 + } + + /** Returns the parent index of the given node index, or 0 if the input is 1 (root). */ + private def parentIndex(index: Long): Long = { + index / 2 + } + + /** + * Summarizes data by each cluster as Map. + * @param d feature dimension + * @param assignments pairs of point and its cluster index + * @return a map from cluster indices to corresponding cluster summaries + */ + private def summarize( + d: Int, + assignments: RDD[(Long, VectorWithNorm)]): Map[Long, ClusterSummary] = { + assignments.aggregateByKey(new ClusterSummaryAggregator(d))( + seqOp = (agg, v) => agg.add(v), + combOp = (agg1, agg2) => agg1.merge(agg2) + ).mapValues(_.summary) + .collect().toMap + } + + /** + * Cluster summary aggregator. + * @param d feature dimension + */ + private class ClusterSummaryAggregator(val d: Int) extends Serializable { + private var n: Long = 0L + private val sum: Vector = Vectors.zeros(d) + private var sumSq: Double = 0.0 + + /** Adds a point. */ + def add(v: VectorWithNorm): this.type = { + n += 1L + // TODO: use a numerically stable approach to estimate cost + sumSq += v.norm * v.norm + BLAS.axpy(1.0, v.vector, sum) + this + } + + /** Merges another aggregator. */ + def merge(other: ClusterSummaryAggregator): this.type = { + n += other.n + sumSq += other.sumSq + BLAS.axpy(1.0, other.sum, sum) + this + } + + /** Returns the summary. */ + def summary: ClusterSummary = { + val mean = sum.copy + if (n > 0L) { + BLAS.scal(1.0 / n, mean) + } + val center = new VectorWithNorm(mean) + val cost = math.max(sumSq - n * center.norm * center.norm, 0.0) + new ClusterSummary(n, center, cost) + } + } + + /** + * Bisects a cluster center. + * + * @param center current cluster center + * @param random a random number generator + * @return initial centers + */ + private def splitCenter( + center: VectorWithNorm, + random: Random): (VectorWithNorm, VectorWithNorm) = { + val d = center.vector.size + val norm = center.norm + val level = 1e-4 * norm + val noise = Vectors.dense(Array.fill(d)(random.nextDouble())) + val left = center.vector.copy + BLAS.axpy(-level, noise, left) + val right = center.vector.copy + BLAS.axpy(level, noise, right) + (new VectorWithNorm(left), new VectorWithNorm(right)) + } + + /** + * Updates assignments. + * @param assignments current assignments + * @param divisibleIndices divisible cluster indices + * @param newClusterCenters new cluster centers + * @return new assignments + */ + private def updateAssignments( + assignments: RDD[(Long, VectorWithNorm)], + divisibleIndices: Set[Long], + newClusterCenters: Map[Long, VectorWithNorm]): RDD[(Long, VectorWithNorm)] = { + assignments.map { case (index, v) => + if (divisibleIndices.contains(index)) { + val children = Seq(leftChildIndex(index), rightChildIndex(index)) + val selected = children.minBy { child => + KMeans.fastSquaredDistance(newClusterCenters(child), v) + } + (selected, v) + } else { + (index, v) + } + } + } + + /** + * Builds a clustering tree by re-indexing internal and leaf clusters. + * @param clusters a map from cluster indices to corresponding cluster summaries + * @return the root node of the clustering tree + */ + private def buildTree(clusters: Map[Long, ClusterSummary]): ClusteringTreeNode = { + var leafIndex = 0 + var internalIndex = -1 + + /** + * Builds a subtree from this given node index. + */ + def buildSubTree(rawIndex: Long): ClusteringTreeNode = { + val cluster = clusters(rawIndex) + val size = cluster.size + val center = cluster.center + val cost = cluster.cost + val isInternal = clusters.contains(leftChildIndex(rawIndex)) + if (isInternal) { + val index = internalIndex + internalIndex -= 1 + val leftIndex = leftChildIndex(rawIndex) + val rightIndex = rightChildIndex(rawIndex) + val height = math.sqrt(Seq(leftIndex, rightIndex).map { childIndex => + KMeans.fastSquaredDistance(center, clusters(childIndex).center) + }.max) + val left = buildSubTree(leftIndex) + val right = buildSubTree(rightIndex) + new ClusteringTreeNode(index, size, center, cost, height, Array(left, right)) + } else { + val index = leafIndex + leafIndex += 1 + val height = 0.0 + new ClusteringTreeNode(index, size, center, cost, height, Array.empty) + } + } + + buildSubTree(ROOT_INDEX) + } + + /** + * Summary of a cluster. + * + * @param size the number of points within this cluster + * @param center the center of the points within this cluster + * @param cost the sum of squared distances to the center + */ + private case class ClusterSummary(size: Long, center: VectorWithNorm, cost: Double) +} + +/** + * Represents a node in a clustering tree. + * + * @param index node index, negative for internal nodes and non-negative for leaf nodes + * @param size size of the cluster + * @param centerWithNorm cluster center with norm + * @param cost cost of the cluster, i.e., the sum of squared distances to the center + * @param height height of the node in the dendrogram. Currently this is defined as the max distance + * from the center to the centers of the children's, but subject to change. + * @param children children nodes + */ +@Since("1.6.0") +@Experimental +class ClusteringTreeNode private[clustering] ( + val index: Int, + val size: Long, + private val centerWithNorm: VectorWithNorm, + val cost: Double, + val height: Double, + val children: Array[ClusteringTreeNode]) extends Serializable { + + /** Whether this is a leaf node. */ + val isLeaf: Boolean = children.isEmpty + + require((isLeaf && index >= 0) || (!isLeaf && index < 0)) + + /** Cluster center. */ + def center: Vector = centerWithNorm.vector + + /** Predicts the leaf cluster node index that the input point belongs to. */ + def predict(point: Vector): Int = { + val (index, _) = predict(new VectorWithNorm(point)) + index + } + + /** Returns the full prediction path from root to leaf. */ + def predictPath(point: Vector): Array[ClusteringTreeNode] = { + predictPath(new VectorWithNorm(point)).toArray + } + + /** Returns the full prediction path from root to leaf. */ + private def predictPath(pointWithNorm: VectorWithNorm): List[ClusteringTreeNode] = { + if (isLeaf) { + this :: Nil + } else { + val selected = children.minBy { child => + KMeans.fastSquaredDistance(child.centerWithNorm, pointWithNorm) + } + selected :: selected.predictPath(pointWithNorm) + } + } + + /** + * Computes the cost (squared distance to the predicted leaf cluster center) of the input point. + */ + def computeCost(point: Vector): Double = { + val (_, cost) = predict(new VectorWithNorm(point)) + cost + } + + /** + * Predicts the cluster index and the cost of the input point. + */ + private def predict(pointWithNorm: VectorWithNorm): (Int, Double) = { + predict(pointWithNorm, KMeans.fastSquaredDistance(centerWithNorm, pointWithNorm)) + } + + /** + * Predicts the cluster index and the cost of the input point. + * @param pointWithNorm input point + * @param cost the cost to the current center + * @return (predicted leaf cluster index, cost) + */ + private def predict(pointWithNorm: VectorWithNorm, cost: Double): (Int, Double) = { + if (isLeaf) { + (index, cost) + } else { + val (selectedChild, minCost) = children.map { child => + (child, KMeans.fastSquaredDistance(child.centerWithNorm, pointWithNorm)) + }.minBy(_._2) + selectedChild.predict(pointWithNorm, minCost) + } + } + + /** + * Returns all leaf nodes from this node. + */ + def leafNodes: Array[ClusteringTreeNode] = { + if (isLeaf) { + Array(this) + } else { + children.flatMap(_.leafNodes) + } + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeansModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeansModel.scala new file mode 100644 index 0000000000000..5015f1540d920 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeansModel.scala @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.clustering + +import org.apache.spark.Logging +import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.api.java.JavaRDD +import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.rdd.RDD + +/** + * Clustering model produced by [[BisectingKMeans]]. + * The prediction is done level-by-level from the root node to a leaf node, and at each node among + * its children the closest to the input point is selected. + * + * @param root the root node of the clustering tree + */ +@Since("1.6.0") +@Experimental +class BisectingKMeansModel @Since("1.6.0") ( + @Since("1.6.0") val root: ClusteringTreeNode + ) extends Serializable with Logging { + + /** + * Leaf cluster centers. + */ + @Since("1.6.0") + def clusterCenters: Array[Vector] = root.leafNodes.map(_.center) + + /** + * Number of leaf clusters. + */ + lazy val k: Int = clusterCenters.length + + /** + * Predicts the index of the cluster that the input point belongs to. + */ + @Since("1.6.0") + def predict(point: Vector): Int = { + root.predict(point) + } + + /** + * Predicts the indices of the clusters that the input points belong to. + */ + @Since("1.6.0") + def predict(points: RDD[Vector]): RDD[Int] = { + points.map { p => root.predict(p) } + } + + /** + * Java-friendly version of [[predict(RDD[Vector])*]] + */ + @Since("1.6.0") + def predict(points: JavaRDD[Vector]): JavaRDD[java.lang.Integer] = + predict(points.rdd).toJavaRDD().asInstanceOf[JavaRDD[java.lang.Integer]] + + /** + * Computes the squared distance between the input point and the cluster center it belongs to. + */ + @Since("1.6.0") + def computeCost(point: Vector): Double = { + root.computeCost(point) + } + + /** + * Computes the sum of squared distances between the input points and their corresponding cluster + * centers. + */ + @Since("1.6.0") + def computeCost(data: RDD[Vector]): Double = { + data.map(root.computeCost).sum() + } + + /** + * Java-friendly version of [[computeCost(RDD[Vector])*]]. + */ + @Since("1.6.0") + def computeCost(data: JavaRDD[Vector]): Double = this.computeCost(data.rdd) +} diff --git a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaBisectingKMeansSuite.java b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaBisectingKMeansSuite.java new file mode 100644 index 0000000000000..a714620ff7e4b --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaBisectingKMeansSuite.java @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.clustering; + +import java.io.Serializable; + +import com.google.common.collect.Lists; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.mllib.linalg.Vectors; + +public class JavaBisectingKMeansSuite implements Serializable { + private transient JavaSparkContext sc; + + @Before + public void setUp() { + sc = new JavaSparkContext("local", this.getClass().getSimpleName()); + } + + @After + public void tearDown() { + sc.stop(); + sc = null; + } + + @Test + public void twoDimensionalData() { + JavaRDD points = sc.parallelize(Lists.newArrayList( + Vectors.dense(4, -1), + Vectors.dense(4, 1), + Vectors.sparse(2, new int[] {0}, new double[] {1.0}) + ), 2); + + BisectingKMeans bkm = new BisectingKMeans() + .setK(4) + .setMaxIterations(2) + .setSeed(1L); + BisectingKMeansModel model = bkm.run(points); + Assert.assertEquals(3, model.k()); + Assert.assertArrayEquals(new double[] {3.0, 0.0}, model.root().center().toArray(), 1e-12); + for (ClusteringTreeNode child: model.root().children()) { + double[] center = child.center().toArray(); + if (center[0] > 2) { + Assert.assertEquals(2, child.size()); + Assert.assertArrayEquals(new double[] {4.0, 0.0}, center, 1e-12); + } else { + Assert.assertEquals(1, child.size()); + Assert.assertArrayEquals(new double[] {1.0, 0.0}, center, 1e-12); + } + } + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/BisectingKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/BisectingKMeansSuite.scala new file mode 100644 index 0000000000000..41b9d5c0d93bb --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/BisectingKMeansSuite.scala @@ -0,0 +1,182 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.clustering + +import org.apache.spark.SparkFunSuite +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ + +class BisectingKMeansSuite extends SparkFunSuite with MLlibTestSparkContext { + + test("default values") { + val bkm0 = new BisectingKMeans() + assert(bkm0.getK === 4) + assert(bkm0.getMaxIterations === 20) + assert(bkm0.getMinDivisibleClusterSize === 1.0) + val bkm1 = new BisectingKMeans() + assert(bkm0.getSeed === bkm1.getSeed, "The default seed should be constant.") + } + + test("setter/getter") { + val bkm = new BisectingKMeans() + + val k = 10 + assert(bkm.getK !== k) + assert(bkm.setK(k).getK === k) + val maxIter = 100 + assert(bkm.getMaxIterations !== maxIter) + assert(bkm.setMaxIterations(maxIter).getMaxIterations === maxIter) + val minSize = 2.0 + assert(bkm.getMinDivisibleClusterSize !== minSize) + assert(bkm.setMinDivisibleClusterSize(minSize).getMinDivisibleClusterSize === minSize) + val seed = 10L + assert(bkm.getSeed !== seed) + assert(bkm.setSeed(seed).getSeed === seed) + + intercept[IllegalArgumentException] { + bkm.setK(0) + } + intercept[IllegalArgumentException] { + bkm.setMaxIterations(0) + } + intercept[IllegalArgumentException] { + bkm.setMinDivisibleClusterSize(0.0) + } + } + + test("1D data") { + val points = Vectors.sparse(1, Array.empty, Array.empty) +: + (1 until 8).map(i => Vectors.dense(i)) + val data = sc.parallelize(points, 2) + val bkm = new BisectingKMeans() + .setK(4) + .setMaxIterations(1) + .setSeed(1L) + // The clusters should be + // (0, 1, 2, 3, 4, 5, 6, 7) + // - (0, 1, 2, 3) + // - (0, 1) + // - (2, 3) + // - (4, 5, 6, 7) + // - (4, 5) + // - (6, 7) + val model = bkm.run(data) + assert(model.k === 4) + // The total cost should be 8 * 0.5 * 0.5 = 2.0. + assert(model.computeCost(data) ~== 2.0 relTol 1e-12) + val predictions = data.map(v => (v(0), model.predict(v))).collectAsMap() + Range(0, 8, 2).foreach { i => + assert(predictions(i) === predictions(i + 1), + s"$i and ${i + 1} should belong to the same cluster.") + } + val root = model.root + assert(root.center(0) ~== 3.5 relTol 1e-12) + assert(root.height ~== 2.0 relTol 1e-12) + assert(root.children.length === 2) + assert(root.children(0).height ~== 1.0 relTol 1e-12) + assert(root.children(1).height ~== 1.0 relTol 1e-12) + } + + test("points are the same") { + val data = sc.parallelize(Seq.fill(8)(Vectors.dense(1.0, 1.0)), 2) + val bkm = new BisectingKMeans() + .setK(2) + .setMaxIterations(1) + .setSeed(1L) + val model = bkm.run(data) + assert(model.k === 1) + } + + test("more desired clusters than points") { + val data = sc.parallelize(Seq.tabulate(4)(i => Vectors.dense(i)), 2) + val bkm = new BisectingKMeans() + .setK(8) + .setMaxIterations(2) + .setSeed(1L) + val model = bkm.run(data) + assert(model.k === 4) + } + + test("min divisible cluster") { + val data = sc.parallelize( + Seq.tabulate(16)(i => Vectors.dense(i)) ++ Seq.tabulate(4)(i => Vectors.dense(-100.0 - i)), + 2) + val bkm = new BisectingKMeans() + .setK(4) + .setMinDivisibleClusterSize(10) + .setMaxIterations(1) + .setSeed(1L) + val model = bkm.run(data) + assert(model.k === 3) + assert(model.predict(Vectors.dense(-100)) === model.predict(Vectors.dense(-97))) + assert(model.predict(Vectors.dense(7)) !== model.predict(Vectors.dense(8))) + + bkm.setMinDivisibleClusterSize(0.5) + val sameModel = bkm.run(data) + assert(sameModel.k === 3) + } + + test("larger clusters get selected first") { + val data = sc.parallelize( + Seq.tabulate(16)(i => Vectors.dense(i)) ++ Seq.tabulate(4)(i => Vectors.dense(-100.0 - i)), + 2) + val bkm = new BisectingKMeans() + .setK(3) + .setMaxIterations(1) + .setSeed(1L) + val model = bkm.run(data) + assert(model.k === 3) + assert(model.predict(Vectors.dense(-100)) === model.predict(Vectors.dense(-97))) + assert(model.predict(Vectors.dense(7)) !== model.predict(Vectors.dense(8))) + } + + test("2D data") { + val points = Seq( + (11, 10), (9, 10), (10, 9), (10, 11), + (11, -10), (9, -10), (10, -9), (10, -11), + (0, 1), (0, -1) + ).map { case (x, y) => + if (x == 0) { + Vectors.sparse(2, Array(1), Array(y)) + } else { + Vectors.dense(x, y) + } + } + val data = sc.parallelize(points, 2) + val bkm = new BisectingKMeans() + .setK(3) + .setMaxIterations(4) + .setSeed(1L) + val model = bkm.run(data) + assert(model.k === 3) + assert(model.root.center ~== Vectors.dense(8, 0) relTol 1e-12) + model.root.leafNodes.foreach { node => + if (node.center(0) < 5) { + assert(node.size === 2) + assert(node.center ~== Vectors.dense(0, 0) relTol 1e-12) + } else if (node.center(1) > 0) { + assert(node.size === 4) + assert(node.center ~== Vectors.dense(10, 10) relTol 1e-12) + } else { + assert(node.size === 4) + assert(node.center ~== Vectors.dense(10, -10) relTol 1e-12) + } + } + } +} From fcb57e9c7323e24b8563800deb035f94f616474e Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 9 Nov 2015 15:16:47 -0800 Subject: [PATCH 58/88] [SPARK-11564][SQL][FOLLOW-UP] improve java api for GroupedDataset created `MapGroupFunction`, `FlatMapGroupFunction`, `CoGroupFunction` Author: Wenchen Fan Closes #9564 from cloud-fan/map. --- .../api/java/function/CoGroupFunction.java | 29 +++++++++++++++ .../api/java/function/FlatMapFunction.java | 2 +- .../api/java/function/FlatMapFunction2.java | 2 +- .../java/function/FlatMapGroupFunction.java | 28 +++++++++++++++ .../api/java/function/MapGroupFunction.java | 28 +++++++++++++++ .../plans/logical/basicOperators.scala | 4 +-- .../org/apache/spark/sql/GroupedDataset.scala | 12 +++---- .../spark/sql/execution/basicOperators.scala | 2 +- .../apache/spark/sql/JavaDatasetSuite.java | 36 ++++++++++++------- 9 files changed, 118 insertions(+), 25 deletions(-) create mode 100644 core/src/main/java/org/apache/spark/api/java/function/CoGroupFunction.java create mode 100644 core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupFunction.java create mode 100644 core/src/main/java/org/apache/spark/api/java/function/MapGroupFunction.java diff --git a/core/src/main/java/org/apache/spark/api/java/function/CoGroupFunction.java b/core/src/main/java/org/apache/spark/api/java/function/CoGroupFunction.java new file mode 100644 index 0000000000000..279639af5d430 --- /dev/null +++ b/core/src/main/java/org/apache/spark/api/java/function/CoGroupFunction.java @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.java.function; + +import java.io.Serializable; +import java.util.Iterator; + +/** + * A function that returns zero or more output records from each grouping key and its values from 2 + * Datasets. + */ +public interface CoGroupFunction extends Serializable { + Iterable call(K key, Iterator left, Iterator right) throws Exception; +} diff --git a/core/src/main/java/org/apache/spark/api/java/function/FlatMapFunction.java b/core/src/main/java/org/apache/spark/api/java/function/FlatMapFunction.java index 23f5fdd43631b..ef0d1824121ec 100644 --- a/core/src/main/java/org/apache/spark/api/java/function/FlatMapFunction.java +++ b/core/src/main/java/org/apache/spark/api/java/function/FlatMapFunction.java @@ -23,5 +23,5 @@ * A function that returns zero or more output records from each input record. */ public interface FlatMapFunction extends Serializable { - public Iterable call(T t) throws Exception; + Iterable call(T t) throws Exception; } diff --git a/core/src/main/java/org/apache/spark/api/java/function/FlatMapFunction2.java b/core/src/main/java/org/apache/spark/api/java/function/FlatMapFunction2.java index c48e92f535ff5..14a98a38ef5ab 100644 --- a/core/src/main/java/org/apache/spark/api/java/function/FlatMapFunction2.java +++ b/core/src/main/java/org/apache/spark/api/java/function/FlatMapFunction2.java @@ -23,5 +23,5 @@ * A function that takes two inputs and returns zero or more output records. */ public interface FlatMapFunction2 extends Serializable { - public Iterable call(T1 t1, T2 t2) throws Exception; + Iterable call(T1 t1, T2 t2) throws Exception; } diff --git a/core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupFunction.java b/core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupFunction.java new file mode 100644 index 0000000000000..18a2d733ca70d --- /dev/null +++ b/core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupFunction.java @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.java.function; + +import java.io.Serializable; +import java.util.Iterator; + +/** + * A function that returns zero or more output records from each grouping key and its values. + */ +public interface FlatMapGroupFunction extends Serializable { + Iterable call(K key, Iterator values) throws Exception; +} diff --git a/core/src/main/java/org/apache/spark/api/java/function/MapGroupFunction.java b/core/src/main/java/org/apache/spark/api/java/function/MapGroupFunction.java new file mode 100644 index 0000000000000..2935f9986a560 --- /dev/null +++ b/core/src/main/java/org/apache/spark/api/java/function/MapGroupFunction.java @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.java.function; + +import java.io.Serializable; +import java.util.Iterator; + +/** + * Base interface for a map function used in GroupedDataset's map function. + */ +public interface MapGroupFunction extends Serializable { + R call(K key, Iterator values) throws Exception; +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index e151ac04ede2a..d771088d69dea 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -527,7 +527,7 @@ case class MapGroups[K, T, U]( /** Factory for constructing new `CoGroup` nodes. */ object CoGroup { def apply[K : Encoder, Left : Encoder, Right : Encoder, R : Encoder]( - func: (K, Iterator[Left], Iterator[Right]) => Iterator[R], + func: (K, Iterator[Left], Iterator[Right]) => TraversableOnce[R], leftGroup: Seq[Attribute], rightGroup: Seq[Attribute], left: LogicalPlan, @@ -551,7 +551,7 @@ object CoGroup { * right children. */ case class CoGroup[K, Left, Right, R]( - func: (K, Iterator[Left], Iterator[Right]) => Iterator[R], + func: (K, Iterator[Left], Iterator[Right]) => TraversableOnce[R], kEncoder: ExpressionEncoder[K], leftEnc: ExpressionEncoder[Left], rightEnc: ExpressionEncoder[Right], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala index 5c3f626545875..850315e281dfe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala @@ -108,9 +108,7 @@ class GroupedDataset[K, T] private[sql]( MapGroups(f, groupingAttributes, logicalPlan)) } - def flatMap[U]( - f: JFunction2[K, JIterator[T], JIterator[U]], - encoder: Encoder[U]): Dataset[U] = { + def flatMap[U](f: FlatMapGroupFunction[K, T, U], encoder: Encoder[U]): Dataset[U] = { flatMap((key, data) => f.call(key, data.asJava).asScala)(encoder) } @@ -131,9 +129,7 @@ class GroupedDataset[K, T] private[sql]( MapGroups(func, groupingAttributes, logicalPlan)) } - def map[U]( - f: JFunction2[K, JIterator[T], U], - encoder: Encoder[U]): Dataset[U] = { + def map[U](f: MapGroupFunction[K, T, U], encoder: Encoder[U]): Dataset[U] = { map((key, data) => f.call(key, data.asJava))(encoder) } @@ -218,7 +214,7 @@ class GroupedDataset[K, T] private[sql]( */ def cogroup[U, R : Encoder]( other: GroupedDataset[K, U])( - f: (K, Iterator[T], Iterator[U]) => Iterator[R]): Dataset[R] = { + f: (K, Iterator[T], Iterator[U]) => TraversableOnce[R]): Dataset[R] = { implicit def uEnc: Encoder[U] = other.tEncoder new Dataset[R]( sqlContext, @@ -232,7 +228,7 @@ class GroupedDataset[K, T] private[sql]( def cogroup[U, R]( other: GroupedDataset[K, U], - f: JFunction3[K, JIterator[T], JIterator[U], JIterator[R]], + f: CoGroupFunction[K, T, U, R], encoder: Encoder[R]): Dataset[R] = { cogroup(other)((key, left, right) => f.call(key, left.asJava, right.asJava).asScala)(encoder) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 2593b16b1c8d7..145de0db9edaa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -391,7 +391,7 @@ case class MapGroups[K, T, U]( * The result of this function is encoded and flattened before being output. */ case class CoGroup[K, Left, Right, R]( - func: (K, Iterator[Left], Iterator[Right]) => Iterator[R], + func: (K, Iterator[Left], Iterator[Right]) => TraversableOnce[R], kEncoder: ExpressionEncoder[K], leftEnc: ExpressionEncoder[Left], rightEnc: ExpressionEncoder[Right], diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index 0f90de774dd3e..312cf33e4c2d4 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -29,7 +29,6 @@ import org.apache.spark.Accumulator; import org.apache.spark.SparkContext; import org.apache.spark.api.java.function.*; -import org.apache.spark.api.java.function.Function; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.sql.catalyst.encoders.Encoder; import org.apache.spark.sql.catalyst.encoders.Encoder$; @@ -170,20 +169,33 @@ public Integer call(String v) throws Exception { } }, e.INT()); - Dataset mapped = grouped.map( - new Function2, String>() { + Dataset mapped = grouped.map(new MapGroupFunction() { + @Override + public String call(Integer key, Iterator values) throws Exception { + StringBuilder sb = new StringBuilder(key.toString()); + while (values.hasNext()) { + sb.append(values.next()); + } + return sb.toString(); + } + }, e.STRING()); + + Assert.assertEquals(Arrays.asList("1a", "3foobar"), mapped.collectAsList()); + + Dataset flatMapped = grouped.flatMap( + new FlatMapGroupFunction() { @Override - public String call(Integer key, Iterator data) throws Exception { + public Iterable call(Integer key, Iterator values) throws Exception { StringBuilder sb = new StringBuilder(key.toString()); - while (data.hasNext()) { - sb.append(data.next()); + while (values.hasNext()) { + sb.append(values.next()); } - return sb.toString(); + return Collections.singletonList(sb.toString()); } }, e.STRING()); - Assert.assertEquals(Arrays.asList("1a", "3foobar"), mapped.collectAsList()); + Assert.assertEquals(Arrays.asList("1a", "3foobar"), flatMapped.collectAsList()); List data2 = Arrays.asList(2, 6, 10); Dataset ds2 = context.createDataset(data2, e.INT()); @@ -196,9 +208,9 @@ public Integer call(Integer v) throws Exception { Dataset cogrouped = grouped.cogroup( grouped2, - new Function3, Iterator, Iterator>() { + new CoGroupFunction() { @Override - public Iterator call( + public Iterable call( Integer key, Iterator left, Iterator right) throws Exception { @@ -210,7 +222,7 @@ public Iterator call( while (right.hasNext()) { sb.append(right.next()); } - return Collections.singletonList(sb.toString()).iterator(); + return Collections.singletonList(sb.toString()); } }, e.STRING()); @@ -225,7 +237,7 @@ public void testGroupByColumn() { GroupedDataset grouped = ds.groupBy(length(col("value"))).asKey(e.INT()); Dataset mapped = grouped.map( - new Function2, String>() { + new MapGroupFunction() { @Override public String call(Integer key, Iterator data) throws Exception { StringBuilder sb = new StringBuilder(key.toString()); From 9565c246eadecf4836d247d0067f2200f061d25f Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Mon, 9 Nov 2015 15:20:50 -0800 Subject: [PATCH 59/88] [SPARK-9557][SQL] Refactor ParquetFilterSuite and remove old ParquetFilters code Actually this was resolved by https://github.com/apache/spark/pull/8275. But I found the JIRA issue for this is not marked as resolved since the PR above was made for another issue but the PR above resolved both. I commented that this is resolved by the PR above; however, I opened this PR as I would like to just add a little bit of corrections. In the previous PR, I refactored the test by not reducing just collecting filters; however, this would not test properly `And` filter (which is not given to the tests). I unintentionally changed this from the original way (before being refactored). In this PR, I just followed the original way to collect filters by reducing. I would like to close this if this PR is inappropriate and somebody would like this deal with it in the separate PR related with this. Author: hyukjinkwon Closes #9554 from HyukjinKwon/SPARK-9557. --- .../datasources/parquet/ParquetFilterSuite.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index c24c9f025dad7..579dabf73318b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -54,12 +54,12 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex .select(output.map(e => Column(e)): _*) .where(Column(predicate)) - val analyzedPredicate = query.queryExecution.optimizedPlan.collect { + val maybeAnalyzedPredicate = query.queryExecution.optimizedPlan.collect { case PhysicalOperation(_, filters, LogicalRelation(_: ParquetRelation, _)) => filters - }.flatten - assert(analyzedPredicate.nonEmpty) + }.flatten.reduceLeftOption(_ && _) + assert(maybeAnalyzedPredicate.isDefined) - val selectedFilters = analyzedPredicate.flatMap(DataSourceStrategy.translateFilter) + val selectedFilters = maybeAnalyzedPredicate.flatMap(DataSourceStrategy.translateFilter) assert(selectedFilters.nonEmpty) selectedFilters.foreach { pred => From 2f38378856fb56bdd9be7ccedf56427e81701f4e Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Mon, 9 Nov 2015 16:06:48 -0800 Subject: [PATCH 60/88] [SPARK-11360][DOC] Loss of nullability when writing parquet files This fix is to add one line to explain the current behavior of Spark SQL when writing Parquet files. All columns are forced to be nullable for compatibility reasons. Author: gatorsmile Closes #9314 from gatorsmile/lossNull. --- docs/sql-programming-guide.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index ccd26904329d3..6e02d6564b002 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -982,7 +982,8 @@ when a table is dropped. [Parquet](http://parquet.io) is a columnar format that is supported by many other data processing systems. Spark SQL provides support for both reading and writing Parquet files that automatically preserves the schema -of the original data. +of the original data. When writing Parquet files, all columns are automatically converted to be nullable for +compatibility reasons. ### Loading Data Programmatically From 9c740a9ddf6344a03b4b45380eaf0cfc6e2299b5 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Mon, 9 Nov 2015 16:11:00 -0800 Subject: [PATCH 61/88] [SPARK-11578][SQL] User API for Typed Aggregation This PR adds a new interface for user-defined aggregations, that can be used in `DataFrame` and `Dataset` operations to take all of the elements of a group and reduce them to a single value. For example, the following aggregator extracts an `int` from a specific class and adds them up: ```scala case class Data(i: Int) val customSummer = new Aggregator[Data, Int, Int] { def prepare(d: Data) = d.i def reduce(l: Int, r: Int) = l + r def present(r: Int) = r }.toColumn() val ds: Dataset[Data] = ... val aggregated = ds.select(customSummer) ``` By using helper functions, users can make a generic `Aggregator` that works on any input type: ```scala /** An `Aggregator` that adds up any numeric type returned by the given function. */ class SumOf[I, N : Numeric](f: I => N) extends Aggregator[I, N, N] with Serializable { val numeric = implicitly[Numeric[N]] override def zero: N = numeric.zero override def reduce(b: N, a: I): N = numeric.plus(b, f(a)) override def present(reduction: N): N = reduction } def sum[I, N : Numeric : Encoder](f: I => N): TypedColumn[I, N] = new SumOf(f).toColumn ``` These aggregators can then be used alongside other built-in SQL aggregations. ```scala val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() ds .groupBy(_._1) .agg( sum(_._2), // The aggregator defined above. expr("sum(_2)").as[Int], // A built-in dynatically typed aggregation. count("*")) // A built-in statically typed aggregation. .collect() res0: ("a", 30, 30, 2L), ("b", 3, 3, 2L), ("c", 1, 1, 1L) ``` The current implementation focuses on integrating this into the typed API, but currently only supports running aggregations that return a single long value as explained in `TypedAggregateExpression`. This will be improved in a followup PR. Author: Michael Armbrust Closes #9555 from marmbrus/dataset-useragg. --- .../scala/org/apache/spark/sql/Column.scala | 11 +- .../scala/org/apache/spark/sql/Dataset.scala | 30 ++-- .../org/apache/spark/sql/GroupedDataset.scala | 51 ++++--- .../org/apache/spark/sql/SQLContext.scala | 1 - .../aggregate/TypedAggregateExpression.scala | 129 ++++++++++++++++++ .../spark/sql/expressions/Aggregator.scala | 81 +++++++++++ .../org/apache/spark/sql/functions.scala | 30 +++- .../apache/spark/sql/JavaDatasetSuite.java | 4 +- .../spark/sql/DatasetAggregatorSuite.scala | 65 +++++++++ 9 files changed, 360 insertions(+), 42 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index c32c93897ce0b..d26b6c3579205 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -23,7 +23,7 @@ import org.apache.spark.annotation.Experimental import org.apache.spark.Logging import org.apache.spark.sql.functions.lit import org.apache.spark.sql.catalyst.analysis._ -import org.apache.spark.sql.catalyst.encoders.Encoder +import org.apache.spark.sql.catalyst.encoders.{encoderFor, Encoder} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.DataTypeParser import org.apache.spark.sql.types._ @@ -39,10 +39,13 @@ private[sql] object Column { } /** - * A [[Column]] where an [[Encoder]] has been given for the expected return type. + * A [[Column]] where an [[Encoder]] has been given for the expected input and return type. * @since 1.6.0 + * @tparam T The input type expected for this expression. Can be `Any` if the expression is type + * checked by the analyzer instead of the compiler (i.e. `expr("sum(...)")`). + * @tparam U The output type of this column. */ -class TypedColumn[T](expr: Expression)(implicit val encoder: Encoder[T]) extends Column(expr) +class TypedColumn[-T, U](expr: Expression, val encoder: Encoder[U]) extends Column(expr) /** * :: Experimental :: @@ -85,7 +88,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { * results into the correct JVM types. * @since 1.6.0 */ - def as[T : Encoder]: TypedColumn[T] = new TypedColumn[T](expr) + def as[U : Encoder]: TypedColumn[Any, U] = new TypedColumn[Any, U](expr, encoderFor[U]) /** * Extracts a value or values from a complex type. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 959e0f5ba03e6..6d2968e2881f8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -358,7 +358,7 @@ class Dataset[T] private[sql]( * }}} * @since 1.6.0 */ - def select[U1: Encoder](c1: TypedColumn[U1]): Dataset[U1] = { + def select[U1: Encoder](c1: TypedColumn[T, U1]): Dataset[U1] = { new Dataset[U1](sqlContext, Project(Alias(c1.expr, "_1")() :: Nil, logicalPlan)) } @@ -367,7 +367,7 @@ class Dataset[T] private[sql]( * code reuse, we do this without the help of the type system and then use helper functions * that cast appropriately for the user facing interface. */ - protected def selectUntyped(columns: TypedColumn[_]*): Dataset[_] = { + protected def selectUntyped(columns: TypedColumn[_, _]*): Dataset[_] = { val aliases = columns.zipWithIndex.map { case (c, i) => Alias(c.expr, s"_${i + 1}")() } val unresolvedPlan = Project(aliases, logicalPlan) val execution = new QueryExecution(sqlContext, unresolvedPlan) @@ -385,7 +385,7 @@ class Dataset[T] private[sql]( * Returns a new [[Dataset]] by computing the given [[Column]] expressions for each element. * @since 1.6.0 */ - def select[U1, U2](c1: TypedColumn[U1], c2: TypedColumn[U2]): Dataset[(U1, U2)] = + def select[U1, U2](c1: TypedColumn[T, U1], c2: TypedColumn[T, U2]): Dataset[(U1, U2)] = selectUntyped(c1, c2).asInstanceOf[Dataset[(U1, U2)]] /** @@ -393,9 +393,9 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ def select[U1, U2, U3]( - c1: TypedColumn[U1], - c2: TypedColumn[U2], - c3: TypedColumn[U3]): Dataset[(U1, U2, U3)] = + c1: TypedColumn[T, U1], + c2: TypedColumn[T, U2], + c3: TypedColumn[T, U3]): Dataset[(U1, U2, U3)] = selectUntyped(c1, c2, c3).asInstanceOf[Dataset[(U1, U2, U3)]] /** @@ -403,10 +403,10 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ def select[U1, U2, U3, U4]( - c1: TypedColumn[U1], - c2: TypedColumn[U2], - c3: TypedColumn[U3], - c4: TypedColumn[U4]): Dataset[(U1, U2, U3, U4)] = + c1: TypedColumn[T, U1], + c2: TypedColumn[T, U2], + c3: TypedColumn[T, U3], + c4: TypedColumn[T, U4]): Dataset[(U1, U2, U3, U4)] = selectUntyped(c1, c2, c3, c4).asInstanceOf[Dataset[(U1, U2, U3, U4)]] /** @@ -414,11 +414,11 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ def select[U1, U2, U3, U4, U5]( - c1: TypedColumn[U1], - c2: TypedColumn[U2], - c3: TypedColumn[U3], - c4: TypedColumn[U4], - c5: TypedColumn[U5]): Dataset[(U1, U2, U3, U4, U5)] = + c1: TypedColumn[T, U1], + c2: TypedColumn[T, U2], + c3: TypedColumn[T, U3], + c4: TypedColumn[T, U4], + c5: TypedColumn[T, U5]): Dataset[(U1, U2, U3, U4, U5)] = selectUntyped(c1, c2, c3, c4, c5).asInstanceOf[Dataset[(U1, U2, U3, U4, U5)]] /* **************** * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala index 850315e281dfe..db61499229284 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql import java.util.{Iterator => JIterator} + import scala.collection.JavaConverters._ import org.apache.spark.annotation.Experimental @@ -26,8 +27,10 @@ import org.apache.spark.sql.catalyst.analysis.{UnresolvedAlias, UnresolvedAttrib import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, encoderFor, Encoder} import org.apache.spark.sql.catalyst.expressions.{Expression, NamedExpression, Alias, Attribute} import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression import org.apache.spark.sql.execution.QueryExecution + /** * :: Experimental :: * A [[Dataset]] has been logically grouped by a user specified grouping key. Users should not @@ -143,7 +146,7 @@ class GroupedDataset[K, T] private[sql]( * that cast appropriately for the user facing interface. * TODO: does not handle aggrecations that return nonflat results, */ - protected def aggUntyped(columns: TypedColumn[_]*): Dataset[_] = { + protected def aggUntyped(columns: TypedColumn[_, _]*): Dataset[_] = { val aliases = (groupingAttributes ++ columns.map(_.expr)).map { case u: UnresolvedAttribute => UnresolvedAlias(u) case expr: NamedExpression => expr @@ -151,7 +154,15 @@ class GroupedDataset[K, T] private[sql]( } val unresolvedPlan = Aggregate(groupingAttributes, aliases, logicalPlan) - val execution = new QueryExecution(sqlContext, unresolvedPlan) + + // Fill in the input encoders for any aggregators in the plan. + val withEncoders = unresolvedPlan transformAllExpressions { + case ta: TypedAggregateExpression if ta.aEncoder.isEmpty => + ta.copy( + aEncoder = Some(tEnc.asInstanceOf[ExpressionEncoder[Any]]), + children = dataAttributes) + } + val execution = new QueryExecution(sqlContext, withEncoders) val columnEncoders = columns.map(_.encoder.asInstanceOf[ExpressionEncoder[_]]) @@ -162,43 +173,47 @@ class GroupedDataset[K, T] private[sql]( case (e, a) => e.unbind(a :: Nil).resolve(execution.analyzed.output) } - new Dataset(sqlContext, execution, ExpressionEncoder.tuple(encoders)) + + new Dataset( + sqlContext, + execution, + ExpressionEncoder.tuple(encoders)) } /** * Computes the given aggregation, returning a [[Dataset]] of tuples for each unique key * and the result of computing this aggregation over all elements in the group. */ - def agg[A1](col1: TypedColumn[A1]): Dataset[(K, A1)] = - aggUntyped(col1).asInstanceOf[Dataset[(K, A1)]] + def agg[U1](col1: TypedColumn[T, U1]): Dataset[(K, U1)] = + aggUntyped(col1).asInstanceOf[Dataset[(K, U1)]] /** * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key * and the result of computing these aggregations over all elements in the group. */ - def agg[A1, A2](col1: TypedColumn[A1], col2: TypedColumn[A2]): Dataset[(K, A1, A2)] = - aggUntyped(col1, col2).asInstanceOf[Dataset[(K, A1, A2)]] + def agg[U1, U2](col1: TypedColumn[T, U1], col2: TypedColumn[T, U2]): Dataset[(K, U1, U2)] = + aggUntyped(col1, col2).asInstanceOf[Dataset[(K, U1, U2)]] /** * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key * and the result of computing these aggregations over all elements in the group. */ - def agg[A1, A2, A3]( - col1: TypedColumn[A1], - col2: TypedColumn[A2], - col3: TypedColumn[A3]): Dataset[(K, A1, A2, A3)] = - aggUntyped(col1, col2, col3).asInstanceOf[Dataset[(K, A1, A2, A3)]] + def agg[U1, U2, U3]( + col1: TypedColumn[T, U1], + col2: TypedColumn[T, U2], + col3: TypedColumn[T, U3]): Dataset[(K, U1, U2, U3)] = + aggUntyped(col1, col2, col3).asInstanceOf[Dataset[(K, U1, U2, U3)]] /** * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key * and the result of computing these aggregations over all elements in the group. */ - def agg[A1, A2, A3, A4]( - col1: TypedColumn[A1], - col2: TypedColumn[A2], - col3: TypedColumn[A3], - col4: TypedColumn[A4]): Dataset[(K, A1, A2, A3, A4)] = - aggUntyped(col1, col2, col3, col4).asInstanceOf[Dataset[(K, A1, A2, A3, A4)]] + def agg[U1, U2, U3, U4]( + col1: TypedColumn[T, U1], + col2: TypedColumn[T, U2], + col3: TypedColumn[T, U3], + col4: TypedColumn[T, U4]): Dataset[(K, U1, U2, U3, U4)] = + aggUntyped(col1, col2, col3, col4).asInstanceOf[Dataset[(K, U1, U2, U3, U4)]] /** * Returns a [[Dataset]] that contains a tuple with each key and the number of items present diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 5598731af5fcc..1cf1e30f967cd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -21,7 +21,6 @@ import java.beans.{BeanInfo, Introspector} import java.util.Properties import java.util.concurrent.atomic.AtomicReference - import scala.collection.JavaConverters._ import scala.collection.immutable import scala.reflect.runtime.universe.TypeTag diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala new file mode 100644 index 0000000000000..24d8122b6222b --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala @@ -0,0 +1,129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.aggregate + +import scala.language.existentials + +import org.apache.spark.Logging +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.encoders.{encoderFor, Encoder} +import org.apache.spark.sql.catalyst.expressions.aggregate.ImperativeAggregate +import org.apache.spark.sql.expressions.Aggregator +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types.{StructType, DataType} + +object TypedAggregateExpression { + def apply[A, B : Encoder, C : Encoder]( + aggregator: Aggregator[A, B, C]): TypedAggregateExpression = { + new TypedAggregateExpression( + aggregator.asInstanceOf[Aggregator[Any, Any, Any]], + None, + encoderFor[B].asInstanceOf[ExpressionEncoder[Any]], + encoderFor[C].asInstanceOf[ExpressionEncoder[Any]], + Nil, + 0, + 0) + } +} + +/** + * This class is a rough sketch of how to hook `Aggregator` into the Aggregation system. It has + * the following limitations: + * - It assumes the aggregator reduces and returns a single column of type `long`. + * - It might only work when there is a single aggregator in the first column. + * - It assumes the aggregator has a zero, `0`. + */ +case class TypedAggregateExpression( + aggregator: Aggregator[Any, Any, Any], + aEncoder: Option[ExpressionEncoder[Any]], + bEncoder: ExpressionEncoder[Any], + cEncoder: ExpressionEncoder[Any], + children: Seq[Expression], + mutableAggBufferOffset: Int, + inputAggBufferOffset: Int) + extends ImperativeAggregate with Logging { + + override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = + copy(mutableAggBufferOffset = newMutableAggBufferOffset) + + override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = + copy(inputAggBufferOffset = newInputAggBufferOffset) + + override def nullable: Boolean = true + + // TODO: this assumes flat results... + override def dataType: DataType = cEncoder.schema.head.dataType + + override def deterministic: Boolean = true + + override lazy val resolved: Boolean = aEncoder.isDefined + + override lazy val inputTypes: Seq[DataType] = + aEncoder.map(_.schema.map(_.dataType)).getOrElse(Nil) + + override val aggBufferSchema: StructType = bEncoder.schema + + override val aggBufferAttributes: Seq[AttributeReference] = aggBufferSchema.toAttributes + + // Note: although this simply copies aggBufferAttributes, this common code can not be placed + // in the superclass because that will lead to initialization ordering issues. + override val inputAggBufferAttributes: Seq[AttributeReference] = + aggBufferAttributes.map(_.newInstance()) + + lazy val inputAttributes = aEncoder.get.schema.toAttributes + lazy val inputMapping = AttributeMap(inputAttributes.zip(children)) + lazy val boundA = + aEncoder.get.copy(constructExpression = aEncoder.get.constructExpression transform { + case a: AttributeReference => inputMapping(a) + }) + + // TODO: this probably only works when we are in the first column. + val bAttributes = bEncoder.schema.toAttributes + lazy val boundB = bEncoder.resolve(bAttributes).bind(bAttributes) + + override def initialize(buffer: MutableRow): Unit = { + // TODO: We need to either force Aggregator to have a zero or we need to eliminate the need for + // this in execution. + buffer.setInt(mutableAggBufferOffset, aggregator.zero.asInstanceOf[Int]) + } + + override def update(buffer: MutableRow, input: InternalRow): Unit = { + val inputA = boundA.fromRow(input) + val currentB = boundB.fromRow(buffer) + val merged = aggregator.reduce(currentB, inputA) + val returned = boundB.toRow(merged) + buffer.setInt(mutableAggBufferOffset, returned.getInt(0)) + } + + override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = { + buffer1.setLong( + mutableAggBufferOffset, + buffer1.getLong(mutableAggBufferOffset) + buffer2.getLong(inputAggBufferOffset)) + } + + override def eval(buffer: InternalRow): Any = { + buffer.getInt(mutableAggBufferOffset) + } + + override def toString: String = { + s"""${aggregator.getClass.getSimpleName}(${children.mkString(",")})""" + } + + override def nodeName: String = aggregator.getClass.getSimpleName +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala new file mode 100644 index 0000000000000..0b3192a6da9d8 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.expressions + +import org.apache.spark.sql.catalyst.encoders.{encoderFor, Encoder} +import org.apache.spark.sql.catalyst.expressions.aggregate.{Complete, AggregateExpression2} +import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression +import org.apache.spark.sql.{Dataset, DataFrame, TypedColumn} + +/** + * A base class for user-defined aggregations, which can be used in [[DataFrame]] and [[Dataset]] + * operations to take all of the elements of a group and reduce them to a single value. + * + * For example, the following aggregator extracts an `int` from a specific class and adds them up: + * {{{ + * case class Data(i: Int) + * + * val customSummer = new Aggregator[Data, Int, Int] { + * def zero = 0 + * def reduce(b: Int, a: Data) = b + a.i + * def present(r: Int) = r + * }.toColumn() + * + * val ds: Dataset[Data] + * val aggregated = ds.select(customSummer) + * }}} + * + * Based loosely on Aggregator from Algebird: https://github.com/twitter/algebird + * + * @tparam A The input type for the aggregation. + * @tparam B The type of the intermediate value of the reduction. + * @tparam C The type of the final result. + */ +abstract class Aggregator[-A, B, C] { + + /** A zero value for this aggregation. Should satisfy the property that any b + zero = b */ + def zero: B + + /** + * Combine two values to produce a new value. For performance, the function may modify `b` and + * return it instead of constructing new object for b. + */ + def reduce(b: B, a: A): B + + /** + * Transform the output of the reduction. + */ + def present(reduction: B): C + + /** + * Returns this `Aggregator` as a [[TypedColumn]] that can be used in [[Dataset]] or [[DataFrame]] + * operations. + */ + def toColumn( + implicit bEncoder: Encoder[B], + cEncoder: Encoder[C]): TypedColumn[A, C] = { + val expr = + new AggregateExpression2( + TypedAggregateExpression(this), + Complete, + false) + + new TypedColumn[A, C](expr, encoderFor[C]) + } +} + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 3f0b24b68b816..6d56542ee0875 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql + + import scala.language.implicitConversions import scala.reflect.runtime.universe.{TypeTag, typeTag} import scala.util.Try @@ -24,11 +26,32 @@ import scala.util.Try import org.apache.spark.annotation.Experimental import org.apache.spark.sql.catalyst.{SqlParser, ScalaReflection} import org.apache.spark.sql.catalyst.analysis.{UnresolvedFunction, Star} +import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, Encoder} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.BroadcastHint import org.apache.spark.sql.types._ import org.apache.spark.util.Utils +/** + * Ensures that java functions signatures for methods that now return a [[TypedColumn]] still have + * legacy equivalents in bytecode. This compatibility is done by forcing the compiler to generate + * "bridge" methods due to the use of covariant return types. + * + * {{{ + * In LegacyFunctions: + * public abstract org.apache.spark.sql.Column avg(java.lang.String); + * + * In functions: + * public static org.apache.spark.sql.TypedColumn avg(...); + * }}} + * + * This allows us to use the same functions both in typed [[Dataset]] operations and untyped + * [[DataFrame]] operations when the return type for a given function is statically known. + */ +private[sql] abstract class LegacyFunctions { + def count(columnName: String): Column +} + /** * :: Experimental :: * Functions available for [[DataFrame]]. @@ -48,11 +71,14 @@ import org.apache.spark.util.Utils */ @Experimental // scalastyle:off -object functions { +object functions extends LegacyFunctions { // scalastyle:on private def withExpr(expr: Expression): Column = Column(expr) + private implicit def newLongEncoder: Encoder[Long] = ExpressionEncoder[Long](flat = true) + + /** * Returns a [[Column]] based on the given column name. * @@ -234,7 +260,7 @@ object functions { * @group agg_funcs * @since 1.3.0 */ - def count(columnName: String): Column = count(Column(columnName)) + def count(columnName: String): TypedColumn[Any, Long] = count(Column(columnName)).as[Long] /** * Aggregate function: returns the number of distinct items in a group. diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index 312cf33e4c2d4..2da63d1b96706 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -258,8 +258,8 @@ public void testSelect() { Dataset ds = context.createDataset(data, e.INT()); Dataset> selected = ds.select( - expr("value + 1").as(e.INT()), - col("value").cast("string").as(e.STRING())); + expr("value + 1"), + col("value").cast("string")).as(e.tuple(e.INT(), e.STRING())); Assert.assertEquals( Arrays.asList(tuple2(3, "2"), tuple2(7, "6")), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala new file mode 100644 index 0000000000000..340470c096b87 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.catalyst.encoders.Encoder +import org.apache.spark.sql.functions._ + +import scala.language.postfixOps + +import org.apache.spark.sql.test.SharedSQLContext + +import org.apache.spark.sql.expressions.Aggregator + +/** An `Aggregator` that adds up any numeric type returned by the given function. */ +class SumOf[I, N : Numeric](f: I => N) extends Aggregator[I, N, N] with Serializable { + val numeric = implicitly[Numeric[N]] + + override def zero: N = numeric.zero + + override def reduce(b: N, a: I): N = numeric.plus(b, f(a)) + + override def present(reduction: N): N = reduction +} + +class DatasetAggregatorSuite extends QueryTest with SharedSQLContext { + + import testImplicits._ + + def sum[I, N : Numeric : Encoder](f: I => N): TypedColumn[I, N] = + new SumOf(f).toColumn + + test("typed aggregation: TypedAggregator") { + val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() + + checkAnswer( + ds.groupBy(_._1).agg(sum(_._2)), + ("a", 30), ("b", 3), ("c", 1)) + } + + test("typed aggregation: TypedAggregator, expr, expr") { + val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() + + checkAnswer( + ds.groupBy(_._1).agg( + sum(_._2), + expr("sum(_2)").as[Int], + count("*")), + ("a", 30, 30, 2L), ("b", 3, 3, 2L), ("c", 1, 1, 1L)) + } +} From 675c7e723cadff588405c23826a00686587728b8 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 9 Nov 2015 16:22:15 -0800 Subject: [PATCH 62/88] [SPARK-11564][SQL] Fix documentation for DataFrame.take/collect Author: Reynold Xin Closes #9557 from rxin/SPARK-11564-1. --- .../src/main/scala/org/apache/spark/sql/DataFrame.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 8ab958adadcca..d25807cf8d09c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -1479,8 +1479,8 @@ class DataFrame private[sql]( /** * Returns the first `n` rows in the [[DataFrame]]. * - * Running take requires moving data into the application's driver process, and doing so on a - * very large dataset can crash the driver process with OutOfMemoryError. + * Running take requires moving data into the application's driver process, and doing so with + * a very large `n` can crash the driver process with OutOfMemoryError. * * @group action * @since 1.3.0 @@ -1501,8 +1501,8 @@ class DataFrame private[sql]( /** * Returns an array that contains all of [[Row]]s in this [[DataFrame]]. * - * Running take requires moving data into the application's driver process, and doing so with - * a very large `n` can crash the driver process with OutOfMemoryError. + * Running collect requires moving all the data into the application's driver process, and + * doing so on a very large dataset can crash the driver process with OutOfMemoryError. * * For Java API, use [[collectAsList]]. * From 7dc9d8dba6c4bc655896b137062d896dec4ef64a Mon Sep 17 00:00:00 2001 From: Yu ISHIKAWA Date: Mon, 9 Nov 2015 16:25:29 -0800 Subject: [PATCH 63/88] [SPARK-11610][MLLIB][PYTHON][DOCS] Make the docs of LDAModel.describeTopics in Python more specific cc jkbradley Author: Yu ISHIKAWA Closes #9577 from yu-iskw/SPARK-11610. --- python/pyspark/mllib/clustering.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py index 12081f8c69075..1fa061dc2da99 100644 --- a/python/pyspark/mllib/clustering.py +++ b/python/pyspark/mllib/clustering.py @@ -734,6 +734,12 @@ def describeTopics(self, maxTermsPerTopic=None): """Return the topics described by weighted terms. WARNING: If vocabSize and k are large, this can return a large object! + + :param maxTermsPerTopic: Maximum number of terms to collect for each topic. + (default: vocabulary size) + :return: Array over topics. Each topic is represented as a pair of matching arrays: + (term indices, term weights in topic). + Each topic's terms are sorted in order of decreasing weight. """ if maxTermsPerTopic is None: topics = self.call("describeTopics") From 61f9c8711c79f35d67b0456155866da316b131d9 Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Mon, 9 Nov 2015 16:55:23 -0800 Subject: [PATCH 64/88] [SPARK-11069][ML] Add RegexTokenizer option to convert to lowercase jira: https://issues.apache.org/jira/browse/SPARK-11069 quotes from jira: Tokenizer converts strings to lowercase automatically, but RegexTokenizer does not. It would be nice to add an option to RegexTokenizer to convert to lowercase. Proposal: call the Boolean Param "toLowercase" set default to false (so behavior does not change) Actually sklearn converts to lowercase before tokenizing too Author: Yuhao Yang Closes #9092 from hhbyyh/tokenLower. --- .../apache/spark/ml/feature/Tokenizer.scala | 19 ++++++++++++++-- .../spark/ml/feature/JavaTokenizerSuite.java | 1 + .../spark/ml/feature/TokenizerSuite.scala | 22 ++++++++++++++----- 3 files changed, 35 insertions(+), 7 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala index 248288ca73e99..1b82b40caac18 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala @@ -100,10 +100,25 @@ class RegexTokenizer(override val uid: String) /** @group getParam */ def getPattern: String = $(pattern) - setDefault(minTokenLength -> 1, gaps -> true, pattern -> "\\s+") + /** + * Indicates whether to convert all characters to lowercase before tokenizing. + * Default: true + * @group param + */ + final val toLowercase: BooleanParam = new BooleanParam(this, "toLowercase", + "whether to convert all characters to lowercase before tokenizing.") + + /** @group setParam */ + def setToLowercase(value: Boolean): this.type = set(toLowercase, value) + + /** @group getParam */ + def getToLowercase: Boolean = $(toLowercase) + + setDefault(minTokenLength -> 1, gaps -> true, pattern -> "\\s+", toLowercase -> true) - override protected def createTransformFunc: String => Seq[String] = { str => + override protected def createTransformFunc: String => Seq[String] = { originStr => val re = $(pattern).r + val str = if ($(toLowercase)) originStr.toLowerCase() else originStr val tokens = if ($(gaps)) re.split(str).toSeq else re.findAllIn(str).toSeq val minLength = $(minTokenLength) tokens.filter(_.length >= minLength) diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaTokenizerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaTokenizerSuite.java index 02309ce63219a..c407d98f1b795 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaTokenizerSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaTokenizerSuite.java @@ -53,6 +53,7 @@ public void regexTokenizer() { .setOutputCol("tokens") .setPattern("\\s") .setGaps(true) + .setToLowercase(false) .setMinTokenLength(3); diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala index e5fd21c3f6fca..a02992a2407b3 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala @@ -48,13 +48,13 @@ class RegexTokenizerSuite extends SparkFunSuite with MLlibTestSparkContext { .setInputCol("rawText") .setOutputCol("tokens") val dataset0 = sqlContext.createDataFrame(Seq( - TokenizerTestData("Test for tokenization.", Array("Test", "for", "tokenization", ".")), - TokenizerTestData("Te,st. punct", Array("Te", ",", "st", ".", "punct")) + TokenizerTestData("Test for tokenization.", Array("test", "for", "tokenization", ".")), + TokenizerTestData("Te,st. punct", Array("te", ",", "st", ".", "punct")) )) testRegexTokenizer(tokenizer0, dataset0) val dataset1 = sqlContext.createDataFrame(Seq( - TokenizerTestData("Test for tokenization.", Array("Test", "for", "tokenization")), + TokenizerTestData("Test for tokenization.", Array("test", "for", "tokenization")), TokenizerTestData("Te,st. punct", Array("punct")) )) tokenizer0.setMinTokenLength(3) @@ -64,11 +64,23 @@ class RegexTokenizerSuite extends SparkFunSuite with MLlibTestSparkContext { .setInputCol("rawText") .setOutputCol("tokens") val dataset2 = sqlContext.createDataFrame(Seq( - TokenizerTestData("Test for tokenization.", Array("Test", "for", "tokenization.")), - TokenizerTestData("Te,st. punct", Array("Te,st.", "punct")) + TokenizerTestData("Test for tokenization.", Array("test", "for", "tokenization.")), + TokenizerTestData("Te,st. punct", Array("te,st.", "punct")) )) testRegexTokenizer(tokenizer2, dataset2) } + + test("RegexTokenizer with toLowercase false") { + val tokenizer = new RegexTokenizer() + .setInputCol("rawText") + .setOutputCol("tokens") + .setToLowercase(false) + val dataset = sqlContext.createDataFrame(Seq( + TokenizerTestData("JAVA SCALA", Array("JAVA", "SCALA")), + TokenizerTestData("java scala", Array("java", "scala")) + )) + testRegexTokenizer(tokenizer, dataset) + } } object RegexTokenizerSuite extends SparkFunSuite { From 26062d22607e1f9854bc2588ba22a4e0f8bba48c Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Mon, 9 Nov 2015 17:18:49 -0800 Subject: [PATCH 65/88] [SPARK-11198][STREAMING][KINESIS] Support de-aggregation of records during recovery While the KCL handles de-aggregation during the regular operation, during recovery we use the lower level api, and therefore need to de-aggregate the records. tdas Testing is an issue, we need protobuf magic to do the aggregated records. Maybe we could depend on KPL for tests? Author: Burak Yavuz Closes #9403 from brkyvz/kinesis-deaggregation. --- extras/kinesis-asl/pom.xml | 6 ++ .../kinesis/KinesisBackedBlockRDD.scala | 6 +- .../streaming/kinesis/KinesisReceiver.scala | 1 - .../kinesis/KinesisRecordProcessor.scala | 2 +- .../kinesis/KinesisBackedBlockRDDSuite.scala | 12 +++- .../kinesis/KinesisStreamSuite.scala | 17 +++--- .../streaming/kinesis/KinesisTestUtils.scala | 55 +++++++++++++++---- pom.xml | 2 + 8 files changed, 76 insertions(+), 25 deletions(-) rename extras/kinesis-asl/src/{main => test}/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala (80%) diff --git a/extras/kinesis-asl/pom.xml b/extras/kinesis-asl/pom.xml index ef72d97eae69d..519a920279c97 100644 --- a/extras/kinesis-asl/pom.xml +++ b/extras/kinesis-asl/pom.xml @@ -64,6 +64,12 @@ aws-java-sdk ${aws.java.sdk.version} + + com.amazonaws + amazon-kinesis-producer + ${aws.kinesis.producer.version} + test + org.mockito mockito-core diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala index 000897a4e7290..691c1790b207f 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala @@ -23,6 +23,7 @@ import scala.util.control.NonFatal import com.amazonaws.auth.{AWSCredentials, DefaultAWSCredentialsProviderChain} import com.amazonaws.services.kinesis.AmazonKinesisClient +import com.amazonaws.services.kinesis.clientlibrary.types.UserRecord import com.amazonaws.services.kinesis.model._ import org.apache.spark._ @@ -210,7 +211,10 @@ class KinesisSequenceRangeIterator( s"getting records using shard iterator") { client.getRecords(getRecordsRequest) } - (getRecordsResult.getRecords.iterator().asScala, getRecordsResult.getNextShardIterator) + // De-aggregate records, if KPL was used in producing the records. The KCL automatically + // handles de-aggregation during regular operation. This code path is used during recovery + val recordIterator = UserRecord.deaggregate(getRecordsResult.getRecords) + (recordIterator.iterator().asScala, getRecordsResult.getNextShardIterator) } /** diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala index 50993f157cd95..97dbb918573a3 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala @@ -216,7 +216,6 @@ private[kinesis] class KinesisReceiver[T]( val metadata = SequenceNumberRange(streamName, shardId, records.get(0).getSequenceNumber(), records.get(records.size() - 1).getSequenceNumber()) blockGenerator.addMultipleDataWithCallback(dataIterator, metadata) - } } diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala index e381ffa0cbef4..b5b76cb92d866 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala @@ -80,7 +80,7 @@ private[kinesis] class KinesisRecordProcessor[T](receiver: KinesisReceiver[T], w * more than once. */ logError(s"Exception: WorkerId $workerId encountered and exception while storing " + - " or checkpointing a batch for workerId $workerId and shardId $shardId.", e) + s" or checkpointing a batch for workerId $workerId and shardId $shardId.", e) /* Rethrow the exception to the Kinesis Worker that is managing this RecordProcessor. */ throw e diff --git a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala index 9f9e146a08d46..52c61dfb1c023 100644 --- a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala +++ b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala @@ -22,7 +22,8 @@ import org.scalatest.BeforeAndAfterAll import org.apache.spark.storage.{BlockId, BlockManager, StorageLevel, StreamBlockId} import org.apache.spark.{SparkConf, SparkContext, SparkException} -class KinesisBackedBlockRDDSuite extends KinesisFunSuite with BeforeAndAfterAll { +abstract class KinesisBackedBlockRDDTests(aggregateTestData: Boolean) + extends KinesisFunSuite with BeforeAndAfterAll { private val testData = 1 to 8 @@ -37,13 +38,12 @@ class KinesisBackedBlockRDDSuite extends KinesisFunSuite with BeforeAndAfterAll private var sc: SparkContext = null private var blockManager: BlockManager = null - override def beforeAll(): Unit = { runIfTestsEnabled("Prepare KinesisTestUtils") { testUtils = new KinesisTestUtils() testUtils.createStream() - shardIdToDataAndSeqNumbers = testUtils.pushData(testData) + shardIdToDataAndSeqNumbers = testUtils.pushData(testData, aggregate = aggregateTestData) require(shardIdToDataAndSeqNumbers.size > 1, "Need data to be sent to multiple shards") shardIds = shardIdToDataAndSeqNumbers.keySet.toSeq @@ -247,3 +247,9 @@ class KinesisBackedBlockRDDSuite extends KinesisFunSuite with BeforeAndAfterAll Array.tabulate(num) { i => new StreamBlockId(0, i) } } } + +class WithAggregationKinesisBackedBlockRDDSuite + extends KinesisBackedBlockRDDTests(aggregateTestData = true) + +class WithoutAggregationKinesisBackedBlockRDDSuite + extends KinesisBackedBlockRDDTests(aggregateTestData = false) diff --git a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala index ba84e557dfcc2..dee30444d8cc6 100644 --- a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala +++ b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala @@ -39,7 +39,7 @@ import org.apache.spark.streaming.scheduler.ReceivedBlockInfo import org.apache.spark.util.Utils import org.apache.spark.{SparkConf, SparkContext} -class KinesisStreamSuite extends KinesisFunSuite +abstract class KinesisStreamTests(aggregateTestData: Boolean) extends KinesisFunSuite with Eventually with BeforeAndAfter with BeforeAndAfterAll { // This is the name that KCL will use to save metadata to DynamoDB @@ -182,13 +182,13 @@ class KinesisStreamSuite extends KinesisFunSuite val collected = new mutable.HashSet[Int] with mutable.SynchronizedSet[Int] stream.map { bytes => new String(bytes).toInt }.foreachRDD { rdd => collected ++= rdd.collect() - logInfo("Collected = " + rdd.collect().toSeq.mkString(", ")) + logInfo("Collected = " + collected.mkString(", ")) } ssc.start() val testData = 1 to 10 eventually(timeout(120 seconds), interval(10 second)) { - testUtils.pushData(testData) + testUtils.pushData(testData, aggregateTestData) assert(collected === testData.toSet, "\nData received does not match data sent") } ssc.stop(stopSparkContext = false) @@ -207,13 +207,13 @@ class KinesisStreamSuite extends KinesisFunSuite val collected = new mutable.HashSet[Int] with mutable.SynchronizedSet[Int] stream.foreachRDD { rdd => collected ++= rdd.collect() - logInfo("Collected = " + rdd.collect().toSeq.mkString(", ")) + logInfo("Collected = " + collected.mkString(", ")) } ssc.start() val testData = 1 to 10 eventually(timeout(120 seconds), interval(10 second)) { - testUtils.pushData(testData) + testUtils.pushData(testData, aggregateTestData) val modData = testData.map(_ + 5) assert(collected === modData.toSet, "\nData received does not match data sent") } @@ -254,7 +254,7 @@ class KinesisStreamSuite extends KinesisFunSuite // If this times out because numBatchesWithData is empty, then its likely that foreachRDD // function failed with exceptions, and nothing got added to `collectedData` eventually(timeout(2 minutes), interval(1 seconds)) { - testUtils.pushData(1 to 5) + testUtils.pushData(1 to 5, aggregateTestData) assert(isCheckpointPresent && numBatchesWithData > 10) } ssc.stop(stopSparkContext = true) // stop the SparkContext so that the blocks are not reused @@ -285,5 +285,8 @@ class KinesisStreamSuite extends KinesisFunSuite } ssc.stop() } - } + +class WithAggregationKinesisStreamSuite extends KinesisStreamTests(aggregateTestData = true) + +class WithoutAggregationKinesisStreamSuite extends KinesisStreamTests(aggregateTestData = false) diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala similarity index 80% rename from extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala rename to extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala index 634bf94521079..7487aa1c12639 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala +++ b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala @@ -31,6 +31,8 @@ import com.amazonaws.services.dynamodbv2.AmazonDynamoDBClient import com.amazonaws.services.dynamodbv2.document.DynamoDB import com.amazonaws.services.kinesis.AmazonKinesisClient import com.amazonaws.services.kinesis.model._ +import com.amazonaws.services.kinesis.producer.{KinesisProducer, KinesisProducerConfiguration, UserRecordResult} +import com.google.common.util.concurrent.{FutureCallback, Futures} import org.apache.spark.Logging @@ -64,6 +66,16 @@ private[kinesis] class KinesisTestUtils extends Logging { new DynamoDB(dynamoDBClient) } + private lazy val kinesisProducer: KinesisProducer = { + val conf = new KinesisProducerConfiguration() + .setRecordMaxBufferedTime(1000) + .setMaxConnections(1) + .setRegion(regionName) + .setMetricsLevel("none") + + new KinesisProducer(conf) + } + def streamName: String = { require(streamCreated, "Stream not yet created, call createStream() to create one") _streamName @@ -90,22 +102,41 @@ private[kinesis] class KinesisTestUtils extends Logging { * Push data to Kinesis stream and return a map of * shardId -> seq of (data, seq number) pushed to corresponding shard */ - def pushData(testData: Seq[Int]): Map[String, Seq[(Int, String)]] = { + def pushData(testData: Seq[Int], aggregate: Boolean): Map[String, Seq[(Int, String)]] = { require(streamCreated, "Stream not yet created, call createStream() to create one") val shardIdToSeqNumbers = new mutable.HashMap[String, ArrayBuffer[(Int, String)]]() testData.foreach { num => val str = num.toString - val putRecordRequest = new PutRecordRequest().withStreamName(streamName) - .withData(ByteBuffer.wrap(str.getBytes())) - .withPartitionKey(str) - - val putRecordResult = kinesisClient.putRecord(putRecordRequest) - val shardId = putRecordResult.getShardId - val seqNumber = putRecordResult.getSequenceNumber() - val sentSeqNumbers = shardIdToSeqNumbers.getOrElseUpdate(shardId, - new ArrayBuffer[(Int, String)]()) - sentSeqNumbers += ((num, seqNumber)) + val data = ByteBuffer.wrap(str.getBytes()) + if (aggregate) { + val future = kinesisProducer.addUserRecord(streamName, str, data) + val kinesisCallBack = new FutureCallback[UserRecordResult]() { + override def onFailure(t: Throwable): Unit = {} // do nothing + + override def onSuccess(result: UserRecordResult): Unit = { + val shardId = result.getShardId + val seqNumber = result.getSequenceNumber() + val sentSeqNumbers = shardIdToSeqNumbers.getOrElseUpdate(shardId, + new ArrayBuffer[(Int, String)]()) + sentSeqNumbers += ((num, seqNumber)) + } + } + + Futures.addCallback(future, kinesisCallBack) + kinesisProducer.flushSync() // make sure we send all data before returning the map + } else { + val putRecordRequest = new PutRecordRequest().withStreamName(streamName) + .withData(data) + .withPartitionKey(str) + + val putRecordResult = kinesisClient.putRecord(putRecordRequest) + val shardId = putRecordResult.getShardId + val seqNumber = putRecordResult.getSequenceNumber() + val sentSeqNumbers = shardIdToSeqNumbers.getOrElseUpdate(shardId, + new ArrayBuffer[(Int, String)]()) + sentSeqNumbers += ((num, seqNumber)) + } } logInfo(s"Pushed $testData:\n\t ${shardIdToSeqNumbers.mkString("\n\t")}") @@ -116,7 +147,7 @@ private[kinesis] class KinesisTestUtils extends Logging { * Expose a Python friendly API. */ def pushData(testData: java.util.List[Int]): Unit = { - pushData(testData.asScala) + pushData(testData.asScala, aggregate = false) } def deleteStream(): Unit = { diff --git a/pom.xml b/pom.xml index 4ed1c0c82dee6..fd8c773513881 100644 --- a/pom.xml +++ b/pom.xml @@ -154,6 +154,8 @@ 0.7.1 1.9.40 1.4.0 + + 0.10.1 4.3.2 From 0ce6f9b2d203ce67aeb4d3aedf19bbd997fe01b9 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Mon, 9 Nov 2015 17:35:12 -0800 Subject: [PATCH 66/88] [SPARK-11141][STREAMING] Batch ReceivedBlockTrackerLogEvents for WAL writes When using S3 as a directory for WALs, the writes take too long. The driver gets very easily bottlenecked when multiple receivers send AddBlock events to the ReceiverTracker. This PR adds batching of events in the ReceivedBlockTracker so that receivers don't get blocked by the driver for too long. cc zsxwing tdas Author: Burak Yavuz Closes #9143 from brkyvz/batch-wal-writes. --- .../scheduler/ReceivedBlockTracker.scala | 62 ++- .../streaming/scheduler/ReceiverTracker.scala | 25 +- .../streaming/util/BatchedWriteAheadLog.scala | 223 ++++++++ .../streaming/util/WriteAheadLogUtils.scala | 21 +- .../streaming/util/WriteAheadLogSuite.scala | 506 ++++++++++++------ .../util/WriteAheadLogUtilsSuite.scala | 122 +++++ 6 files changed, 767 insertions(+), 192 deletions(-) create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/util/BatchedWriteAheadLog.scala create mode 100644 streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogUtilsSuite.scala diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala index f2711d1355e60..500dc70c98506 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala @@ -22,12 +22,13 @@ import java.nio.ByteBuffer import scala.collection.JavaConverters._ import scala.collection.mutable import scala.language.implicitConversions +import scala.util.control.NonFatal import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.spark.streaming.Time -import org.apache.spark.streaming.util.{WriteAheadLog, WriteAheadLogUtils} +import org.apache.spark.streaming.util.{BatchedWriteAheadLog, WriteAheadLog, WriteAheadLogUtils} import org.apache.spark.util.{Clock, Utils} import org.apache.spark.{Logging, SparkConf} @@ -41,7 +42,6 @@ private[streaming] case class BatchAllocationEvent(time: Time, allocatedBlocks: private[streaming] case class BatchCleanupEvent(times: Seq[Time]) extends ReceivedBlockTrackerLogEvent - /** Class representing the blocks of all the streams allocated to a batch */ private[streaming] case class AllocatedBlocks(streamIdToAllocatedBlocks: Map[Int, Seq[ReceivedBlockInfo]]) { @@ -82,15 +82,22 @@ private[streaming] class ReceivedBlockTracker( } /** Add received block. This event will get written to the write ahead log (if enabled). */ - def addBlock(receivedBlockInfo: ReceivedBlockInfo): Boolean = synchronized { + def addBlock(receivedBlockInfo: ReceivedBlockInfo): Boolean = { try { - writeToLog(BlockAdditionEvent(receivedBlockInfo)) - getReceivedBlockQueue(receivedBlockInfo.streamId) += receivedBlockInfo - logDebug(s"Stream ${receivedBlockInfo.streamId} received " + - s"block ${receivedBlockInfo.blockStoreResult.blockId}") - true + val writeResult = writeToLog(BlockAdditionEvent(receivedBlockInfo)) + if (writeResult) { + synchronized { + getReceivedBlockQueue(receivedBlockInfo.streamId) += receivedBlockInfo + } + logDebug(s"Stream ${receivedBlockInfo.streamId} received " + + s"block ${receivedBlockInfo.blockStoreResult.blockId}") + } else { + logDebug(s"Failed to acknowledge stream ${receivedBlockInfo.streamId} receiving " + + s"block ${receivedBlockInfo.blockStoreResult.blockId} in the Write Ahead Log.") + } + writeResult } catch { - case e: Exception => + case NonFatal(e) => logError(s"Error adding block $receivedBlockInfo", e) false } @@ -106,10 +113,12 @@ private[streaming] class ReceivedBlockTracker( (streamId, getReceivedBlockQueue(streamId).dequeueAll(x => true)) }.toMap val allocatedBlocks = AllocatedBlocks(streamIdToBlocks) - writeToLog(BatchAllocationEvent(batchTime, allocatedBlocks)) - timeToAllocatedBlocks(batchTime) = allocatedBlocks - lastAllocatedBatchTime = batchTime - allocatedBlocks + if (writeToLog(BatchAllocationEvent(batchTime, allocatedBlocks))) { + timeToAllocatedBlocks.put(batchTime, allocatedBlocks) + lastAllocatedBatchTime = batchTime + } else { + logInfo(s"Possibly processed batch $batchTime need to be processed again in WAL recovery") + } } else { // This situation occurs when: // 1. WAL is ended with BatchAllocationEvent, but without BatchCleanupEvent, @@ -157,9 +166,12 @@ private[streaming] class ReceivedBlockTracker( require(cleanupThreshTime.milliseconds < clock.getTimeMillis()) val timesToCleanup = timeToAllocatedBlocks.keys.filter { _ < cleanupThreshTime }.toSeq logInfo("Deleting batches " + timesToCleanup) - writeToLog(BatchCleanupEvent(timesToCleanup)) - timeToAllocatedBlocks --= timesToCleanup - writeAheadLogOption.foreach(_.clean(cleanupThreshTime.milliseconds, waitForCompletion)) + if (writeToLog(BatchCleanupEvent(timesToCleanup))) { + timeToAllocatedBlocks --= timesToCleanup + writeAheadLogOption.foreach(_.clean(cleanupThreshTime.milliseconds, waitForCompletion)) + } else { + logWarning("Failed to acknowledge batch clean up in the Write Ahead Log.") + } } /** Stop the block tracker. */ @@ -185,8 +197,8 @@ private[streaming] class ReceivedBlockTracker( logTrace(s"Recovery: Inserting allocated batch for time $batchTime to " + s"${allocatedBlocks.streamIdToAllocatedBlocks}") streamIdToUnallocatedBlockQueues.values.foreach { _.clear() } - lastAllocatedBatchTime = batchTime timeToAllocatedBlocks.put(batchTime, allocatedBlocks) + lastAllocatedBatchTime = batchTime } // Cleanup the batch allocations @@ -213,12 +225,20 @@ private[streaming] class ReceivedBlockTracker( } /** Write an update to the tracker to the write ahead log */ - private def writeToLog(record: ReceivedBlockTrackerLogEvent) { + private def writeToLog(record: ReceivedBlockTrackerLogEvent): Boolean = { if (isWriteAheadLogEnabled) { - logDebug(s"Writing to log $record") - writeAheadLogOption.foreach { logManager => - logManager.write(ByteBuffer.wrap(Utils.serialize(record)), clock.getTimeMillis()) + logTrace(s"Writing record: $record") + try { + writeAheadLogOption.get.write(ByteBuffer.wrap(Utils.serialize(record)), + clock.getTimeMillis()) + true + } catch { + case NonFatal(e) => + logWarning(s"Exception thrown while writing record: $record to the WriteAheadLog.", e) + false } + } else { + true } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala index b183d856f50c3..ea5d12b50fcc5 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala @@ -20,7 +20,7 @@ package org.apache.spark.streaming.scheduler import java.util.concurrent.{CountDownLatch, TimeUnit} import scala.collection.mutable.HashMap -import scala.concurrent.ExecutionContext +import scala.concurrent.{Future, ExecutionContext} import scala.language.existentials import scala.util.{Failure, Success} @@ -437,7 +437,12 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false // TODO Remove this thread pool after https://github.com/apache/spark/issues/7385 is merged private val submitJobThreadPool = ExecutionContext.fromExecutorService( - ThreadUtils.newDaemonCachedThreadPool("submit-job-thead-pool")) + ThreadUtils.newDaemonCachedThreadPool("submit-job-thread-pool")) + + private val walBatchingThreadPool = ExecutionContext.fromExecutorService( + ThreadUtils.newDaemonCachedThreadPool("wal-batching-thread-pool")) + + @volatile private var active: Boolean = true override def receive: PartialFunction[Any, Unit] = { // Local messages @@ -488,7 +493,19 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false registerReceiver(streamId, typ, host, executorId, receiverEndpoint, context.senderAddress) context.reply(successful) case AddBlock(receivedBlockInfo) => - context.reply(addBlock(receivedBlockInfo)) + if (WriteAheadLogUtils.isBatchingEnabled(ssc.conf, isDriver = true)) { + walBatchingThreadPool.execute(new Runnable { + override def run(): Unit = Utils.tryLogNonFatalError { + if (active) { + context.reply(addBlock(receivedBlockInfo)) + } else { + throw new IllegalStateException("ReceiverTracker RpcEndpoint shut down.") + } + } + }) + } else { + context.reply(addBlock(receivedBlockInfo)) + } case DeregisterReceiver(streamId, message, error) => deregisterReceiver(streamId, message, error) context.reply(true) @@ -599,6 +616,8 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false override def onStop(): Unit = { submitJobThreadPool.shutdownNow() + active = false + walBatchingThreadPool.shutdown() } /** diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/BatchedWriteAheadLog.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/BatchedWriteAheadLog.scala new file mode 100644 index 0000000000000..9727ed2ba1445 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/BatchedWriteAheadLog.scala @@ -0,0 +1,223 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.util + +import java.nio.ByteBuffer +import java.util.concurrent.LinkedBlockingQueue +import java.util.{Iterator => JIterator} + +import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuffer +import scala.concurrent.{Await, Promise} +import scala.concurrent.duration._ +import scala.util.control.NonFatal + +import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.util.Utils + +/** + * A wrapper for a WriteAheadLog that batches records before writing data. Handles aggregation + * during writes, and de-aggregation in the `readAll` method. The end consumer has to handle + * de-aggregation after the `read` method. In addition, the `WriteAheadLogRecordHandle` returned + * after the write will contain the batch of records rather than individual records. + * + * When writing a batch of records, the `time` passed to the `wrappedLog` will be the timestamp + * of the latest record in the batch. This is very important in achieving correctness. Consider the + * following example: + * We receive records with timestamps 1, 3, 5, 7. We use "log-1" as the filename. Once we receive + * a clean up request for timestamp 3, we would clean up the file "log-1", and lose data regarding + * 5 and 7. + * + * This means the caller can assume the same write semantics as any other WriteAheadLog + * implementation despite the batching in the background - when the write() returns, the data is + * written to the WAL and is durable. To take advantage of the batching, the caller can write from + * multiple threads, each of which will stay blocked until the corresponding data has been written. + * + * All other methods of the WriteAheadLog interface will be passed on to the wrapped WriteAheadLog. + */ +private[util] class BatchedWriteAheadLog(val wrappedLog: WriteAheadLog, conf: SparkConf) + extends WriteAheadLog with Logging { + + import BatchedWriteAheadLog._ + + private val walWriteQueue = new LinkedBlockingQueue[Record]() + + // Whether the writer thread is active + @volatile private var active: Boolean = true + private val buffer = new ArrayBuffer[Record]() + + private val batchedWriterThread = startBatchedWriterThread() + + /** + * Write a byte buffer to the log file. This method adds the byteBuffer to a queue and blocks + * until the record is properly written by the parent. + */ + override def write(byteBuffer: ByteBuffer, time: Long): WriteAheadLogRecordHandle = { + val promise = Promise[WriteAheadLogRecordHandle]() + val putSuccessfully = synchronized { + if (active) { + walWriteQueue.offer(Record(byteBuffer, time, promise)) + true + } else { + false + } + } + if (putSuccessfully) { + Await.result(promise.future, WriteAheadLogUtils.getBatchingTimeout(conf).milliseconds) + } else { + throw new IllegalStateException("close() was called on BatchedWriteAheadLog before " + + s"write request with time $time could be fulfilled.") + } + } + + /** + * This method is not supported as the resulting ByteBuffer would actually require de-aggregation. + * This method is primarily used in testing, and to ensure that it is not used in production, + * we throw an UnsupportedOperationException. + */ + override def read(segment: WriteAheadLogRecordHandle): ByteBuffer = { + throw new UnsupportedOperationException("read() is not supported for BatchedWriteAheadLog " + + "as the data may require de-aggregation.") + } + + /** + * Read all the existing logs from the log directory. The output of the wrapped WriteAheadLog + * will be de-aggregated. + */ + override def readAll(): JIterator[ByteBuffer] = { + wrappedLog.readAll().asScala.flatMap(deaggregate).asJava + } + + /** + * Delete the log files that are older than the threshold time. + * + * This method is handled by the parent WriteAheadLog. + */ + override def clean(threshTime: Long, waitForCompletion: Boolean): Unit = { + wrappedLog.clean(threshTime, waitForCompletion) + } + + + /** + * Stop the batched writer thread, fulfill promises with failures and close the wrapped WAL. + */ + override def close(): Unit = { + logInfo(s"BatchedWriteAheadLog shutting down at time: ${System.currentTimeMillis()}.") + synchronized { + active = false + } + batchedWriterThread.interrupt() + batchedWriterThread.join() + while (!walWriteQueue.isEmpty) { + val Record(_, time, promise) = walWriteQueue.poll() + promise.failure(new IllegalStateException("close() was called on BatchedWriteAheadLog " + + s"before write request with time $time could be fulfilled.")) + } + wrappedLog.close() + } + + /** Start the actual log writer on a separate thread. */ + private def startBatchedWriterThread(): Thread = { + val thread = new Thread(new Runnable { + override def run(): Unit = { + while (active) { + try { + flushRecords() + } catch { + case NonFatal(e) => + logWarning("Encountered exception in Batched Writer Thread.", e) + } + } + logInfo("BatchedWriteAheadLog Writer thread exiting.") + } + }, "BatchedWriteAheadLog Writer") + thread.setDaemon(true) + thread.start() + thread + } + + /** Write all the records in the buffer to the write ahead log. */ + private def flushRecords(): Unit = { + try { + buffer.append(walWriteQueue.take()) + val numBatched = walWriteQueue.drainTo(buffer.asJava) + 1 + logDebug(s"Received $numBatched records from queue") + } catch { + case _: InterruptedException => + logWarning("BatchedWriteAheadLog Writer queue interrupted.") + } + try { + var segment: WriteAheadLogRecordHandle = null + if (buffer.length > 0) { + logDebug(s"Batched ${buffer.length} records for Write Ahead Log write") + // We take the latest record for the timestamp. Please refer to the class Javadoc for + // detailed explanation + val time = buffer.last.time + segment = wrappedLog.write(aggregate(buffer), time) + } + buffer.foreach(_.promise.success(segment)) + } catch { + case e: InterruptedException => + logWarning("BatchedWriteAheadLog Writer queue interrupted.", e) + buffer.foreach(_.promise.failure(e)) + case NonFatal(e) => + logWarning(s"BatchedWriteAheadLog Writer failed to write $buffer", e) + buffer.foreach(_.promise.failure(e)) + } finally { + buffer.clear() + } + } +} + +/** Static methods for aggregating and de-aggregating records. */ +private[util] object BatchedWriteAheadLog { + + /** + * Wrapper class for representing the records that we will write to the WriteAheadLog. Coupled + * with the timestamp for the write request of the record, and the promise that will block the + * write request, while a separate thread is actually performing the write. + */ + case class Record(data: ByteBuffer, time: Long, promise: Promise[WriteAheadLogRecordHandle]) + + /** Copies the byte array of a ByteBuffer. */ + private def getByteArray(buffer: ByteBuffer): Array[Byte] = { + val byteArray = new Array[Byte](buffer.remaining()) + buffer.get(byteArray) + byteArray + } + + /** Aggregate multiple serialized ReceivedBlockTrackerLogEvents in a single ByteBuffer. */ + def aggregate(records: Seq[Record]): ByteBuffer = { + ByteBuffer.wrap(Utils.serialize[Array[Array[Byte]]]( + records.map(record => getByteArray(record.data)).toArray)) + } + + /** + * De-aggregate serialized ReceivedBlockTrackerLogEvents in a single ByteBuffer. + * A stream may not have used batching initially, but started using it after a restart. This + * method therefore needs to be backwards compatible. + */ + def deaggregate(buffer: ByteBuffer): Array[ByteBuffer] = { + try { + Utils.deserialize[Array[Array[Byte]]](getByteArray(buffer)).map(ByteBuffer.wrap) + } catch { + case _: ClassCastException => // users may restart a stream with batching enabled + Array(buffer) + } + } +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogUtils.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogUtils.scala index 0ea970e61b694..731a369fc92c0 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogUtils.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogUtils.scala @@ -38,6 +38,8 @@ private[streaming] object WriteAheadLogUtils extends Logging { val DRIVER_WAL_ROLLING_INTERVAL_CONF_KEY = "spark.streaming.driver.writeAheadLog.rollingIntervalSecs" val DRIVER_WAL_MAX_FAILURES_CONF_KEY = "spark.streaming.driver.writeAheadLog.maxFailures" + val DRIVER_WAL_BATCHING_CONF_KEY = "spark.streaming.driver.writeAheadLog.allowBatching" + val DRIVER_WAL_BATCHING_TIMEOUT_CONF_KEY = "spark.streaming.driver.writeAheadLog.batchingTimeout" val DRIVER_WAL_CLOSE_AFTER_WRITE_CONF_KEY = "spark.streaming.driver.writeAheadLog.closeFileAfterWrite" @@ -64,6 +66,18 @@ private[streaming] object WriteAheadLogUtils extends Logging { } } + def isBatchingEnabled(conf: SparkConf, isDriver: Boolean): Boolean = { + isDriver && conf.getBoolean(DRIVER_WAL_BATCHING_CONF_KEY, defaultValue = false) + } + + /** + * How long we will wait for the wrappedLog in the BatchedWriteAheadLog to write the records + * before we fail the write attempt to unblock receivers. + */ + def getBatchingTimeout(conf: SparkConf): Long = { + conf.getLong(DRIVER_WAL_BATCHING_TIMEOUT_CONF_KEY, defaultValue = 5000) + } + def shouldCloseFileAfterWrite(conf: SparkConf, isDriver: Boolean): Boolean = { if (isDriver) { conf.getBoolean(DRIVER_WAL_CLOSE_AFTER_WRITE_CONF_KEY, defaultValue = false) @@ -115,7 +129,7 @@ private[streaming] object WriteAheadLogUtils extends Logging { } else { sparkConf.getOption(RECEIVER_WAL_CLASS_CONF_KEY) } - classNameOption.map { className => + val wal = classNameOption.map { className => try { instantiateClass( Utils.classForName(className).asInstanceOf[Class[_ <: WriteAheadLog]], sparkConf) @@ -128,6 +142,11 @@ private[streaming] object WriteAheadLogUtils extends Logging { getRollingIntervalSecs(sparkConf, isDriver), getMaxFailures(sparkConf, isDriver), shouldCloseFileAfterWrite(sparkConf, isDriver)) } + if (isBatchingEnabled(sparkConf, isDriver)) { + new BatchedWriteAheadLog(wal, sparkConf) + } else { + wal + } } /** Instantiate the class, either using single arg constructor or zero arg constructor */ diff --git a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala index 93ae41a3d2ecd..e96f4c2a29347 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala @@ -18,31 +18,47 @@ package org.apache.spark.streaming.util import java.io._ import java.nio.ByteBuffer -import java.util +import java.util.concurrent.{ExecutionException, ThreadPoolExecutor} +import java.util.concurrent.atomic.AtomicInteger import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer +import scala.concurrent._ import scala.concurrent.duration._ import scala.language.{implicitConversions, postfixOps} -import scala.reflect.ClassTag +import scala.util.{Failure, Success} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path +import org.mockito.Matchers.{eq => meq} +import org.mockito.Matchers._ +import org.mockito.Mockito._ +import org.mockito.invocation.InvocationOnMock +import org.mockito.stubbing.Answer +import org.scalatest.concurrent.Eventually import org.scalatest.concurrent.Eventually._ -import org.scalatest.BeforeAndAfter +import org.scalatest.{BeforeAndAfterEach, BeforeAndAfter} +import org.scalatest.mock.MockitoSugar -import org.apache.spark.util.{ManualClock, Utils} -import org.apache.spark.{SparkConf, SparkException, SparkFunSuite} +import org.apache.spark.streaming.scheduler._ +import org.apache.spark.util.{ThreadUtils, ManualClock, Utils} +import org.apache.spark.{SparkException, SparkConf, SparkFunSuite} -class WriteAheadLogSuite extends SparkFunSuite with BeforeAndAfter { +/** Common tests for WriteAheadLogs that we would like to test with different configurations. */ +abstract class CommonWriteAheadLogTests( + allowBatching: Boolean, + closeFileAfterWrite: Boolean, + testTag: String = "") + extends SparkFunSuite with BeforeAndAfter { import WriteAheadLogSuite._ - val hadoopConf = new Configuration() - var tempDir: File = null - var testDir: String = null - var testFile: String = null - var writeAheadLog: FileBasedWriteAheadLog = null + protected val hadoopConf = new Configuration() + protected var tempDir: File = null + protected var testDir: String = null + protected var testFile: String = null + protected var writeAheadLog: WriteAheadLog = null + protected def testPrefix = if (testTag != "") testTag + " - " else testTag before { tempDir = Utils.createTempDir() @@ -58,49 +74,130 @@ class WriteAheadLogSuite extends SparkFunSuite with BeforeAndAfter { Utils.deleteRecursively(tempDir) } - test("WriteAheadLogUtils - log selection and creation") { - val logDir = Utils.createTempDir().getAbsolutePath() + test(testPrefix + "read all logs") { + // Write data manually for testing reading through WriteAheadLog + val writtenData = (1 to 10).map { i => + val data = generateRandomData() + val file = testDir + s"/log-$i-$i" + writeDataManually(data, file, allowBatching) + data + }.flatten - def assertDriverLogClass[T <: WriteAheadLog: ClassTag](conf: SparkConf): WriteAheadLog = { - val log = WriteAheadLogUtils.createLogForDriver(conf, logDir, hadoopConf) - assert(log.getClass === implicitly[ClassTag[T]].runtimeClass) - log + val logDirectoryPath = new Path(testDir) + val fileSystem = HdfsUtils.getFileSystemForPath(logDirectoryPath, hadoopConf) + assert(fileSystem.exists(logDirectoryPath) === true) + + // Read data using manager and verify + val readData = readDataUsingWriteAheadLog(testDir, closeFileAfterWrite, allowBatching) + assert(readData === writtenData) + } + + test(testPrefix + "write logs") { + // Write data with rotation using WriteAheadLog class + val dataToWrite = generateRandomData() + writeDataUsingWriteAheadLog(testDir, dataToWrite, closeFileAfterWrite = closeFileAfterWrite, + allowBatching = allowBatching) + + // Read data manually to verify the written data + val logFiles = getLogFilesInDirectory(testDir) + assert(logFiles.size > 1) + val writtenData = readAndDeserializeDataManually(logFiles, allowBatching) + assert(writtenData === dataToWrite) + } + + test(testPrefix + "read all logs after write") { + // Write data with manager, recover with new manager and verify + val dataToWrite = generateRandomData() + writeDataUsingWriteAheadLog(testDir, dataToWrite, closeFileAfterWrite, allowBatching) + val logFiles = getLogFilesInDirectory(testDir) + assert(logFiles.size > 1) + val readData = readDataUsingWriteAheadLog(testDir, closeFileAfterWrite, allowBatching) + assert(dataToWrite === readData) + } + + test(testPrefix + "clean old logs") { + logCleanUpTest(waitForCompletion = false) + } + + test(testPrefix + "clean old logs synchronously") { + logCleanUpTest(waitForCompletion = true) + } + + private def logCleanUpTest(waitForCompletion: Boolean): Unit = { + // Write data with manager, recover with new manager and verify + val manualClock = new ManualClock + val dataToWrite = generateRandomData() + writeAheadLog = writeDataUsingWriteAheadLog(testDir, dataToWrite, closeFileAfterWrite, + allowBatching, manualClock, closeLog = false) + val logFiles = getLogFilesInDirectory(testDir) + assert(logFiles.size > 1) + + writeAheadLog.clean(manualClock.getTimeMillis() / 2, waitForCompletion) + + if (waitForCompletion) { + assert(getLogFilesInDirectory(testDir).size < logFiles.size) + } else { + eventually(Eventually.timeout(1 second), interval(10 milliseconds)) { + assert(getLogFilesInDirectory(testDir).size < logFiles.size) + } } + } - def assertReceiverLogClass[T: ClassTag](conf: SparkConf): WriteAheadLog = { - val log = WriteAheadLogUtils.createLogForReceiver(conf, logDir, hadoopConf) - assert(log.getClass === implicitly[ClassTag[T]].runtimeClass) - log + test(testPrefix + "handling file errors while reading rotating logs") { + // Generate a set of log files + val manualClock = new ManualClock + val dataToWrite1 = generateRandomData() + writeDataUsingWriteAheadLog(testDir, dataToWrite1, closeFileAfterWrite, allowBatching, + manualClock) + val logFiles1 = getLogFilesInDirectory(testDir) + assert(logFiles1.size > 1) + + + // Recover old files and generate a second set of log files + val dataToWrite2 = generateRandomData() + manualClock.advance(100000) + writeDataUsingWriteAheadLog(testDir, dataToWrite2, closeFileAfterWrite, allowBatching , + manualClock) + val logFiles2 = getLogFilesInDirectory(testDir) + assert(logFiles2.size > logFiles1.size) + + // Read the files and verify that all the written data can be read + val readData1 = readDataUsingWriteAheadLog(testDir, closeFileAfterWrite, allowBatching) + assert(readData1 === (dataToWrite1 ++ dataToWrite2)) + + // Corrupt the first set of files so that they are basically unreadable + logFiles1.foreach { f => + val raf = new FileOutputStream(f, true).getChannel() + raf.truncate(1) + raf.close() } - val emptyConf = new SparkConf() // no log configuration - assertDriverLogClass[FileBasedWriteAheadLog](emptyConf) - assertReceiverLogClass[FileBasedWriteAheadLog](emptyConf) - - // Verify setting driver WAL class - val conf1 = new SparkConf().set("spark.streaming.driver.writeAheadLog.class", - classOf[MockWriteAheadLog0].getName()) - assertDriverLogClass[MockWriteAheadLog0](conf1) - assertReceiverLogClass[FileBasedWriteAheadLog](conf1) - - // Verify setting receiver WAL class - val receiverWALConf = new SparkConf().set("spark.streaming.receiver.writeAheadLog.class", - classOf[MockWriteAheadLog0].getName()) - assertDriverLogClass[FileBasedWriteAheadLog](receiverWALConf) - assertReceiverLogClass[MockWriteAheadLog0](receiverWALConf) - - // Verify setting receiver WAL class with 1-arg constructor - val receiverWALConf2 = new SparkConf().set("spark.streaming.receiver.writeAheadLog.class", - classOf[MockWriteAheadLog1].getName()) - assertReceiverLogClass[MockWriteAheadLog1](receiverWALConf2) - - // Verify failure setting receiver WAL class with 2-arg constructor - intercept[SparkException] { - val receiverWALConf3 = new SparkConf().set("spark.streaming.receiver.writeAheadLog.class", - classOf[MockWriteAheadLog2].getName()) - assertReceiverLogClass[MockWriteAheadLog1](receiverWALConf3) + // Verify that the corrupted files do not prevent reading of the second set of data + val readData = readDataUsingWriteAheadLog(testDir, closeFileAfterWrite, allowBatching) + assert(readData === dataToWrite2) + } + + test(testPrefix + "do not create directories or files unless write") { + val nonexistentTempPath = File.createTempFile("test", "") + nonexistentTempPath.delete() + assert(!nonexistentTempPath.exists()) + + val writtenSegment = writeDataManually(generateRandomData(), testFile, allowBatching) + val wal = createWriteAheadLog(testDir, closeFileAfterWrite, allowBatching) + assert(!nonexistentTempPath.exists(), "Directory created just by creating log object") + if (allowBatching) { + intercept[UnsupportedOperationException](wal.read(writtenSegment.head)) + } else { + wal.read(writtenSegment.head) } + assert(!nonexistentTempPath.exists(), "Directory created just by attempting to read segment") } +} + +class FileBasedWriteAheadLogSuite + extends CommonWriteAheadLogTests(false, false, "FileBasedWriteAheadLog") { + + import WriteAheadLogSuite._ test("FileBasedWriteAheadLogWriter - writing data") { val dataToWrite = generateRandomData() @@ -122,7 +219,7 @@ class WriteAheadLogSuite extends SparkFunSuite with BeforeAndAfter { test("FileBasedWriteAheadLogReader - sequentially reading data") { val writtenData = generateRandomData() - writeDataManually(writtenData, testFile) + writeDataManually(writtenData, testFile, allowBatching = false) val reader = new FileBasedWriteAheadLogReader(testFile, hadoopConf) val readData = reader.toSeq.map(byteBufferToString) assert(readData === writtenData) @@ -166,7 +263,7 @@ class WriteAheadLogSuite extends SparkFunSuite with BeforeAndAfter { test("FileBasedWriteAheadLogRandomReader - reading data using random reader") { // Write data manually for testing the random reader val writtenData = generateRandomData() - val segments = writeDataManually(writtenData, testFile) + val segments = writeDataManually(writtenData, testFile, allowBatching = false) // Get a random order of these segments and read them back val writtenDataAndSegments = writtenData.zip(segments).toSeq.permutations.take(10).flatten @@ -190,163 +287,212 @@ class WriteAheadLogSuite extends SparkFunSuite with BeforeAndAfter { } reader.close() } +} - test("FileBasedWriteAheadLog - write rotating logs") { - // Write data with rotation using WriteAheadLog class - val dataToWrite = generateRandomData() - writeDataUsingWriteAheadLog(testDir, dataToWrite) - - // Read data manually to verify the written data - val logFiles = getLogFilesInDirectory(testDir) - assert(logFiles.size > 1) - val writtenData = logFiles.flatMap { file => readDataManually(file)} - assert(writtenData === dataToWrite) - } +abstract class CloseFileAfterWriteTests(allowBatching: Boolean, testTag: String) + extends CommonWriteAheadLogTests(allowBatching, closeFileAfterWrite = true, testTag) { - test("FileBasedWriteAheadLog - close after write flag") { + import WriteAheadLogSuite._ + test(testPrefix + "close after write flag") { // Write data with rotation using WriteAheadLog class val numFiles = 3 val dataToWrite = Seq.tabulate(numFiles)(_.toString) // total advance time is less than 1000, therefore log shouldn't be rolled, but manually closed writeDataUsingWriteAheadLog(testDir, dataToWrite, closeLog = false, clockAdvanceTime = 100, - closeFileAfterWrite = true) + closeFileAfterWrite = true, allowBatching = allowBatching) // Read data manually to verify the written data val logFiles = getLogFilesInDirectory(testDir) assert(logFiles.size === numFiles) - val writtenData = logFiles.flatMap { file => readDataManually(file)} + val writtenData: Seq[String] = readAndDeserializeDataManually(logFiles, allowBatching) assert(writtenData === dataToWrite) } +} - test("FileBasedWriteAheadLog - read rotating logs") { - // Write data manually for testing reading through WriteAheadLog - val writtenData = (1 to 10).map { i => - val data = generateRandomData() - val file = testDir + s"/log-$i-$i" - writeDataManually(data, file) - data - }.flatten +class FileBasedWriteAheadLogWithFileCloseAfterWriteSuite + extends CloseFileAfterWriteTests(allowBatching = false, "FileBasedWriteAheadLog") - val logDirectoryPath = new Path(testDir) - val fileSystem = HdfsUtils.getFileSystemForPath(logDirectoryPath, hadoopConf) - assert(fileSystem.exists(logDirectoryPath) === true) +class BatchedWriteAheadLogSuite extends CommonWriteAheadLogTests( + allowBatching = true, + closeFileAfterWrite = false, + "BatchedWriteAheadLog") with MockitoSugar with BeforeAndAfterEach with Eventually { - // Read data using manager and verify - val readData = readDataUsingWriteAheadLog(testDir) - assert(readData === writtenData) - } + import BatchedWriteAheadLog._ + import WriteAheadLogSuite._ - test("FileBasedWriteAheadLog - recover past logs when creating new manager") { - // Write data with manager, recover with new manager and verify - val dataToWrite = generateRandomData() - writeDataUsingWriteAheadLog(testDir, dataToWrite) - val logFiles = getLogFilesInDirectory(testDir) - assert(logFiles.size > 1) - val readData = readDataUsingWriteAheadLog(testDir) - assert(dataToWrite === readData) + private var wal: WriteAheadLog = _ + private var walHandle: WriteAheadLogRecordHandle = _ + private var walBatchingThreadPool: ThreadPoolExecutor = _ + private var walBatchingExecutionContext: ExecutionContextExecutorService = _ + private val sparkConf = new SparkConf() + + override def beforeEach(): Unit = { + wal = mock[WriteAheadLog] + walHandle = mock[WriteAheadLogRecordHandle] + walBatchingThreadPool = ThreadUtils.newDaemonFixedThreadPool(8, "wal-test-thread-pool") + walBatchingExecutionContext = ExecutionContext.fromExecutorService(walBatchingThreadPool) } - test("FileBasedWriteAheadLog - clean old logs") { - logCleanUpTest(waitForCompletion = false) + override def afterEach(): Unit = { + if (walBatchingExecutionContext != null) { + walBatchingExecutionContext.shutdownNow() + } } - test("FileBasedWriteAheadLog - clean old logs synchronously") { - logCleanUpTest(waitForCompletion = true) - } + test("BatchedWriteAheadLog - serializing and deserializing batched records") { + val events = Seq( + BlockAdditionEvent(ReceivedBlockInfo(0, None, None, null)), + BatchAllocationEvent(null, null), + BatchCleanupEvent(Nil) + ) - private def logCleanUpTest(waitForCompletion: Boolean): Unit = { - // Write data with manager, recover with new manager and verify - val manualClock = new ManualClock - val dataToWrite = generateRandomData() - writeAheadLog = writeDataUsingWriteAheadLog(testDir, dataToWrite, manualClock, closeLog = false) - val logFiles = getLogFilesInDirectory(testDir) - assert(logFiles.size > 1) + val buffers = events.map(e => Record(ByteBuffer.wrap(Utils.serialize(e)), 0L, null)) + val batched = BatchedWriteAheadLog.aggregate(buffers) + val deaggregate = BatchedWriteAheadLog.deaggregate(batched).map(buffer => + Utils.deserialize[ReceivedBlockTrackerLogEvent](buffer.array())) - writeAheadLog.clean(manualClock.getTimeMillis() / 2, waitForCompletion) + assert(deaggregate.toSeq === events) + } - if (waitForCompletion) { - assert(getLogFilesInDirectory(testDir).size < logFiles.size) - } else { - eventually(timeout(1 second), interval(10 milliseconds)) { - assert(getLogFilesInDirectory(testDir).size < logFiles.size) - } + test("BatchedWriteAheadLog - failures in wrappedLog get bubbled up") { + when(wal.write(any[ByteBuffer], anyLong)).thenThrow(new RuntimeException("Hello!")) + // the BatchedWriteAheadLog should bubble up any exceptions that may have happened during writes + val batchedWal = new BatchedWriteAheadLog(wal, sparkConf) + + intercept[RuntimeException] { + val buffer = mock[ByteBuffer] + batchedWal.write(buffer, 2L) } } - test("FileBasedWriteAheadLog - handling file errors while reading rotating logs") { - // Generate a set of log files - val manualClock = new ManualClock - val dataToWrite1 = generateRandomData() - writeDataUsingWriteAheadLog(testDir, dataToWrite1, manualClock) - val logFiles1 = getLogFilesInDirectory(testDir) - assert(logFiles1.size > 1) + // we make the write requests in separate threads so that we don't block the test thread + private def promiseWriteEvent(wal: WriteAheadLog, event: String, time: Long): Promise[Unit] = { + val p = Promise[Unit]() + p.completeWith(Future { + val v = wal.write(event, time) + assert(v === walHandle) + }(walBatchingExecutionContext)) + p + } + /** + * In order to block the writes on the writer thread, we mock the write method, and block it + * for some time with a promise. + */ + private def writeBlockingPromise(wal: WriteAheadLog): Promise[Any] = { + // we would like to block the write so that we can queue requests + val promise = Promise[Any]() + when(wal.write(any[ByteBuffer], any[Long])).thenAnswer( + new Answer[WriteAheadLogRecordHandle] { + override def answer(invocation: InvocationOnMock): WriteAheadLogRecordHandle = { + Await.ready(promise.future, 4.seconds) + walHandle + } + } + ) + promise + } - // Recover old files and generate a second set of log files - val dataToWrite2 = generateRandomData() - manualClock.advance(100000) - writeDataUsingWriteAheadLog(testDir, dataToWrite2, manualClock) - val logFiles2 = getLogFilesInDirectory(testDir) - assert(logFiles2.size > logFiles1.size) + test("BatchedWriteAheadLog - name log with aggregated entries with the timestamp of last entry") { + val batchedWal = new BatchedWriteAheadLog(wal, sparkConf) + // block the write so that we can batch some records + val promise = writeBlockingPromise(wal) + + val event1 = "hello" + val event2 = "world" + val event3 = "this" + val event4 = "is" + val event5 = "doge" + + // The queue.take() immediately takes the 3, and there is nothing left in the queue at that + // moment. Then the promise blocks the writing of 3. The rest get queued. + promiseWriteEvent(batchedWal, event1, 3L) + // rest of the records will be batched while it takes 3 to get written + promiseWriteEvent(batchedWal, event2, 5L) + promiseWriteEvent(batchedWal, event3, 8L) + promiseWriteEvent(batchedWal, event4, 12L) + promiseWriteEvent(batchedWal, event5, 10L) + eventually(timeout(1 second)) { + assert(walBatchingThreadPool.getActiveCount === 5) + } + promise.success(true) - // Read the files and verify that all the written data can be read - val readData1 = readDataUsingWriteAheadLog(testDir) - assert(readData1 === (dataToWrite1 ++ dataToWrite2)) + val buffer1 = wrapArrayArrayByte(Array(event1)) + val buffer2 = wrapArrayArrayByte(Array(event2, event3, event4, event5)) - // Corrupt the first set of files so that they are basically unreadable - logFiles1.foreach { f => - val raf = new FileOutputStream(f, true).getChannel() - raf.truncate(1) - raf.close() + eventually(timeout(1 second)) { + verify(wal, times(1)).write(meq(buffer1), meq(3L)) + // the file name should be the timestamp of the last record, as events should be naturally + // in order of timestamp, and we need the last element. + verify(wal, times(1)).write(meq(buffer2), meq(10L)) } - - // Verify that the corrupted files do not prevent reading of the second set of data - val readData = readDataUsingWriteAheadLog(testDir) - assert(readData === dataToWrite2) } - test("FileBasedWriteAheadLog - do not create directories or files unless write") { - val nonexistentTempPath = File.createTempFile("test", "") - nonexistentTempPath.delete() - assert(!nonexistentTempPath.exists()) + test("BatchedWriteAheadLog - shutdown properly") { + val batchedWal = new BatchedWriteAheadLog(wal, sparkConf) + batchedWal.close() + verify(wal, times(1)).close() - val writtenSegment = writeDataManually(generateRandomData(), testFile) - val wal = new FileBasedWriteAheadLog(new SparkConf(), tempDir.getAbsolutePath, - new Configuration(), 1, 1, closeFileAfterWrite = false) - assert(!nonexistentTempPath.exists(), "Directory created just by creating log object") - wal.read(writtenSegment.head) - assert(!nonexistentTempPath.exists(), "Directory created just by attempting to read segment") + intercept[IllegalStateException](batchedWal.write(mock[ByteBuffer], 12L)) } -} -object WriteAheadLogSuite { + test("BatchedWriteAheadLog - fail everything in queue during shutdown") { + val batchedWal = new BatchedWriteAheadLog(wal, sparkConf) - class MockWriteAheadLog0() extends WriteAheadLog { - override def write(record: ByteBuffer, time: Long): WriteAheadLogRecordHandle = { null } - override def read(handle: WriteAheadLogRecordHandle): ByteBuffer = { null } - override def readAll(): util.Iterator[ByteBuffer] = { null } - override def clean(threshTime: Long, waitForCompletion: Boolean): Unit = { } - override def close(): Unit = { } - } + // block the write so that we can batch some records + writeBlockingPromise(wal) + + val event1 = ("hello", 3L) + val event2 = ("world", 5L) + val event3 = ("this", 8L) + val event4 = ("is", 9L) + val event5 = ("doge", 10L) + + // The queue.take() immediately takes the 3, and there is nothing left in the queue at that + // moment. Then the promise blocks the writing of 3. The rest get queued. + val writePromises = Seq(event1, event2, event3, event4, event5).map { event => + promiseWriteEvent(batchedWal, event._1, event._2) + } - class MockWriteAheadLog1(val conf: SparkConf) extends MockWriteAheadLog0() + eventually(timeout(1 second)) { + assert(walBatchingThreadPool.getActiveCount === 5) + } + + batchedWal.close() + eventually(timeout(1 second)) { + assert(writePromises.forall(_.isCompleted)) + assert(writePromises.forall(_.future.value.get.isFailure)) // all should have failed + } + } +} - class MockWriteAheadLog2(val conf: SparkConf, x: Int) extends MockWriteAheadLog0() +class BatchedWriteAheadLogWithCloseFileAfterWriteSuite + extends CloseFileAfterWriteTests(allowBatching = true, "BatchedWriteAheadLog") +object WriteAheadLogSuite { private val hadoopConf = new Configuration() /** Write data to a file directly and return an array of the file segments written. */ - def writeDataManually(data: Seq[String], file: String): Seq[FileBasedWriteAheadLogSegment] = { + def writeDataManually( + data: Seq[String], + file: String, + allowBatching: Boolean): Seq[FileBasedWriteAheadLogSegment] = { val segments = new ArrayBuffer[FileBasedWriteAheadLogSegment]() val writer = HdfsUtils.getOutputStream(file, hadoopConf) - data.foreach { item => + def writeToStream(bytes: Array[Byte]): Unit = { val offset = writer.getPos - val bytes = Utils.serialize(item) writer.writeInt(bytes.size) writer.write(bytes) segments += FileBasedWriteAheadLogSegment(file, offset, bytes.size) } + if (allowBatching) { + writeToStream(wrapArrayArrayByte(data.toArray[String]).array()) + } else { + data.foreach { item => + writeToStream(Utils.serialize(item)) + } + } writer.close() segments } @@ -356,8 +502,7 @@ object WriteAheadLogSuite { */ def writeDataUsingWriter( filePath: String, - data: Seq[String] - ): Seq[FileBasedWriteAheadLogSegment] = { + data: Seq[String]): Seq[FileBasedWriteAheadLogSegment] = { val writer = new FileBasedWriteAheadLogWriter(filePath, hadoopConf) val segments = data.map { item => writer.write(item) @@ -370,13 +515,13 @@ object WriteAheadLogSuite { def writeDataUsingWriteAheadLog( logDirectory: String, data: Seq[String], + closeFileAfterWrite: Boolean, + allowBatching: Boolean, manualClock: ManualClock = new ManualClock, closeLog: Boolean = true, - clockAdvanceTime: Int = 500, - closeFileAfterWrite: Boolean = false): FileBasedWriteAheadLog = { + clockAdvanceTime: Int = 500): WriteAheadLog = { if (manualClock.getTimeMillis() < 100000) manualClock.setTime(10000) - val wal = new FileBasedWriteAheadLog(new SparkConf(), logDirectory, hadoopConf, 1, 1, - closeFileAfterWrite) + val wal = createWriteAheadLog(logDirectory, closeFileAfterWrite, allowBatching) // Ensure that 500 does not get sorted after 2000, so put a high base value. data.foreach { item => @@ -406,16 +551,16 @@ object WriteAheadLogSuite { } /** Read all the data from a log file directly and return the list of byte buffers. */ - def readDataManually(file: String): Seq[String] = { + def readDataManually[T](file: String): Seq[T] = { val reader = HdfsUtils.getInputStream(file, hadoopConf) - val buffer = new ArrayBuffer[String] + val buffer = new ArrayBuffer[T] try { while (true) { // Read till EOF is thrown val length = reader.readInt() val bytes = new Array[Byte](length) reader.read(bytes) - buffer += Utils.deserialize[String](bytes) + buffer += Utils.deserialize[T](bytes) } } catch { case ex: EOFException => @@ -434,15 +579,17 @@ object WriteAheadLogSuite { } /** Read all the data in the log file in a directory using the WriteAheadLog class. */ - def readDataUsingWriteAheadLog(logDirectory: String): Seq[String] = { - val wal = new FileBasedWriteAheadLog(new SparkConf(), logDirectory, hadoopConf, 1, 1, - closeFileAfterWrite = false) + def readDataUsingWriteAheadLog( + logDirectory: String, + closeFileAfterWrite: Boolean, + allowBatching: Boolean): Seq[String] = { + val wal = createWriteAheadLog(logDirectory, closeFileAfterWrite, allowBatching) val data = wal.readAll().asScala.map(byteBufferToString).toSeq wal.close() data } - /** Get the log files in a direction */ + /** Get the log files in a directory. */ def getLogFilesInDirectory(directory: String): Seq[String] = { val logDirectoryPath = new Path(directory) val fileSystem = HdfsUtils.getFileSystemForPath(logDirectoryPath, hadoopConf) @@ -458,10 +605,31 @@ object WriteAheadLogSuite { } } + def createWriteAheadLog( + logDirectory: String, + closeFileAfterWrite: Boolean, + allowBatching: Boolean): WriteAheadLog = { + val sparkConf = new SparkConf + val wal = new FileBasedWriteAheadLog(sparkConf, logDirectory, hadoopConf, 1, 1, + closeFileAfterWrite) + if (allowBatching) new BatchedWriteAheadLog(wal, sparkConf) else wal + } + def generateRandomData(): Seq[String] = { (1 to 100).map { _.toString } } + def readAndDeserializeDataManually(logFiles: Seq[String], allowBatching: Boolean): Seq[String] = { + if (allowBatching) { + logFiles.flatMap { file => + val data = readDataManually[Array[Array[Byte]]](file) + data.flatMap(byteArray => byteArray.map(Utils.deserialize[String])) + } + } else { + logFiles.flatMap { file => readDataManually[String](file)} + } + } + implicit def stringToByteBuffer(str: String): ByteBuffer = { ByteBuffer.wrap(Utils.serialize(str)) } @@ -469,4 +637,8 @@ object WriteAheadLogSuite { implicit def byteBufferToString(byteBuffer: ByteBuffer): String = { Utils.deserialize[String](byteBuffer.array) } + + def wrapArrayArrayByte[T](records: Array[T]): ByteBuffer = { + ByteBuffer.wrap(Utils.serialize[Array[Array[Byte]]](records.map(Utils.serialize[T]))) + } } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogUtilsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogUtilsSuite.scala new file mode 100644 index 0000000000000..9152728191ea1 --- /dev/null +++ b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogUtilsSuite.scala @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.util + +import java.nio.ByteBuffer +import java.util + +import scala.reflect.ClassTag + +import org.apache.hadoop.conf.Configuration + +import org.apache.spark.{SparkException, SparkConf, SparkFunSuite} +import org.apache.spark.util.Utils + +class WriteAheadLogUtilsSuite extends SparkFunSuite { + import WriteAheadLogUtilsSuite._ + + private val logDir = Utils.createTempDir().getAbsolutePath() + private val hadoopConf = new Configuration() + + def assertDriverLogClass[T <: WriteAheadLog: ClassTag]( + conf: SparkConf, + isBatched: Boolean = false): WriteAheadLog = { + val log = WriteAheadLogUtils.createLogForDriver(conf, logDir, hadoopConf) + if (isBatched) { + assert(log.isInstanceOf[BatchedWriteAheadLog]) + val parentLog = log.asInstanceOf[BatchedWriteAheadLog].wrappedLog + assert(parentLog.getClass === implicitly[ClassTag[T]].runtimeClass) + } else { + assert(log.getClass === implicitly[ClassTag[T]].runtimeClass) + } + log + } + + def assertReceiverLogClass[T <: WriteAheadLog: ClassTag](conf: SparkConf): WriteAheadLog = { + val log = WriteAheadLogUtils.createLogForReceiver(conf, logDir, hadoopConf) + assert(log.getClass === implicitly[ClassTag[T]].runtimeClass) + log + } + + test("log selection and creation") { + + val emptyConf = new SparkConf() // no log configuration + assertDriverLogClass[FileBasedWriteAheadLog](emptyConf) + assertReceiverLogClass[FileBasedWriteAheadLog](emptyConf) + + // Verify setting driver WAL class + val driverWALConf = new SparkConf().set("spark.streaming.driver.writeAheadLog.class", + classOf[MockWriteAheadLog0].getName()) + assertDriverLogClass[MockWriteAheadLog0](driverWALConf) + assertReceiverLogClass[FileBasedWriteAheadLog](driverWALConf) + + // Verify setting receiver WAL class + val receiverWALConf = new SparkConf().set("spark.streaming.receiver.writeAheadLog.class", + classOf[MockWriteAheadLog0].getName()) + assertDriverLogClass[FileBasedWriteAheadLog](receiverWALConf) + assertReceiverLogClass[MockWriteAheadLog0](receiverWALConf) + + // Verify setting receiver WAL class with 1-arg constructor + val receiverWALConf2 = new SparkConf().set("spark.streaming.receiver.writeAheadLog.class", + classOf[MockWriteAheadLog1].getName()) + assertReceiverLogClass[MockWriteAheadLog1](receiverWALConf2) + + // Verify failure setting receiver WAL class with 2-arg constructor + intercept[SparkException] { + val receiverWALConf3 = new SparkConf().set("spark.streaming.receiver.writeAheadLog.class", + classOf[MockWriteAheadLog2].getName()) + assertReceiverLogClass[MockWriteAheadLog1](receiverWALConf3) + } + } + + test("wrap WriteAheadLog in BatchedWriteAheadLog when batching is enabled") { + def getBatchedSparkConf: SparkConf = + new SparkConf().set("spark.streaming.driver.writeAheadLog.allowBatching", "true") + + val justBatchingConf = getBatchedSparkConf + assertDriverLogClass[FileBasedWriteAheadLog](justBatchingConf, isBatched = true) + assertReceiverLogClass[FileBasedWriteAheadLog](justBatchingConf) + + // Verify setting driver WAL class + val driverWALConf = getBatchedSparkConf.set("spark.streaming.driver.writeAheadLog.class", + classOf[MockWriteAheadLog0].getName()) + assertDriverLogClass[MockWriteAheadLog0](driverWALConf, isBatched = true) + assertReceiverLogClass[FileBasedWriteAheadLog](driverWALConf) + + // Verify receivers are not wrapped + val receiverWALConf = getBatchedSparkConf.set("spark.streaming.receiver.writeAheadLog.class", + classOf[MockWriteAheadLog0].getName()) + assertDriverLogClass[FileBasedWriteAheadLog](receiverWALConf, isBatched = true) + assertReceiverLogClass[MockWriteAheadLog0](receiverWALConf) + } +} + +object WriteAheadLogUtilsSuite { + + class MockWriteAheadLog0() extends WriteAheadLog { + override def write(record: ByteBuffer, time: Long): WriteAheadLogRecordHandle = { null } + override def read(handle: WriteAheadLogRecordHandle): ByteBuffer = { null } + override def readAll(): util.Iterator[ByteBuffer] = { null } + override def clean(threshTime: Long, waitForCompletion: Boolean): Unit = { } + override def close(): Unit = { } + } + + class MockWriteAheadLog1(val conf: SparkConf) extends MockWriteAheadLog0() + + class MockWriteAheadLog2(val conf: SparkConf, x: Int) extends MockWriteAheadLog0() +} From 1f0f14efe35f986e338ee2cbc1ef2a9ce7395c00 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Mon, 9 Nov 2015 17:38:19 -0800 Subject: [PATCH 67/88] [SPARK-11462][STREAMING] Add JavaStreamingListener Currently, StreamingListener is not Java friendly because it exposes some Scala collections to Java users directly, such as Option, Map. This PR added a Java version of StreamingListener and a bunch of Java friendly classes for Java users. Author: zsxwing Author: Shixiong Zhu Closes #9420 from zsxwing/java-streaming-listener. --- .../api/java/JavaStreamingListener.scala | 168 ++++++++++ .../java/JavaStreamingListenerWrapper.scala | 122 ++++++++ .../JavaStreamingListenerAPISuite.java | 85 +++++ .../JavaStreamingListenerWrapperSuite.scala | 290 ++++++++++++++++++ 4 files changed, 665 insertions(+) create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingListener.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingListenerWrapper.scala create mode 100644 streaming/src/test/java/org/apache/spark/streaming/JavaStreamingListenerAPISuite.java create mode 100644 streaming/src/test/java/org/apache/spark/streaming/api/java/JavaStreamingListenerWrapperSuite.scala diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingListener.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingListener.scala new file mode 100644 index 0000000000000..c86c7101ff6d5 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingListener.scala @@ -0,0 +1,168 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.api.java + +import org.apache.spark.streaming.Time + +/** + * A listener interface for receiving information about an ongoing streaming computation. + */ +private[streaming] class JavaStreamingListener { + + /** Called when a receiver has been started */ + def onReceiverStarted(receiverStarted: JavaStreamingListenerReceiverStarted): Unit = { } + + /** Called when a receiver has reported an error */ + def onReceiverError(receiverError: JavaStreamingListenerReceiverError): Unit = { } + + /** Called when a receiver has been stopped */ + def onReceiverStopped(receiverStopped: JavaStreamingListenerReceiverStopped): Unit = { } + + /** Called when a batch of jobs has been submitted for processing. */ + def onBatchSubmitted(batchSubmitted: JavaStreamingListenerBatchSubmitted): Unit = { } + + /** Called when processing of a batch of jobs has started. */ + def onBatchStarted(batchStarted: JavaStreamingListenerBatchStarted): Unit = { } + + /** Called when processing of a batch of jobs has completed. */ + def onBatchCompleted(batchCompleted: JavaStreamingListenerBatchCompleted): Unit = { } + + /** Called when processing of a job of a batch has started. */ + def onOutputOperationStarted( + outputOperationStarted: JavaStreamingListenerOutputOperationStarted): Unit = { } + + /** Called when processing of a job of a batch has completed. */ + def onOutputOperationCompleted( + outputOperationCompleted: JavaStreamingListenerOutputOperationCompleted): Unit = { } +} + +/** + * Base trait for events related to JavaStreamingListener + */ +private[streaming] sealed trait JavaStreamingListenerEvent + +private[streaming] class JavaStreamingListenerBatchSubmitted(val batchInfo: JavaBatchInfo) + extends JavaStreamingListenerEvent + +private[streaming] class JavaStreamingListenerBatchCompleted(val batchInfo: JavaBatchInfo) + extends JavaStreamingListenerEvent + +private[streaming] class JavaStreamingListenerBatchStarted(val batchInfo: JavaBatchInfo) + extends JavaStreamingListenerEvent + +private[streaming] class JavaStreamingListenerOutputOperationStarted( + val outputOperationInfo: JavaOutputOperationInfo) extends JavaStreamingListenerEvent + +private[streaming] class JavaStreamingListenerOutputOperationCompleted( + val outputOperationInfo: JavaOutputOperationInfo) extends JavaStreamingListenerEvent + +private[streaming] class JavaStreamingListenerReceiverStarted(val receiverInfo: JavaReceiverInfo) + extends JavaStreamingListenerEvent + +private[streaming] class JavaStreamingListenerReceiverError(val receiverInfo: JavaReceiverInfo) + extends JavaStreamingListenerEvent + +private[streaming] class JavaStreamingListenerReceiverStopped(val receiverInfo: JavaReceiverInfo) + extends JavaStreamingListenerEvent + +/** + * Class having information on batches. + * + * @param batchTime Time of the batch + * @param streamIdToInputInfo A map of input stream id to its input info + * @param submissionTime Clock time of when jobs of this batch was submitted to the streaming + * scheduler queue + * @param processingStartTime Clock time of when the first job of this batch started processing. + * `-1` means the batch has not yet started + * @param processingEndTime Clock time of when the last job of this batch finished processing. `-1` + * means the batch has not yet completed. + * @param schedulingDelay Time taken for the first job of this batch to start processing from the + * time this batch was submitted to the streaming scheduler. Essentially, it + * is `processingStartTime` - `submissionTime`. `-1` means the batch has not + * yet started + * @param processingDelay Time taken for the all jobs of this batch to finish processing from the + * time they started processing. Essentially, it is + * `processingEndTime` - `processingStartTime`. `-1` means the batch has not + * yet completed. + * @param totalDelay Time taken for all the jobs of this batch to finish processing from the time + * they were submitted. Essentially, it is `processingDelay` + `schedulingDelay`. + * `-1` means the batch has not yet completed. + * @param numRecords The number of recorders received by the receivers in this batch + * @param outputOperationInfos The output operations in this batch + */ +private[streaming] case class JavaBatchInfo( + batchTime: Time, + streamIdToInputInfo: java.util.Map[Int, JavaStreamInputInfo], + submissionTime: Long, + processingStartTime: Long, + processingEndTime: Long, + schedulingDelay: Long, + processingDelay: Long, + totalDelay: Long, + numRecords: Long, + outputOperationInfos: java.util.Map[Int, JavaOutputOperationInfo]) + +/** + * Track the information of input stream at specified batch time. + * + * @param inputStreamId the input stream id + * @param numRecords the number of records in a batch + * @param metadata metadata for this batch. It should contain at least one standard field named + * "Description" which maps to the content that will be shown in the UI. + * @param metadataDescription description of this input stream + */ +private[streaming] case class JavaStreamInputInfo( + inputStreamId: Int, + numRecords: Long, + metadata: java.util.Map[String, Any], + metadataDescription: String) + +/** + * Class having information about a receiver + */ +private[streaming] case class JavaReceiverInfo( + streamId: Int, + name: String, + active: Boolean, + location: String, + lastErrorMessage: String, + lastError: String, + lastErrorTime: Long) + +/** + * Class having information on output operations. + * + * @param batchTime Time of the batch + * @param id Id of this output operation. Different output operations have different ids in a batch. + * @param name The name of this output operation. + * @param description The description of this output operation. + * @param startTime Clock time of when the output operation started processing. `-1` means the + * output operation has not yet started + * @param endTime Clock time of when the output operation started processing. `-1` means the output + * operation has not yet completed + * @param failureReason Failure reason if this output operation fails. If the output operation is + * successful, this field is `null`. + */ +private[streaming] case class JavaOutputOperationInfo( + batchTime: Time, + id: Int, + name: String, + description: String, + startTime: Long, + endTime: Long, + failureReason: String) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingListenerWrapper.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingListenerWrapper.scala new file mode 100644 index 0000000000000..2c60b396a6616 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingListenerWrapper.scala @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.api.java + +import scala.collection.JavaConverters._ + +import org.apache.spark.streaming.scheduler._ + +/** + * A wrapper to convert a [[JavaStreamingListener]] to a [[StreamingListener]]. + */ +private[streaming] class JavaStreamingListenerWrapper(javaStreamingListener: JavaStreamingListener) + extends StreamingListener { + + private def toJavaReceiverInfo(receiverInfo: ReceiverInfo): JavaReceiverInfo = { + JavaReceiverInfo( + receiverInfo.streamId, + receiverInfo.name, + receiverInfo.active, + receiverInfo.location, + receiverInfo.lastErrorMessage, + receiverInfo.lastError, + receiverInfo.lastErrorTime + ) + } + + private def toJavaStreamInputInfo(streamInputInfo: StreamInputInfo): JavaStreamInputInfo = { + JavaStreamInputInfo( + streamInputInfo.inputStreamId, + streamInputInfo.numRecords: Long, + streamInputInfo.metadata.asJava, + streamInputInfo.metadataDescription.orNull + ) + } + + private def toJavaOutputOperationInfo( + outputOperationInfo: OutputOperationInfo): JavaOutputOperationInfo = { + JavaOutputOperationInfo( + outputOperationInfo.batchTime, + outputOperationInfo.id, + outputOperationInfo.name, + outputOperationInfo.description: String, + outputOperationInfo.startTime.getOrElse(-1), + outputOperationInfo.endTime.getOrElse(-1), + outputOperationInfo.failureReason.orNull + ) + } + + private def toJavaBatchInfo(batchInfo: BatchInfo): JavaBatchInfo = { + JavaBatchInfo( + batchInfo.batchTime, + batchInfo.streamIdToInputInfo.mapValues(toJavaStreamInputInfo(_)).asJava, + batchInfo.submissionTime, + batchInfo.processingStartTime.getOrElse(-1), + batchInfo.processingEndTime.getOrElse(-1), + batchInfo.schedulingDelay.getOrElse(-1), + batchInfo.processingDelay.getOrElse(-1), + batchInfo.totalDelay.getOrElse(-1), + batchInfo.numRecords, + batchInfo.outputOperationInfos.mapValues(toJavaOutputOperationInfo(_)).asJava + ) + } + + override def onReceiverStarted(receiverStarted: StreamingListenerReceiverStarted): Unit = { + javaStreamingListener.onReceiverStarted( + new JavaStreamingListenerReceiverStarted(toJavaReceiverInfo(receiverStarted.receiverInfo))) + } + + override def onReceiverError(receiverError: StreamingListenerReceiverError): Unit = { + javaStreamingListener.onReceiverError( + new JavaStreamingListenerReceiverError(toJavaReceiverInfo(receiverError.receiverInfo))) + } + + override def onReceiverStopped(receiverStopped: StreamingListenerReceiverStopped): Unit = { + javaStreamingListener.onReceiverStopped( + new JavaStreamingListenerReceiverStopped(toJavaReceiverInfo(receiverStopped.receiverInfo))) + } + + override def onBatchSubmitted(batchSubmitted: StreamingListenerBatchSubmitted): Unit = { + javaStreamingListener.onBatchSubmitted( + new JavaStreamingListenerBatchSubmitted(toJavaBatchInfo(batchSubmitted.batchInfo))) + } + + override def onBatchStarted(batchStarted: StreamingListenerBatchStarted): Unit = { + javaStreamingListener.onBatchStarted( + new JavaStreamingListenerBatchStarted(toJavaBatchInfo(batchStarted.batchInfo))) + } + + override def onBatchCompleted(batchCompleted: StreamingListenerBatchCompleted): Unit = { + javaStreamingListener.onBatchCompleted( + new JavaStreamingListenerBatchCompleted(toJavaBatchInfo(batchCompleted.batchInfo))) + } + + override def onOutputOperationStarted( + outputOperationStarted: StreamingListenerOutputOperationStarted): Unit = { + javaStreamingListener.onOutputOperationStarted(new JavaStreamingListenerOutputOperationStarted( + toJavaOutputOperationInfo(outputOperationStarted.outputOperationInfo))) + } + + override def onOutputOperationCompleted( + outputOperationCompleted: StreamingListenerOutputOperationCompleted): Unit = { + javaStreamingListener.onOutputOperationCompleted( + new JavaStreamingListenerOutputOperationCompleted( + toJavaOutputOperationInfo(outputOperationCompleted.outputOperationInfo))) + } + +} diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaStreamingListenerAPISuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaStreamingListenerAPISuite.java new file mode 100644 index 0000000000000..8cc285aa7fb34 --- /dev/null +++ b/streaming/src/test/java/org/apache/spark/streaming/JavaStreamingListenerAPISuite.java @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +package org.apache.spark.streaming; + +import org.apache.spark.streaming.api.java.*; + +public class JavaStreamingListenerAPISuite extends JavaStreamingListener { + + @Override + public void onReceiverStarted(JavaStreamingListenerReceiverStarted receiverStarted) { + JavaReceiverInfo receiverInfo = receiverStarted.receiverInfo(); + receiverInfo.streamId(); + receiverInfo.name(); + receiverInfo.active(); + receiverInfo.location(); + receiverInfo.lastErrorMessage(); + receiverInfo.lastError(); + receiverInfo.lastErrorTime(); + } + + @Override + public void onReceiverError(JavaStreamingListenerReceiverError receiverError) { + JavaReceiverInfo receiverInfo = receiverError.receiverInfo(); + receiverInfo.streamId(); + receiverInfo.name(); + receiverInfo.active(); + receiverInfo.location(); + receiverInfo.lastErrorMessage(); + receiverInfo.lastError(); + receiverInfo.lastErrorTime(); + } + + @Override + public void onReceiverStopped(JavaStreamingListenerReceiverStopped receiverStopped) { + JavaReceiverInfo receiverInfo = receiverStopped.receiverInfo(); + receiverInfo.streamId(); + receiverInfo.name(); + receiverInfo.active(); + receiverInfo.location(); + receiverInfo.lastErrorMessage(); + receiverInfo.lastError(); + receiverInfo.lastErrorTime(); + } + + @Override + public void onBatchSubmitted(JavaStreamingListenerBatchSubmitted batchSubmitted) { + super.onBatchSubmitted(batchSubmitted); + } + + @Override + public void onBatchStarted(JavaStreamingListenerBatchStarted batchStarted) { + super.onBatchStarted(batchStarted); + } + + @Override + public void onBatchCompleted(JavaStreamingListenerBatchCompleted batchCompleted) { + super.onBatchCompleted(batchCompleted); + } + + @Override + public void onOutputOperationStarted(JavaStreamingListenerOutputOperationStarted outputOperationStarted) { + super.onOutputOperationStarted(outputOperationStarted); + } + + @Override + public void onOutputOperationCompleted(JavaStreamingListenerOutputOperationCompleted outputOperationCompleted) { + super.onOutputOperationCompleted(outputOperationCompleted); + } +} diff --git a/streaming/src/test/java/org/apache/spark/streaming/api/java/JavaStreamingListenerWrapperSuite.scala b/streaming/src/test/java/org/apache/spark/streaming/api/java/JavaStreamingListenerWrapperSuite.scala new file mode 100644 index 0000000000000..6d6d61e70cafc --- /dev/null +++ b/streaming/src/test/java/org/apache/spark/streaming/api/java/JavaStreamingListenerWrapperSuite.scala @@ -0,0 +1,290 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.api.java + +import scala.collection.JavaConverters._ + +import org.apache.spark.SparkFunSuite +import org.apache.spark.streaming.Time +import org.apache.spark.streaming.scheduler._ + +class JavaStreamingListenerWrapperSuite extends SparkFunSuite { + + test("basic") { + val listener = new TestJavaStreamingListener() + val listenerWrapper = new JavaStreamingListenerWrapper(listener) + + val receiverStarted = StreamingListenerReceiverStarted(ReceiverInfo( + streamId = 2, + name = "test", + active = true, + location = "localhost" + )) + listenerWrapper.onReceiverStarted(receiverStarted) + assertReceiverInfo(listener.receiverStarted.receiverInfo, receiverStarted.receiverInfo) + + val receiverStopped = StreamingListenerReceiverStopped(ReceiverInfo( + streamId = 2, + name = "test", + active = false, + location = "localhost" + )) + listenerWrapper.onReceiverStopped(receiverStopped) + assertReceiverInfo(listener.receiverStopped.receiverInfo, receiverStopped.receiverInfo) + + val receiverError = StreamingListenerReceiverError(ReceiverInfo( + streamId = 2, + name = "test", + active = false, + location = "localhost", + lastErrorMessage = "failed", + lastError = "failed", + lastErrorTime = System.currentTimeMillis() + )) + listenerWrapper.onReceiverError(receiverError) + assertReceiverInfo(listener.receiverError.receiverInfo, receiverError.receiverInfo) + + val batchSubmitted = StreamingListenerBatchSubmitted(BatchInfo( + batchTime = Time(1000L), + streamIdToInputInfo = Map( + 0 -> StreamInputInfo( + inputStreamId = 0, + numRecords = 1000, + metadata = Map(StreamInputInfo.METADATA_KEY_DESCRIPTION -> "receiver1")), + 1 -> StreamInputInfo( + inputStreamId = 1, + numRecords = 2000, + metadata = Map(StreamInputInfo.METADATA_KEY_DESCRIPTION -> "receiver2"))), + submissionTime = 1001L, + None, + None, + outputOperationInfos = Map( + 0 -> OutputOperationInfo( + batchTime = Time(1000L), + id = 0, + name = "op1", + description = "operation1", + startTime = None, + endTime = None, + failureReason = None), + 1 -> OutputOperationInfo( + batchTime = Time(1000L), + id = 1, + name = "op2", + description = "operation2", + startTime = None, + endTime = None, + failureReason = None)) + )) + listenerWrapper.onBatchSubmitted(batchSubmitted) + assertBatchInfo(listener.batchSubmitted.batchInfo, batchSubmitted.batchInfo) + + val batchStarted = StreamingListenerBatchStarted(BatchInfo( + batchTime = Time(1000L), + streamIdToInputInfo = Map( + 0 -> StreamInputInfo( + inputStreamId = 0, + numRecords = 1000, + metadata = Map(StreamInputInfo.METADATA_KEY_DESCRIPTION -> "receiver1")), + 1 -> StreamInputInfo( + inputStreamId = 1, + numRecords = 2000, + metadata = Map(StreamInputInfo.METADATA_KEY_DESCRIPTION -> "receiver2"))), + submissionTime = 1001L, + Some(1002L), + None, + outputOperationInfos = Map( + 0 -> OutputOperationInfo( + batchTime = Time(1000L), + id = 0, + name = "op1", + description = "operation1", + startTime = Some(1003L), + endTime = None, + failureReason = None), + 1 -> OutputOperationInfo( + batchTime = Time(1000L), + id = 1, + name = "op2", + description = "operation2", + startTime = Some(1005L), + endTime = None, + failureReason = None)) + )) + listenerWrapper.onBatchStarted(batchStarted) + assertBatchInfo(listener.batchStarted.batchInfo, batchStarted.batchInfo) + + val batchCompleted = StreamingListenerBatchCompleted(BatchInfo( + batchTime = Time(1000L), + streamIdToInputInfo = Map( + 0 -> StreamInputInfo( + inputStreamId = 0, + numRecords = 1000, + metadata = Map(StreamInputInfo.METADATA_KEY_DESCRIPTION -> "receiver1")), + 1 -> StreamInputInfo( + inputStreamId = 1, + numRecords = 2000, + metadata = Map(StreamInputInfo.METADATA_KEY_DESCRIPTION -> "receiver2"))), + submissionTime = 1001L, + Some(1002L), + Some(1010L), + outputOperationInfos = Map( + 0 -> OutputOperationInfo( + batchTime = Time(1000L), + id = 0, + name = "op1", + description = "operation1", + startTime = Some(1003L), + endTime = Some(1004L), + failureReason = None), + 1 -> OutputOperationInfo( + batchTime = Time(1000L), + id = 1, + name = "op2", + description = "operation2", + startTime = Some(1005L), + endTime = Some(1010L), + failureReason = None)) + )) + listenerWrapper.onBatchCompleted(batchCompleted) + assertBatchInfo(listener.batchCompleted.batchInfo, batchCompleted.batchInfo) + + val outputOperationStarted = StreamingListenerOutputOperationStarted(OutputOperationInfo( + batchTime = Time(1000L), + id = 0, + name = "op1", + description = "operation1", + startTime = Some(1003L), + endTime = None, + failureReason = None + )) + listenerWrapper.onOutputOperationStarted(outputOperationStarted) + assertOutputOperationInfo(listener.outputOperationStarted.outputOperationInfo, + outputOperationStarted.outputOperationInfo) + + val outputOperationCompleted = StreamingListenerOutputOperationCompleted(OutputOperationInfo( + batchTime = Time(1000L), + id = 0, + name = "op1", + description = "operation1", + startTime = Some(1003L), + endTime = Some(1004L), + failureReason = None + )) + listenerWrapper.onOutputOperationCompleted(outputOperationCompleted) + assertOutputOperationInfo(listener.outputOperationCompleted.outputOperationInfo, + outputOperationCompleted.outputOperationInfo) + } + + private def assertReceiverInfo( + javaReceiverInfo: JavaReceiverInfo, receiverInfo: ReceiverInfo): Unit = { + assert(javaReceiverInfo.streamId === receiverInfo.streamId) + assert(javaReceiverInfo.name === receiverInfo.name) + assert(javaReceiverInfo.active === receiverInfo.active) + assert(javaReceiverInfo.location === receiverInfo.location) + assert(javaReceiverInfo.lastErrorMessage === receiverInfo.lastErrorMessage) + assert(javaReceiverInfo.lastError === receiverInfo.lastError) + assert(javaReceiverInfo.lastErrorTime === receiverInfo.lastErrorTime) + } + + private def assertBatchInfo(javaBatchInfo: JavaBatchInfo, batchInfo: BatchInfo): Unit = { + assert(javaBatchInfo.batchTime === batchInfo.batchTime) + assert(javaBatchInfo.streamIdToInputInfo.size === batchInfo.streamIdToInputInfo.size) + batchInfo.streamIdToInputInfo.foreach { case (streamId, streamInputInfo) => + assertStreamingInfo(javaBatchInfo.streamIdToInputInfo.get(streamId), streamInputInfo) + } + assert(javaBatchInfo.submissionTime === batchInfo.submissionTime) + assert(javaBatchInfo.processingStartTime === batchInfo.processingStartTime.getOrElse(-1)) + assert(javaBatchInfo.processingEndTime === batchInfo.processingEndTime.getOrElse(-1)) + assert(javaBatchInfo.schedulingDelay === batchInfo.schedulingDelay.getOrElse(-1)) + assert(javaBatchInfo.processingDelay === batchInfo.processingDelay.getOrElse(-1)) + assert(javaBatchInfo.totalDelay === batchInfo.totalDelay.getOrElse(-1)) + assert(javaBatchInfo.numRecords === batchInfo.numRecords) + assert(javaBatchInfo.outputOperationInfos.size === batchInfo.outputOperationInfos.size) + batchInfo.outputOperationInfos.foreach { case (outputOperationId, outputOperationInfo) => + assertOutputOperationInfo( + javaBatchInfo.outputOperationInfos.get(outputOperationId), outputOperationInfo) + } + } + + private def assertStreamingInfo( + javaStreamInputInfo: JavaStreamInputInfo, streamInputInfo: StreamInputInfo): Unit = { + assert(javaStreamInputInfo.inputStreamId === streamInputInfo.inputStreamId) + assert(javaStreamInputInfo.numRecords === streamInputInfo.numRecords) + assert(javaStreamInputInfo.metadata === streamInputInfo.metadata.asJava) + assert(javaStreamInputInfo.metadataDescription === streamInputInfo.metadataDescription.orNull) + } + + private def assertOutputOperationInfo( + javaOutputOperationInfo: JavaOutputOperationInfo, + outputOperationInfo: OutputOperationInfo): Unit = { + assert(javaOutputOperationInfo.batchTime === outputOperationInfo.batchTime) + assert(javaOutputOperationInfo.id === outputOperationInfo.id) + assert(javaOutputOperationInfo.name === outputOperationInfo.name) + assert(javaOutputOperationInfo.description === outputOperationInfo.description) + assert(javaOutputOperationInfo.startTime === outputOperationInfo.startTime.getOrElse(-1)) + assert(javaOutputOperationInfo.endTime === outputOperationInfo.endTime.getOrElse(-1)) + assert(javaOutputOperationInfo.failureReason === outputOperationInfo.failureReason.orNull) + } +} + +class TestJavaStreamingListener extends JavaStreamingListener { + + var receiverStarted: JavaStreamingListenerReceiverStarted = null + var receiverError: JavaStreamingListenerReceiverError = null + var receiverStopped: JavaStreamingListenerReceiverStopped = null + var batchSubmitted: JavaStreamingListenerBatchSubmitted = null + var batchStarted: JavaStreamingListenerBatchStarted = null + var batchCompleted: JavaStreamingListenerBatchCompleted = null + var outputOperationStarted: JavaStreamingListenerOutputOperationStarted = null + var outputOperationCompleted: JavaStreamingListenerOutputOperationCompleted = null + + override def onReceiverStarted(receiverStarted: JavaStreamingListenerReceiverStarted): Unit = { + this.receiverStarted = receiverStarted + } + + override def onReceiverError(receiverError: JavaStreamingListenerReceiverError): Unit = { + this.receiverError = receiverError + } + + override def onReceiverStopped(receiverStopped: JavaStreamingListenerReceiverStopped): Unit = { + this.receiverStopped = receiverStopped + } + + override def onBatchSubmitted(batchSubmitted: JavaStreamingListenerBatchSubmitted): Unit = { + this.batchSubmitted = batchSubmitted + } + + override def onBatchStarted(batchStarted: JavaStreamingListenerBatchStarted): Unit = { + this.batchStarted = batchStarted + } + + override def onBatchCompleted(batchCompleted: JavaStreamingListenerBatchCompleted): Unit = { + this.batchCompleted = batchCompleted + } + + override def onOutputOperationStarted( + outputOperationStarted: JavaStreamingListenerOutputOperationStarted): Unit = { + this.outputOperationStarted = outputOperationStarted + } + + override def onOutputOperationCompleted( + outputOperationCompleted: JavaStreamingListenerOutputOperationCompleted): Unit = { + this.outputOperationCompleted = outputOperationCompleted + } +} From 6502944f39893b9dfb472f8406d5f3a02a316eff Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Mon, 9 Nov 2015 18:13:37 -0800 Subject: [PATCH 68/88] [SPARK-11333][STREAMING] Add executorId to ReceiverInfo and display it in UI Expose executorId to `ReceiverInfo` and UI since it's helpful when there are multiple executors running in the same host. Screenshot: screen shot 2015-11-02 at 10 52 19 am Author: Shixiong Zhu Author: zsxwing Closes #9418 from zsxwing/SPARK-11333. --- .../spark/streaming/api/java/JavaStreamingListener.scala | 1 + .../streaming/api/java/JavaStreamingListenerWrapper.scala | 1 + .../apache/spark/streaming/scheduler/ReceiverInfo.scala | 1 + .../spark/streaming/scheduler/ReceiverTrackingInfo.scala | 1 + .../org/apache/spark/streaming/ui/StreamingPage.scala | 8 ++++++-- .../spark/streaming/JavaStreamingListenerAPISuite.java | 3 +++ .../api/java/JavaStreamingListenerWrapperSuite.scala | 8 ++++++-- .../streaming/ui/StreamingJobProgressListenerSuite.scala | 6 +++--- 8 files changed, 22 insertions(+), 7 deletions(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingListener.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingListener.scala index c86c7101ff6d5..34429074fe804 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingListener.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingListener.scala @@ -140,6 +140,7 @@ private[streaming] case class JavaReceiverInfo( name: String, active: Boolean, location: String, + executorId: String, lastErrorMessage: String, lastError: String, lastErrorTime: Long) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingListenerWrapper.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingListenerWrapper.scala index 2c60b396a6616..b109b9f1cbeae 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingListenerWrapper.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingListenerWrapper.scala @@ -33,6 +33,7 @@ private[streaming] class JavaStreamingListenerWrapper(javaStreamingListener: Jav receiverInfo.name, receiverInfo.active, receiverInfo.location, + receiverInfo.executorId, receiverInfo.lastErrorMessage, receiverInfo.lastError, receiverInfo.lastErrorTime diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverInfo.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverInfo.scala index 59df892397fe0..3b35964114c02 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverInfo.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverInfo.scala @@ -30,6 +30,7 @@ case class ReceiverInfo( name: String, active: Boolean, location: String, + executorId: String, lastErrorMessage: String = "", lastError: String = "", lastErrorTime: Long = -1L diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTrackingInfo.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTrackingInfo.scala index ab0a84f05214d..4dc5bb9c3bfbe 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTrackingInfo.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTrackingInfo.scala @@ -49,6 +49,7 @@ private[streaming] case class ReceiverTrackingInfo( name.getOrElse(""), state == ReceiverState.ACTIVE, location = runningExecutor.map(_.host).getOrElse(""), + executorId = runningExecutor.map(_.executorId).getOrElse(""), lastErrorMessage = errorInfo.map(_.lastErrorMessage).getOrElse(""), lastError = errorInfo.map(_.lastError).getOrElse(""), lastErrorTime = errorInfo.map(_.lastErrorTime).getOrElse(-1L) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala index 96d943e75d272..4588b2163cd44 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala @@ -402,7 +402,7 @@ private[ui] class StreamingPage(parent: StreamingTab)
Status
-
Location
+
Executor ID / Host
Last Error Time
Last Error Message @@ -430,7 +430,11 @@ private[ui] class StreamingPage(parent: StreamingTab) val receiverActive = receiverInfo.map { info => if (info.active) "ACTIVE" else "INACTIVE" }.getOrElse(emptyCell) - val receiverLocation = receiverInfo.map(_.location).getOrElse(emptyCell) + val receiverLocation = receiverInfo.map { info => + val executorId = if (info.executorId.isEmpty) emptyCell else info.executorId + val location = if (info.location.isEmpty) emptyCell else info.location + s"$executorId / $location" + }.getOrElse(emptyCell) val receiverLastError = receiverInfo.map { info => val msg = s"${info.lastErrorMessage} - ${info.lastError}" if (msg.size > 100) msg.take(97) + "..." else msg diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaStreamingListenerAPISuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaStreamingListenerAPISuite.java index 8cc285aa7fb34..67b2a0703e02b 100644 --- a/streaming/src/test/java/org/apache/spark/streaming/JavaStreamingListenerAPISuite.java +++ b/streaming/src/test/java/org/apache/spark/streaming/JavaStreamingListenerAPISuite.java @@ -29,6 +29,7 @@ public void onReceiverStarted(JavaStreamingListenerReceiverStarted receiverStart receiverInfo.name(); receiverInfo.active(); receiverInfo.location(); + receiverInfo.executorId(); receiverInfo.lastErrorMessage(); receiverInfo.lastError(); receiverInfo.lastErrorTime(); @@ -41,6 +42,7 @@ public void onReceiverError(JavaStreamingListenerReceiverError receiverError) { receiverInfo.name(); receiverInfo.active(); receiverInfo.location(); + receiverInfo.executorId(); receiverInfo.lastErrorMessage(); receiverInfo.lastError(); receiverInfo.lastErrorTime(); @@ -53,6 +55,7 @@ public void onReceiverStopped(JavaStreamingListenerReceiverStopped receiverStopp receiverInfo.name(); receiverInfo.active(); receiverInfo.location(); + receiverInfo.executorId(); receiverInfo.lastErrorMessage(); receiverInfo.lastError(); receiverInfo.lastErrorTime(); diff --git a/streaming/src/test/java/org/apache/spark/streaming/api/java/JavaStreamingListenerWrapperSuite.scala b/streaming/src/test/java/org/apache/spark/streaming/api/java/JavaStreamingListenerWrapperSuite.scala index 6d6d61e70cafc..0295e059f7bc2 100644 --- a/streaming/src/test/java/org/apache/spark/streaming/api/java/JavaStreamingListenerWrapperSuite.scala +++ b/streaming/src/test/java/org/apache/spark/streaming/api/java/JavaStreamingListenerWrapperSuite.scala @@ -33,7 +33,8 @@ class JavaStreamingListenerWrapperSuite extends SparkFunSuite { streamId = 2, name = "test", active = true, - location = "localhost" + location = "localhost", + executorId = "1" )) listenerWrapper.onReceiverStarted(receiverStarted) assertReceiverInfo(listener.receiverStarted.receiverInfo, receiverStarted.receiverInfo) @@ -42,7 +43,8 @@ class JavaStreamingListenerWrapperSuite extends SparkFunSuite { streamId = 2, name = "test", active = false, - location = "localhost" + location = "localhost", + executorId = "1" )) listenerWrapper.onReceiverStopped(receiverStopped) assertReceiverInfo(listener.receiverStopped.receiverInfo, receiverStopped.receiverInfo) @@ -52,6 +54,7 @@ class JavaStreamingListenerWrapperSuite extends SparkFunSuite { name = "test", active = false, location = "localhost", + executorId = "1", lastErrorMessage = "failed", lastError = "failed", lastErrorTime = System.currentTimeMillis() @@ -197,6 +200,7 @@ class JavaStreamingListenerWrapperSuite extends SparkFunSuite { assert(javaReceiverInfo.name === receiverInfo.name) assert(javaReceiverInfo.active === receiverInfo.active) assert(javaReceiverInfo.location === receiverInfo.location) + assert(javaReceiverInfo.executorId === receiverInfo.executorId) assert(javaReceiverInfo.lastErrorMessage === receiverInfo.lastErrorMessage) assert(javaReceiverInfo.lastError === receiverInfo.lastError) assert(javaReceiverInfo.lastErrorTime === receiverInfo.lastErrorTime) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala index af4718b4eb705..34cd7435569e1 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ui/StreamingJobProgressListenerSuite.scala @@ -130,20 +130,20 @@ class StreamingJobProgressListenerSuite extends TestSuiteBase with Matchers { listener.numTotalReceivedRecords should be (600) // onReceiverStarted - val receiverInfoStarted = ReceiverInfo(0, "test", true, "localhost") + val receiverInfoStarted = ReceiverInfo(0, "test", true, "localhost", "0") listener.onReceiverStarted(StreamingListenerReceiverStarted(receiverInfoStarted)) listener.receiverInfo(0) should be (Some(receiverInfoStarted)) listener.receiverInfo(1) should be (None) // onReceiverError - val receiverInfoError = ReceiverInfo(1, "test", true, "localhost") + val receiverInfoError = ReceiverInfo(1, "test", true, "localhost", "1") listener.onReceiverError(StreamingListenerReceiverError(receiverInfoError)) listener.receiverInfo(0) should be (Some(receiverInfoStarted)) listener.receiverInfo(1) should be (Some(receiverInfoError)) listener.receiverInfo(2) should be (None) // onReceiverStopped - val receiverInfoStopped = ReceiverInfo(2, "test", true, "localhost") + val receiverInfoStopped = ReceiverInfo(2, "test", true, "localhost", "2") listener.onReceiverStopped(StreamingListenerReceiverStopped(receiverInfoStopped)) listener.receiverInfo(0) should be (Some(receiverInfoStarted)) listener.receiverInfo(1) should be (Some(receiverInfoError)) From 1431319e5bc46c7225a8edeeec482816d14a83b8 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Mon, 9 Nov 2015 18:53:57 -0800 Subject: [PATCH 69/88] Add mockito as an explicit test dependency to spark-streaming While sbt successfully compiles as it properly pulls the mockito dependency, maven builds have broken. We need this in ASAP. tdas Author: Burak Yavuz Closes #9584 from brkyvz/fix-master. --- streaming/pom.xml | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/streaming/pom.xml b/streaming/pom.xml index 145c8a7321c05..435e16db13ab4 100644 --- a/streaming/pom.xml +++ b/streaming/pom.xml @@ -93,6 +93,11 @@ selenium-java test
+ + org.mockito + mockito-core + test + target/scala-${scala.binary.version}/classes From c4e19b3819df4cd7a1c495a00bd2844cf55f4dbd Mon Sep 17 00:00:00 2001 From: Shivaram Venkataraman Date: Mon, 9 Nov 2015 21:06:01 -0800 Subject: [PATCH 70/88] [SPARK-11587][SPARKR] Fix the summary generic to match base R The signature is summary(object, ...) as defined in https://stat.ethz.ch/R-manual/R-devel/library/base/html/summary.html Author: Shivaram Venkataraman Closes #9582 from shivaram/summary-fix. --- R/pkg/R/DataFrame.R | 6 +++--- R/pkg/R/generics.R | 2 +- R/pkg/R/mllib.R | 12 ++++++------ R/pkg/inst/tests/test_mllib.R | 6 ++++++ 4 files changed, 16 insertions(+), 10 deletions(-) diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 44ce9414da5cf..e9013aa34a84f 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -1944,9 +1944,9 @@ setMethod("describe", #' @rdname summary #' @name summary setMethod("summary", - signature(x = "DataFrame"), - function(x) { - describe(x) + signature(object = "DataFrame"), + function(object, ...) { + describe(object) }) diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 083d37fee28a4..efef7d66b522c 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -561,7 +561,7 @@ setGeneric("summarize", function(x,...) { standardGeneric("summarize") }) #' @rdname summary #' @export -setGeneric("summary", function(x, ...) { standardGeneric("summary") }) +setGeneric("summary", function(object, ...) { standardGeneric("summary") }) # @rdname tojson # @export diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R index 7ff859741b4a0..7126b7cde4bd7 100644 --- a/R/pkg/R/mllib.R +++ b/R/pkg/R/mllib.R @@ -89,17 +89,17 @@ setMethod("predict", signature(object = "PipelineModel"), #' model <- glm(y ~ x, trainingData) #' summary(model) #'} -setMethod("summary", signature(x = "PipelineModel"), - function(x, ...) { +setMethod("summary", signature(object = "PipelineModel"), + function(object, ...) { modelName <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", - "getModelName", x@model) + "getModelName", object@model) features <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", - "getModelFeatures", x@model) + "getModelFeatures", object@model) coefficients <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", - "getModelCoefficients", x@model) + "getModelCoefficients", object@model) if (modelName == "LinearRegressionModel") { devianceResiduals <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", - "getModelDevianceResiduals", x@model) + "getModelDevianceResiduals", object@model) devianceResiduals <- matrix(devianceResiduals, nrow = 1) colnames(devianceResiduals) <- c("Min", "Max") rownames(devianceResiduals) <- rep("", times = 1) diff --git a/R/pkg/inst/tests/test_mllib.R b/R/pkg/inst/tests/test_mllib.R index 2606407bdcb44..42287ea19adc5 100644 --- a/R/pkg/inst/tests/test_mllib.R +++ b/R/pkg/inst/tests/test_mllib.R @@ -113,3 +113,9 @@ test_that("summary coefficients match with native glm of family 'binomial'", { rownames(stats$Coefficients) == c("(Intercept)", "Sepal_Length", "Sepal_Width"))) }) + +test_that("summary works on base GLM models", { + baseModel <- stats::glm(Sepal.Width ~ Sepal.Length + Species, data = iris) + baseSummary <- summary(baseModel) + expect_true(abs(baseSummary$deviance - 12.19313) < 1e-4) +}) From d6cd3a18e720e8f6f1f307e0dffad3512952d997 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Mon, 9 Nov 2015 23:27:36 -0800 Subject: [PATCH 71/88] [SPARK-11599] [SQL] fix NPE when resolve Hive UDF in SQLParser The DataFrame APIs that takes a SQL expression always use SQLParser, then the HiveFunctionRegistry will called outside of Hive state, cause NPE if there is not a active Session State for current thread (in PySpark). cc rxin yhuai Author: Davies Liu Closes #9576 from davies/hive_udf. --- .../apache/spark/sql/hive/HiveContext.scala | 10 +++++- .../sql/hive/execution/HiveQuerySuite.scala | 33 ++++++++++++++----- 2 files changed, 34 insertions(+), 9 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 2d72b959af134..c5f69657f5293 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -454,7 +454,15 @@ class HiveContext private[hive]( // Note that HiveUDFs will be overridden by functions registered in this context. @transient override protected[sql] lazy val functionRegistry: FunctionRegistry = - new HiveFunctionRegistry(FunctionRegistry.builtin.copy()) + new HiveFunctionRegistry(FunctionRegistry.builtin.copy()) { + override def lookupFunction(name: String, children: Seq[Expression]): Expression = { + // Hive Registry need current database to lookup function + // TODO: the current database of executionHive should be consistent with metadataHive + executionHive.withHiveState { + super.lookupFunction(name, children) + } + } + } // The Hive UDF current_database() is foldable, will be evaluated by optimizer, but the optimizer // can't access the SessionState of metadataHive. diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 78378c8b69c7a..f0a7a6cc7a1e3 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -20,22 +20,19 @@ package org.apache.spark.sql.hive.execution import java.io.File import java.util.{Locale, TimeZone} -import org.apache.spark.sql.execution.joins.BroadcastNestedLoopJoin - import scala.util.Try -import org.scalatest.BeforeAndAfter - import org.apache.hadoop.hive.conf.HiveConf.ConfVars +import org.scalatest.BeforeAndAfter -import org.apache.spark.{SparkFiles, SparkException} -import org.apache.spark.sql.{AnalysisException, DataFrame, Row} import org.apache.spark.sql.catalyst.expressions.Cast import org.apache.spark.sql.catalyst.plans.logical.Project +import org.apache.spark.sql.execution.joins.BroadcastNestedLoopJoin import org.apache.spark.sql.hive._ -import org.apache.spark.sql.hive.test.TestHiveContext -import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ +import org.apache.spark.sql.hive.test.{TestHive, TestHiveContext} +import org.apache.spark.sql.{AnalysisException, DataFrame, Row} +import org.apache.spark.{SparkException, SparkFiles} case class TestData(a: Int, b: String) @@ -1237,6 +1234,26 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { } + test("lookup hive UDF in another thread") { + val e = intercept[AnalysisException] { + range(1).selectExpr("not_a_udf()") + } + assert(e.getMessage.contains("undefined function not_a_udf")) + var success = false + val t = new Thread("test") { + override def run(): Unit = { + val e = intercept[AnalysisException] { + range(1).selectExpr("not_a_udf()") + } + assert(e.getMessage.contains("undefined function not_a_udf")) + success = true + } + } + t.start() + t.join() + assert(success) + } + createQueryTest("select from thrift based table", "SELECT * from src_thrift") From 521b3cae118d1e22c170e2aad43f9baa162db55e Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Mon, 9 Nov 2015 23:28:32 -0800 Subject: [PATCH 72/88] [SPARK-11598] [SQL] enable tests for ShuffledHashOuterJoin Author: Davies Liu Closes #9573 from davies/join_condition. --- .../org/apache/spark/sql/JoinSuite.scala | 435 ++++++++++-------- 1 file changed, 231 insertions(+), 204 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index a9ca46cab067d..3f3b837f7581c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -237,214 +237,241 @@ class JoinSuite extends QueryTest with SharedSQLContext { Row(2, 2, 2, 2) :: Nil) } - test("left outer join") { - checkAnswer( - upperCaseData.join(lowerCaseData, $"n" === $"N", "left"), - Row(1, "A", 1, "a") :: - Row(2, "B", 2, "b") :: - Row(3, "C", 3, "c") :: - Row(4, "D", 4, "d") :: - Row(5, "E", null, null) :: - Row(6, "F", null, null) :: Nil) - - checkAnswer( - upperCaseData.join(lowerCaseData, $"n" === $"N" && $"n" > 1, "left"), - Row(1, "A", null, null) :: - Row(2, "B", 2, "b") :: - Row(3, "C", 3, "c") :: - Row(4, "D", 4, "d") :: - Row(5, "E", null, null) :: - Row(6, "F", null, null) :: Nil) - - checkAnswer( - upperCaseData.join(lowerCaseData, $"n" === $"N" && $"N" > 1, "left"), - Row(1, "A", null, null) :: - Row(2, "B", 2, "b") :: - Row(3, "C", 3, "c") :: - Row(4, "D", 4, "d") :: - Row(5, "E", null, null) :: - Row(6, "F", null, null) :: Nil) - - checkAnswer( - upperCaseData.join(lowerCaseData, $"n" === $"N" && $"l" > $"L", "left"), - Row(1, "A", 1, "a") :: - Row(2, "B", 2, "b") :: - Row(3, "C", 3, "c") :: - Row(4, "D", 4, "d") :: - Row(5, "E", null, null) :: - Row(6, "F", null, null) :: Nil) - - // Make sure we are choosing left.outputPartitioning as the - // outputPartitioning for the outer join operator. - checkAnswer( - sql( - """ - |SELECT l.N, count(*) - |FROM upperCaseData l LEFT OUTER JOIN allNulls r ON (l.N = r.a) - |GROUP BY l.N - """.stripMargin), - Row(1, 1) :: - Row(2, 1) :: - Row(3, 1) :: - Row(4, 1) :: - Row(5, 1) :: - Row(6, 1) :: Nil) - - checkAnswer( - sql( - """ - |SELECT r.a, count(*) - |FROM upperCaseData l LEFT OUTER JOIN allNulls r ON (l.N = r.a) - |GROUP BY r.a - """.stripMargin), - Row(null, 6) :: Nil) - } + def test_outer_join(useSMJ: Boolean): Unit = { + + val algo = if (useSMJ) "SortMergeOuterJoin" else "ShuffledHashOuterJoin" + + test("left outer join: " + algo) { + withSQLConf(SQLConf.SORTMERGE_JOIN.key -> useSMJ.toString) { + + checkAnswer( + upperCaseData.join(lowerCaseData, $"n" === $"N", "left"), + Row(1, "A", 1, "a") :: + Row(2, "B", 2, "b") :: + Row(3, "C", 3, "c") :: + Row(4, "D", 4, "d") :: + Row(5, "E", null, null) :: + Row(6, "F", null, null) :: Nil) + + checkAnswer( + upperCaseData.join(lowerCaseData, $"n" === $"N" && $"n" > 1, "left"), + Row(1, "A", null, null) :: + Row(2, "B", 2, "b") :: + Row(3, "C", 3, "c") :: + Row(4, "D", 4, "d") :: + Row(5, "E", null, null) :: + Row(6, "F", null, null) :: Nil) + + checkAnswer( + upperCaseData.join(lowerCaseData, $"n" === $"N" && $"N" > 1, "left"), + Row(1, "A", null, null) :: + Row(2, "B", 2, "b") :: + Row(3, "C", 3, "c") :: + Row(4, "D", 4, "d") :: + Row(5, "E", null, null) :: + Row(6, "F", null, null) :: Nil) + + checkAnswer( + upperCaseData.join(lowerCaseData, $"n" === $"N" && $"l" > $"L", "left"), + Row(1, "A", 1, "a") :: + Row(2, "B", 2, "b") :: + Row(3, "C", 3, "c") :: + Row(4, "D", 4, "d") :: + Row(5, "E", null, null) :: + Row(6, "F", null, null) :: Nil) + + // Make sure we are choosing left.outputPartitioning as the + // outputPartitioning for the outer join operator. + checkAnswer( + sql( + """ + |SELECT l.N, count(*) + |FROM upperCaseData l LEFT OUTER JOIN allNulls r ON (l.N = r.a) + |GROUP BY l.N + """. + stripMargin), + Row(1, 1) :: + Row(2, 1) :: + Row(3, 1) :: + Row(4, 1) :: + Row(5, 1) :: + Row(6, 1) :: Nil) + + checkAnswer( + sql( + """ + |SELECT r.a, count(*) + |FROM upperCaseData l LEFT OUTER JOIN allNulls r ON (l.N = r.a) + |GROUP BY r.a + """.stripMargin), + Row(null, 6) :: Nil) + } + } - test("right outer join") { - checkAnswer( - lowerCaseData.join(upperCaseData, $"n" === $"N", "right"), - Row(1, "a", 1, "A") :: - Row(2, "b", 2, "B") :: - Row(3, "c", 3, "C") :: - Row(4, "d", 4, "D") :: - Row(null, null, 5, "E") :: - Row(null, null, 6, "F") :: Nil) - checkAnswer( - lowerCaseData.join(upperCaseData, $"n" === $"N" && $"n" > 1, "right"), - Row(null, null, 1, "A") :: - Row(2, "b", 2, "B") :: - Row(3, "c", 3, "C") :: - Row(4, "d", 4, "D") :: - Row(null, null, 5, "E") :: - Row(null, null, 6, "F") :: Nil) - checkAnswer( - lowerCaseData.join(upperCaseData, $"n" === $"N" && $"N" > 1, "right"), - Row(null, null, 1, "A") :: - Row(2, "b", 2, "B") :: - Row(3, "c", 3, "C") :: - Row(4, "d", 4, "D") :: - Row(null, null, 5, "E") :: - Row(null, null, 6, "F") :: Nil) - checkAnswer( - lowerCaseData.join(upperCaseData, $"n" === $"N" && $"l" > $"L", "right"), - Row(1, "a", 1, "A") :: - Row(2, "b", 2, "B") :: - Row(3, "c", 3, "C") :: - Row(4, "d", 4, "D") :: - Row(null, null, 5, "E") :: - Row(null, null, 6, "F") :: Nil) - - // Make sure we are choosing right.outputPartitioning as the - // outputPartitioning for the outer join operator. - checkAnswer( - sql( - """ - |SELECT l.a, count(*) - |FROM allNulls l RIGHT OUTER JOIN upperCaseData r ON (l.a = r.N) - |GROUP BY l.a - """.stripMargin), - Row(null, 6)) + test("right outer join: " + algo) { + withSQLConf(SQLConf.SORTMERGE_JOIN.key -> useSMJ.toString) { + checkAnswer( + lowerCaseData.join(upperCaseData, $"n" === $"N", "right"), + Row(1, "a", 1, "A") :: + Row(2, "b", 2, "B") :: + Row(3, "c", 3, "C") :: + Row(4, "d", 4, "D") :: + Row(null, null, 5, "E") :: + Row(null, null, 6, "F") :: Nil) + checkAnswer( + lowerCaseData.join(upperCaseData, $"n" === $"N" && $"n" > 1, "right"), + Row(null, null, 1, "A") :: + Row(2, "b", 2, "B") :: + Row(3, "c", 3, "C") :: + Row(4, "d", 4, "D") :: + Row(null, null, 5, "E") :: + Row(null, null, 6, "F") :: Nil) + checkAnswer( + lowerCaseData.join(upperCaseData, $"n" === $"N" && $"N" > 1, "right"), + Row(null, null, 1, "A") :: + Row(2, "b", 2, "B") :: + Row(3, "c", 3, "C") :: + Row(4, "d", 4, "D") :: + Row(null, null, 5, "E") :: + Row(null, null, 6, "F") :: Nil) + checkAnswer( + lowerCaseData.join(upperCaseData, $"n" === $"N" && $"l" > $"L", "right"), + Row(1, "a", 1, "A") :: + Row(2, "b", 2, "B") :: + Row(3, "c", 3, "C") :: + Row(4, "d", 4, "D") :: + Row(null, null, 5, "E") :: + Row(null, null, 6, "F") :: Nil) + + // Make sure we are choosing right.outputPartitioning as the + // outputPartitioning for the outer join operator. + checkAnswer( + sql( + """ + |SELECT l.a, count(*) + |FROM allNulls l RIGHT OUTER JOIN upperCaseData r ON (l.a = r.N) + |GROUP BY l.a + """.stripMargin), + Row(null, + 6)) + + checkAnswer( + sql( + """ + |SELECT r.N, count(*) + |FROM allNulls l RIGHT OUTER JOIN upperCaseData r ON (l.a = r.N) + |GROUP BY r.N + """.stripMargin), + Row(1 + , 1) :: + Row(2, 1) :: + Row(3, 1) :: + Row(4, 1) :: + Row(5, 1) :: + Row(6, 1) :: Nil) + } + } - checkAnswer( - sql( - """ - |SELECT r.N, count(*) - |FROM allNulls l RIGHT OUTER JOIN upperCaseData r ON (l.a = r.N) - |GROUP BY r.N - """.stripMargin), - Row(1, 1) :: - Row(2, 1) :: - Row(3, 1) :: - Row(4, 1) :: - Row(5, 1) :: - Row(6, 1) :: Nil) + test("full outer join: " + algo) { + withSQLConf(SQLConf.SORTMERGE_JOIN.key -> useSMJ.toString) { + + upperCaseData.where('N <= 4).registerTempTable("left") + upperCaseData.where('N >= 3).registerTempTable("right") + + val left = UnresolvedRelation(TableIdentifier("left"), None) + val right = UnresolvedRelation(TableIdentifier("right"), None) + + checkAnswer( + left.join(right, $"left.N" === $"right.N", "full"), + Row(1, "A", null, null) :: + Row(2, "B", null, null) :: + Row(3, "C", 3, "C") :: + Row(4, "D", 4, "D") :: + Row(null, null, 5, "E") :: + Row(null, null, 6, "F") :: Nil) + + checkAnswer( + left.join(right, ($"left.N" === $"right.N") && ($"left.N" !== 3), "full"), + Row(1, "A", null, null) :: + Row(2, "B", null, null) :: + Row(3, "C", null, null) :: + Row(null, null, 3, "C") :: + Row(4, "D", 4, "D") :: + Row(null, null, 5, "E") :: + Row(null, null, 6, "F") :: Nil) + + checkAnswer( + left.join(right, ($"left.N" === $"right.N") && ($"right.N" !== 3), "full"), + Row(1, "A", null, null) :: + Row(2, "B", null, null) :: + Row(3, "C", null, null) :: + Row(null, null, 3, "C") :: + Row(4, "D", 4, "D") :: + Row(null, null, 5, "E") :: + Row(null, null, 6, "F") :: Nil) + + // Make sure we are UnknownPartitioning as the outputPartitioning for the outer join + // operator. + checkAnswer( + sql( + """ + |SELECT l.a, count(*) + |FROM allNulls l FULL OUTER JOIN upperCaseData r ON (l.a = r.N) + |GROUP BY l.a + """. + stripMargin), + Row( + null, 10)) + + checkAnswer( + sql( + """ + |SELECT r.N, count(*) + |FROM allNulls l FULL OUTER JOIN upperCaseData r ON (l.a = r.N) + |GROUP BY r.N + """.stripMargin), + Row + (1, 1) :: + Row(2, 1) :: + Row(3, 1) :: + Row(4, 1) :: + Row(5, 1) :: + Row(6, 1) :: + Row(null, 4) :: Nil) + + checkAnswer( + sql( + """ + |SELECT l.N, count(*) + |FROM upperCaseData l FULL OUTER JOIN allNulls r ON (l.N = r.a) + |GROUP BY l.N + """.stripMargin), + Row(1 + , 1) :: + Row(2, 1) :: + Row(3, 1) :: + Row(4, 1) :: + Row(5, 1) :: + Row(6, 1) :: + Row(null, 4) :: Nil) + + checkAnswer( + sql( + """ + |SELECT r.a, count(*) + |FROM upperCaseData l FULL OUTER JOIN allNulls r ON (l.N = r.a) + |GROUP BY r.a + """. + stripMargin), + Row(null, 10)) + } + } } - test("full outer join") { - upperCaseData.where('N <= 4).registerTempTable("left") - upperCaseData.where('N >= 3).registerTempTable("right") - - val left = UnresolvedRelation(TableIdentifier("left"), None) - val right = UnresolvedRelation(TableIdentifier("right"), None) - - checkAnswer( - left.join(right, $"left.N" === $"right.N", "full"), - Row(1, "A", null, null) :: - Row(2, "B", null, null) :: - Row(3, "C", 3, "C") :: - Row(4, "D", 4, "D") :: - Row(null, null, 5, "E") :: - Row(null, null, 6, "F") :: Nil) - - checkAnswer( - left.join(right, ($"left.N" === $"right.N") && ($"left.N" !== 3), "full"), - Row(1, "A", null, null) :: - Row(2, "B", null, null) :: - Row(3, "C", null, null) :: - Row(null, null, 3, "C") :: - Row(4, "D", 4, "D") :: - Row(null, null, 5, "E") :: - Row(null, null, 6, "F") :: Nil) - - checkAnswer( - left.join(right, ($"left.N" === $"right.N") && ($"right.N" !== 3), "full"), - Row(1, "A", null, null) :: - Row(2, "B", null, null) :: - Row(3, "C", null, null) :: - Row(null, null, 3, "C") :: - Row(4, "D", 4, "D") :: - Row(null, null, 5, "E") :: - Row(null, null, 6, "F") :: Nil) - - // Make sure we are UnknownPartitioning as the outputPartitioning for the outer join operator. - checkAnswer( - sql( - """ - |SELECT l.a, count(*) - |FROM allNulls l FULL OUTER JOIN upperCaseData r ON (l.a = r.N) - |GROUP BY l.a - """.stripMargin), - Row(null, 10)) - - checkAnswer( - sql( - """ - |SELECT r.N, count(*) - |FROM allNulls l FULL OUTER JOIN upperCaseData r ON (l.a = r.N) - |GROUP BY r.N - """.stripMargin), - Row(1, 1) :: - Row(2, 1) :: - Row(3, 1) :: - Row(4, 1) :: - Row(5, 1) :: - Row(6, 1) :: - Row(null, 4) :: Nil) - - checkAnswer( - sql( - """ - |SELECT l.N, count(*) - |FROM upperCaseData l FULL OUTER JOIN allNulls r ON (l.N = r.a) - |GROUP BY l.N - """.stripMargin), - Row(1, 1) :: - Row(2, 1) :: - Row(3, 1) :: - Row(4, 1) :: - Row(5, 1) :: - Row(6, 1) :: - Row(null, 4) :: Nil) - - checkAnswer( - sql( - """ - |SELECT r.a, count(*) - |FROM upperCaseData l FULL OUTER JOIN allNulls r ON (l.N = r.a) - |GROUP BY r.a - """.stripMargin), - Row(null, 10)) - } + // test SortMergeOuterJoin + test_outer_join(true) + // test ShuffledHashOuterJoin + test_outer_join(false) test("broadcasted left semi join operator selection") { sqlContext.cacheManager.clearCache() From 5507a9d0935aa42d65c3a4fa65da680b5af14faf Mon Sep 17 00:00:00 2001 From: Paul Chandler Date: Tue, 10 Nov 2015 12:59:53 +0100 Subject: [PATCH 73/88] Fix typo in driver page "Comamnd property" => "Command property" Author: Paul Chandler Closes #9578 from pestilence669/fix_spelling. --- .../scala/org/apache/spark/deploy/mesos/ui/DriverPage.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/mesos/ui/DriverPage.scala b/core/src/main/scala/org/apache/spark/deploy/mesos/ui/DriverPage.scala index e8ef60bd5428a..bc67fd460d9a9 100644 --- a/core/src/main/scala/org/apache/spark/deploy/mesos/ui/DriverPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/mesos/ui/DriverPage.scala @@ -46,7 +46,7 @@ private[ui] class DriverPage(parent: MesosClusterUI) extends WebUIPage("driver") val schedulerHeaders = Seq("Scheduler property", "Value") val commandEnvHeaders = Seq("Command environment variable", "Value") val launchedHeaders = Seq("Launched property", "Value") - val commandHeaders = Seq("Comamnd property", "Value") + val commandHeaders = Seq("Command property", "Value") val retryHeaders = Seq("Last failed status", "Next retry time", "Retry count") val driverDescription = Iterable.apply(driverState.description) val submissionState = Iterable.apply(driverState.submissionState) From a81f47ff7498e7063c855ccf75bba81ab101b43e Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Tue, 10 Nov 2015 10:05:53 -0800 Subject: [PATCH 74/88] [SPARK-11382] Replace example code in mllib-decision-tree.md using include_example https://issues.apache.org/jira/browse/SPARK-11382 B.T.W. I fix an error in naive_bayes_example.py. Author: Xusen Yin Closes #9596 from yinxusen/SPARK-11382. --- docs/mllib-decision-tree.md | 253 +----------------- ...JavaDecisionTreeClassificationExample.java | 91 +++++++ .../JavaDecisionTreeRegressionExample.java | 96 +++++++ .../decision_tree_classification_example.py | 55 ++++ .../mllib/decision_tree_regression_example.py | 56 ++++ .../main/python/mllib/naive_bayes_example.py | 1 + .../DecisionTreeClassificationExample.scala | 67 +++++ .../mllib/DecisionTreeRegressionExample.scala | 66 +++++ 8 files changed, 438 insertions(+), 247 deletions(-) create mode 100644 examples/src/main/java/org/apache/spark/examples/mllib/JavaDecisionTreeClassificationExample.java create mode 100644 examples/src/main/java/org/apache/spark/examples/mllib/JavaDecisionTreeRegressionExample.java create mode 100644 examples/src/main/python/mllib/decision_tree_classification_example.py create mode 100644 examples/src/main/python/mllib/decision_tree_regression_example.py create mode 100644 examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeClassificationExample.scala create mode 100644 examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRegressionExample.scala diff --git a/docs/mllib-decision-tree.md b/docs/mllib-decision-tree.md index b5b454bc69245..77ce34e91af3c 100644 --- a/docs/mllib-decision-tree.md +++ b/docs/mllib-decision-tree.md @@ -194,137 +194,19 @@ maximum tree depth of 5. The test error is calculated to measure the algorithm a
Refer to the [`DecisionTree` Scala docs](api/scala/index.html#org.apache.spark.mllib.tree.DecisionTree) and [`DecisionTreeModel` Scala docs](api/scala/index.html#org.apache.spark.mllib.tree.model.DecisionTreeModel) for details on the API. -{% highlight scala %} -import org.apache.spark.mllib.tree.DecisionTree -import org.apache.spark.mllib.tree.model.DecisionTreeModel -import org.apache.spark.mllib.util.MLUtils - -// Load and parse the data file. -val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") -// Split the data into training and test sets (30% held out for testing) -val splits = data.randomSplit(Array(0.7, 0.3)) -val (trainingData, testData) = (splits(0), splits(1)) - -// Train a DecisionTree model. -// Empty categoricalFeaturesInfo indicates all features are continuous. -val numClasses = 2 -val categoricalFeaturesInfo = Map[Int, Int]() -val impurity = "gini" -val maxDepth = 5 -val maxBins = 32 - -val model = DecisionTree.trainClassifier(trainingData, numClasses, categoricalFeaturesInfo, - impurity, maxDepth, maxBins) - -// Evaluate model on test instances and compute test error -val labelAndPreds = testData.map { point => - val prediction = model.predict(point.features) - (point.label, prediction) -} -val testErr = labelAndPreds.filter(r => r._1 != r._2).count.toDouble / testData.count() -println("Test Error = " + testErr) -println("Learned classification tree model:\n" + model.toDebugString) - -// Save and load model -model.save(sc, "myModelPath") -val sameModel = DecisionTreeModel.load(sc, "myModelPath") -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/mllib/DecisionTreeClassificationExample.scala %}
Refer to the [`DecisionTree` Java docs](api/java/org/apache/spark/mllib/tree/DecisionTree.html) and [`DecisionTreeModel` Java docs](api/java/org/apache/spark/mllib/tree/model/DecisionTreeModel.html) for details on the API. -{% highlight java %} -import java.util.HashMap; -import scala.Tuple2; -import org.apache.spark.api.java.JavaPairRDD; -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.api.java.function.Function; -import org.apache.spark.api.java.function.PairFunction; -import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.mllib.tree.DecisionTree; -import org.apache.spark.mllib.tree.model.DecisionTreeModel; -import org.apache.spark.mllib.util.MLUtils; -import org.apache.spark.SparkConf; - -SparkConf sparkConf = new SparkConf().setAppName("JavaDecisionTree"); -JavaSparkContext sc = new JavaSparkContext(sparkConf); - -// Load and parse the data file. -String datapath = "data/mllib/sample_libsvm_data.txt"; -JavaRDD data = MLUtils.loadLibSVMFile(sc.sc(), datapath).toJavaRDD(); -// Split the data into training and test sets (30% held out for testing) -JavaRDD[] splits = data.randomSplit(new double[]{0.7, 0.3}); -JavaRDD trainingData = splits[0]; -JavaRDD testData = splits[1]; - -// Set parameters. -// Empty categoricalFeaturesInfo indicates all features are continuous. -Integer numClasses = 2; -Map categoricalFeaturesInfo = new HashMap(); -String impurity = "gini"; -Integer maxDepth = 5; -Integer maxBins = 32; - -// Train a DecisionTree model for classification. -final DecisionTreeModel model = DecisionTree.trainClassifier(trainingData, numClasses, - categoricalFeaturesInfo, impurity, maxDepth, maxBins); - -// Evaluate model on test instances and compute test error -JavaPairRDD predictionAndLabel = - testData.mapToPair(new PairFunction() { - @Override - public Tuple2 call(LabeledPoint p) { - return new Tuple2(model.predict(p.features()), p.label()); - } - }); -Double testErr = - 1.0 * predictionAndLabel.filter(new Function, Boolean>() { - @Override - public Boolean call(Tuple2 pl) { - return !pl._1().equals(pl._2()); - } - }).count() / testData.count(); -System.out.println("Test Error: " + testErr); -System.out.println("Learned classification tree model:\n" + model.toDebugString()); - -// Save and load model -model.save(sc.sc(), "myModelPath"); -DecisionTreeModel sameModel = DecisionTreeModel.load(sc.sc(), "myModelPath"); -{% endhighlight %} +{% include_example java/org/apache/spark/examples/mllib/JavaDecisionTreeClassificationExample.java %}
Refer to the [`DecisionTree` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.tree.DecisionTree) and [`DecisionTreeModel` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.tree.DecisionTreeModel) for more details on the API. -{% highlight python %} -from pyspark.mllib.regression import LabeledPoint -from pyspark.mllib.tree import DecisionTree, DecisionTreeModel -from pyspark.mllib.util import MLUtils - -# Load and parse the data file into an RDD of LabeledPoint. -data = MLUtils.loadLibSVMFile(sc, 'data/mllib/sample_libsvm_data.txt') -# Split the data into training and test sets (30% held out for testing) -(trainingData, testData) = data.randomSplit([0.7, 0.3]) - -# Train a DecisionTree model. -# Empty categoricalFeaturesInfo indicates all features are continuous. -model = DecisionTree.trainClassifier(trainingData, numClasses=2, categoricalFeaturesInfo={}, - impurity='gini', maxDepth=5, maxBins=32) - -# Evaluate model on test instances and compute test error -predictions = model.predict(testData.map(lambda x: x.features)) -labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions) -testErr = labelsAndPredictions.filter(lambda (v, p): v != p).count() / float(testData.count()) -print('Test Error = ' + str(testErr)) -print('Learned classification tree model:') -print(model.toDebugString()) - -# Save and load model -model.save(sc, "myModelPath") -sameModel = DecisionTreeModel.load(sc, "myModelPath") -{% endhighlight %} +{% include_example python/mllib/decision_tree_classification_example.py %}
@@ -343,142 +225,19 @@ depth of 5. The Mean Squared Error (MSE) is computed at the end to evaluate
Refer to the [`DecisionTree` Scala docs](api/scala/index.html#org.apache.spark.mllib.tree.DecisionTree) and [`DecisionTreeModel` Scala docs](api/scala/index.html#org.apache.spark.mllib.tree.model.DecisionTreeModel) for details on the API. -{% highlight scala %} -import org.apache.spark.mllib.tree.DecisionTree -import org.apache.spark.mllib.tree.model.DecisionTreeModel -import org.apache.spark.mllib.util.MLUtils - -// Load and parse the data file. -val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") -// Split the data into training and test sets (30% held out for testing) -val splits = data.randomSplit(Array(0.7, 0.3)) -val (trainingData, testData) = (splits(0), splits(1)) - -// Train a DecisionTree model. -// Empty categoricalFeaturesInfo indicates all features are continuous. -val categoricalFeaturesInfo = Map[Int, Int]() -val impurity = "variance" -val maxDepth = 5 -val maxBins = 32 - -val model = DecisionTree.trainRegressor(trainingData, categoricalFeaturesInfo, impurity, - maxDepth, maxBins) - -// Evaluate model on test instances and compute test error -val labelsAndPredictions = testData.map { point => - val prediction = model.predict(point.features) - (point.label, prediction) -} -val testMSE = labelsAndPredictions.map{ case(v, p) => math.pow((v - p), 2)}.mean() -println("Test Mean Squared Error = " + testMSE) -println("Learned regression tree model:\n" + model.toDebugString) - -// Save and load model -model.save(sc, "myModelPath") -val sameModel = DecisionTreeModel.load(sc, "myModelPath") -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/mllib/DecisionTreeRegressionExample.scala %}
Refer to the [`DecisionTree` Java docs](api/java/org/apache/spark/mllib/tree/DecisionTree.html) and [`DecisionTreeModel` Java docs](api/java/org/apache/spark/mllib/tree/model/DecisionTreeModel.html) for details on the API. -{% highlight java %} -import java.util.HashMap; -import scala.Tuple2; -import org.apache.spark.api.java.function.Function2; -import org.apache.spark.api.java.JavaPairRDD; -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.api.java.function.Function; -import org.apache.spark.api.java.function.PairFunction; -import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.mllib.tree.DecisionTree; -import org.apache.spark.mllib.tree.model.DecisionTreeModel; -import org.apache.spark.mllib.util.MLUtils; -import org.apache.spark.SparkConf; - -SparkConf sparkConf = new SparkConf().setAppName("JavaDecisionTree"); -JavaSparkContext sc = new JavaSparkContext(sparkConf); - -// Load and parse the data file. -String datapath = "data/mllib/sample_libsvm_data.txt"; -JavaRDD data = MLUtils.loadLibSVMFile(sc.sc(), datapath).toJavaRDD(); -// Split the data into training and test sets (30% held out for testing) -JavaRDD[] splits = data.randomSplit(new double[]{0.7, 0.3}); -JavaRDD trainingData = splits[0]; -JavaRDD testData = splits[1]; - -// Set parameters. -// Empty categoricalFeaturesInfo indicates all features are continuous. -Map categoricalFeaturesInfo = new HashMap(); -String impurity = "variance"; -Integer maxDepth = 5; -Integer maxBins = 32; - -// Train a DecisionTree model. -final DecisionTreeModel model = DecisionTree.trainRegressor(trainingData, - categoricalFeaturesInfo, impurity, maxDepth, maxBins); - -// Evaluate model on test instances and compute test error -JavaPairRDD predictionAndLabel = - testData.mapToPair(new PairFunction() { - @Override - public Tuple2 call(LabeledPoint p) { - return new Tuple2(model.predict(p.features()), p.label()); - } - }); -Double testMSE = - predictionAndLabel.map(new Function, Double>() { - @Override - public Double call(Tuple2 pl) { - Double diff = pl._1() - pl._2(); - return diff * diff; - } - }).reduce(new Function2() { - @Override - public Double call(Double a, Double b) { - return a + b; - } - }) / testData.count(); -System.out.println("Test Mean Squared Error: " + testMSE); -System.out.println("Learned regression tree model:\n" + model.toDebugString()); - -// Save and load model -model.save(sc.sc(), "myModelPath"); -DecisionTreeModel sameModel = DecisionTreeModel.load(sc.sc(), "myModelPath"); -{% endhighlight %} +{% include_example java/org/apache/spark/examples/mllib/JavaDecisionTreeRegressionExample.java %}
Refer to the [`DecisionTree` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.tree.DecisionTree) and [`DecisionTreeModel` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.tree.DecisionTreeModel) for more details on the API. -{% highlight python %} -from pyspark.mllib.regression import LabeledPoint -from pyspark.mllib.tree import DecisionTree, DecisionTreeModel -from pyspark.mllib.util import MLUtils - -# Load and parse the data file into an RDD of LabeledPoint. -data = MLUtils.loadLibSVMFile(sc, 'data/mllib/sample_libsvm_data.txt') -# Split the data into training and test sets (30% held out for testing) -(trainingData, testData) = data.randomSplit([0.7, 0.3]) - -# Train a DecisionTree model. -# Empty categoricalFeaturesInfo indicates all features are continuous. -model = DecisionTree.trainRegressor(trainingData, categoricalFeaturesInfo={}, - impurity='variance', maxDepth=5, maxBins=32) - -# Evaluate model on test instances and compute test error -predictions = model.predict(testData.map(lambda x: x.features)) -labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions) -testMSE = labelsAndPredictions.map(lambda (v, p): (v - p) * (v - p)).sum() / float(testData.count()) -print('Test Mean Squared Error = ' + str(testMSE)) -print('Learned regression tree model:') -print(model.toDebugString()) - -# Save and load model -model.save(sc, "myModelPath") -sameModel = DecisionTreeModel.load(sc, "myModelPath") -{% endhighlight %} +{% include_example python/mllib/decision_tree_regression_example.py %}
diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaDecisionTreeClassificationExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaDecisionTreeClassificationExample.java new file mode 100644 index 0000000000000..5839b0cf8a8f8 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaDecisionTreeClassificationExample.java @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib; + +// $example on$ +import java.util.HashMap; +import java.util.Map; + +import scala.Tuple2; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaPairRDD; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.api.java.function.PairFunction; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.tree.DecisionTree; +import org.apache.spark.mllib.tree.model.DecisionTreeModel; +import org.apache.spark.mllib.util.MLUtils; +// $example off$ + +class JavaDecisionTreeClassificationExample { + + public static void main(String[] args) { + + // $example on$ + SparkConf sparkConf = new SparkConf().setAppName("JavaDecisionTreeClassificationExample"); + JavaSparkContext jsc = new JavaSparkContext(sparkConf); + + // Load and parse the data file. + String datapath = "data/mllib/sample_libsvm_data.txt"; + JavaRDD data = MLUtils.loadLibSVMFile(jsc.sc(), datapath).toJavaRDD(); + // Split the data into training and test sets (30% held out for testing) + JavaRDD[] splits = data.randomSplit(new double[]{0.7, 0.3}); + JavaRDD trainingData = splits[0]; + JavaRDD testData = splits[1]; + + // Set parameters. + // Empty categoricalFeaturesInfo indicates all features are continuous. + Integer numClasses = 2; + Map categoricalFeaturesInfo = new HashMap(); + String impurity = "gini"; + Integer maxDepth = 5; + Integer maxBins = 32; + + // Train a DecisionTree model for classification. + final DecisionTreeModel model = DecisionTree.trainClassifier(trainingData, numClasses, + categoricalFeaturesInfo, impurity, maxDepth, maxBins); + + // Evaluate model on test instances and compute test error + JavaPairRDD predictionAndLabel = + testData.mapToPair(new PairFunction() { + @Override + public Tuple2 call(LabeledPoint p) { + return new Tuple2(model.predict(p.features()), p.label()); + } + }); + Double testErr = + 1.0 * predictionAndLabel.filter(new Function, Boolean>() { + @Override + public Boolean call(Tuple2 pl) { + return !pl._1().equals(pl._2()); + } + }).count() / testData.count(); + + System.out.println("Test Error: " + testErr); + System.out.println("Learned classification tree model:\n" + model.toDebugString()); + + // Save and load model + model.save(jsc.sc(), "target/tmp/myDecisionTreeClassificationModel"); + DecisionTreeModel sameModel = DecisionTreeModel + .load(jsc.sc(), "target/tmp/myDecisionTreeClassificationModel"); + // $example off$ + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaDecisionTreeRegressionExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaDecisionTreeRegressionExample.java new file mode 100644 index 0000000000000..ccde578249f7c --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaDecisionTreeRegressionExample.java @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib; + +// $example on$ +import java.util.HashMap; +import java.util.Map; + +import scala.Tuple2; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaPairRDD; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.api.java.function.Function2; +import org.apache.spark.api.java.function.PairFunction; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.tree.DecisionTree; +import org.apache.spark.mllib.tree.model.DecisionTreeModel; +import org.apache.spark.mllib.util.MLUtils; +// $example off$ + +class JavaDecisionTreeRegressionExample { + + public static void main(String[] args) { + + // $example on$ + SparkConf sparkConf = new SparkConf().setAppName("JavaDecisionTreeRegressionExample"); + JavaSparkContext jsc = new JavaSparkContext(sparkConf); + + // Load and parse the data file. + String datapath = "data/mllib/sample_libsvm_data.txt"; + JavaRDD data = MLUtils.loadLibSVMFile(jsc.sc(), datapath).toJavaRDD(); + // Split the data into training and test sets (30% held out for testing) + JavaRDD[] splits = data.randomSplit(new double[]{0.7, 0.3}); + JavaRDD trainingData = splits[0]; + JavaRDD testData = splits[1]; + + // Set parameters. + // Empty categoricalFeaturesInfo indicates all features are continuous. + Map categoricalFeaturesInfo = new HashMap(); + String impurity = "variance"; + Integer maxDepth = 5; + Integer maxBins = 32; + + // Train a DecisionTree model. + final DecisionTreeModel model = DecisionTree.trainRegressor(trainingData, + categoricalFeaturesInfo, impurity, maxDepth, maxBins); + + // Evaluate model on test instances and compute test error + JavaPairRDD predictionAndLabel = + testData.mapToPair(new PairFunction() { + @Override + public Tuple2 call(LabeledPoint p) { + return new Tuple2(model.predict(p.features()), p.label()); + } + }); + Double testMSE = + predictionAndLabel.map(new Function, Double>() { + @Override + public Double call(Tuple2 pl) { + Double diff = pl._1() - pl._2(); + return diff * diff; + } + }).reduce(new Function2() { + @Override + public Double call(Double a, Double b) { + return a + b; + } + }) / data.count(); + System.out.println("Test Mean Squared Error: " + testMSE); + System.out.println("Learned regression tree model:\n" + model.toDebugString()); + + // Save and load model + model.save(jsc.sc(), "target/tmp/myDecisionTreeRegressionModel"); + DecisionTreeModel sameModel = DecisionTreeModel + .load(jsc.sc(), "target/tmp/myDecisionTreeRegressionModel"); + // $example off$ + } +} diff --git a/examples/src/main/python/mllib/decision_tree_classification_example.py b/examples/src/main/python/mllib/decision_tree_classification_example.py new file mode 100644 index 0000000000000..1b529768b6c62 --- /dev/null +++ b/examples/src/main/python/mllib/decision_tree_classification_example.py @@ -0,0 +1,55 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Decision Tree Classification Example. +""" +from __future__ import print_function + +from pyspark import SparkContext +# $example on$ +from pyspark.mllib.tree import DecisionTree, DecisionTreeModel +from pyspark.mllib.util import MLUtils +# $example off$ + +if __name__ == "__main__": + + sc = SparkContext(appName="PythonDecisionTreeClassificationExample") + + # $example on$ + # Load and parse the data file into an RDD of LabeledPoint. + data = MLUtils.loadLibSVMFile(sc, 'data/mllib/sample_libsvm_data.txt') + # Split the data into training and test sets (30% held out for testing) + (trainingData, testData) = data.randomSplit([0.7, 0.3]) + + # Train a DecisionTree model. + # Empty categoricalFeaturesInfo indicates all features are continuous. + model = DecisionTree.trainClassifier(trainingData, numClasses=2, categoricalFeaturesInfo={}, + impurity='gini', maxDepth=5, maxBins=32) + + # Evaluate model on test instances and compute test error + predictions = model.predict(testData.map(lambda x: x.features)) + labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions) + testErr = labelsAndPredictions.filter(lambda (v, p): v != p).count() / float(testData.count()) + print('Test Error = ' + str(testErr)) + print('Learned classification tree model:') + print(model.toDebugString()) + + # Save and load model + model.save(sc, "target/tmp/myDecisionTreeClassificationModel") + sameModel = DecisionTreeModel.load(sc, "target/tmp/myDecisionTreeClassificationModel") + # $example off$ diff --git a/examples/src/main/python/mllib/decision_tree_regression_example.py b/examples/src/main/python/mllib/decision_tree_regression_example.py new file mode 100644 index 0000000000000..cf518eac67e81 --- /dev/null +++ b/examples/src/main/python/mllib/decision_tree_regression_example.py @@ -0,0 +1,56 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Decision Tree Regression Example. +""" +from __future__ import print_function + +from pyspark import SparkContext +# $example on$ +from pyspark.mllib.tree import DecisionTree, DecisionTreeModel +from pyspark.mllib.util import MLUtils +# $example off$ + +if __name__ == "__main__": + + sc = SparkContext(appName="PythonDecisionTreeRegressionExample") + + # $example on$ + # Load and parse the data file into an RDD of LabeledPoint. + data = MLUtils.loadLibSVMFile(sc, 'data/mllib/sample_libsvm_data.txt') + # Split the data into training and test sets (30% held out for testing) + (trainingData, testData) = data.randomSplit([0.7, 0.3]) + + # Train a DecisionTree model. + # Empty categoricalFeaturesInfo indicates all features are continuous. + model = DecisionTree.trainRegressor(trainingData, categoricalFeaturesInfo={}, + impurity='variance', maxDepth=5, maxBins=32) + + # Evaluate model on test instances and compute test error + predictions = model.predict(testData.map(lambda x: x.features)) + labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions) + testMSE = labelsAndPredictions.map(lambda (v, p): (v - p) * (v - p)).sum() /\ + float(testData.count()) + print('Test Mean Squared Error = ' + str(testMSE)) + print('Learned regression tree model:') + print(model.toDebugString()) + + # Save and load model + model.save(sc, "target/tmp/myDecisionTreeRegressionModel") + sameModel = DecisionTreeModel.load(sc, "target/tmp/myDecisionTreeRegressionModel") + # $example off$ diff --git a/examples/src/main/python/mllib/naive_bayes_example.py b/examples/src/main/python/mllib/naive_bayes_example.py index a2e7dacf25491..f5e120c678fcf 100644 --- a/examples/src/main/python/mllib/naive_bayes_example.py +++ b/examples/src/main/python/mllib/naive_bayes_example.py @@ -20,6 +20,7 @@ """ from __future__ import print_function +from pyspark import SparkContext # $example on$ from pyspark.mllib.classification import NaiveBayes, NaiveBayesModel from pyspark.mllib.linalg import Vectors diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeClassificationExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeClassificationExample.scala new file mode 100644 index 0000000000000..d427bbadaa0c1 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeClassificationExample.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.mllib + +// $example on$ +import org.apache.spark.mllib.tree.DecisionTree +import org.apache.spark.mllib.tree.model.DecisionTreeModel +import org.apache.spark.mllib.util.MLUtils +// $example off$ +import org.apache.spark.{SparkConf, SparkContext} + +object DecisionTreeClassificationExample { + + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("DecisionTreeClassificationExample") + val sc = new SparkContext(conf) + + // $example on$ + // Load and parse the data file. + val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") + // Split the data into training and test sets (30% held out for testing) + val splits = data.randomSplit(Array(0.7, 0.3)) + val (trainingData, testData) = (splits(0), splits(1)) + + // Train a DecisionTree model. + // Empty categoricalFeaturesInfo indicates all features are continuous. + val numClasses = 2 + val categoricalFeaturesInfo = Map[Int, Int]() + val impurity = "gini" + val maxDepth = 5 + val maxBins = 32 + + val model = DecisionTree.trainClassifier(trainingData, numClasses, categoricalFeaturesInfo, + impurity, maxDepth, maxBins) + + // Evaluate model on test instances and compute test error + val labelAndPreds = testData.map { point => + val prediction = model.predict(point.features) + (point.label, prediction) + } + val testErr = labelAndPreds.filter(r => r._1 != r._2).count().toDouble / testData.count() + println("Test Error = " + testErr) + println("Learned classification tree model:\n" + model.toDebugString) + + // Save and load model + model.save(sc, "target/tmp/myDecisionTreeClassificationModel") + val sameModel = DecisionTreeModel.load(sc, "target/tmp/myDecisionTreeClassificationModel") + // $example off$ + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRegressionExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRegressionExample.scala new file mode 100644 index 0000000000000..fb05e7d9c5065 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRegressionExample.scala @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.mllib + +// $example on$ +import org.apache.spark.mllib.tree.DecisionTree +import org.apache.spark.mllib.tree.model.DecisionTreeModel +import org.apache.spark.mllib.util.MLUtils +// $example off$ +import org.apache.spark.{SparkConf, SparkContext} + +object DecisionTreeRegressionExample { + + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("DecisionTreeRegressionExample") + val sc = new SparkContext(conf) + + // $example on$ + // Load and parse the data file. + val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") + // Split the data into training and test sets (30% held out for testing) + val splits = data.randomSplit(Array(0.7, 0.3)) + val (trainingData, testData) = (splits(0), splits(1)) + + // Train a DecisionTree model. + // Empty categoricalFeaturesInfo indicates all features are continuous. + val categoricalFeaturesInfo = Map[Int, Int]() + val impurity = "variance" + val maxDepth = 5 + val maxBins = 32 + + val model = DecisionTree.trainRegressor(trainingData, categoricalFeaturesInfo, impurity, + maxDepth, maxBins) + + // Evaluate model on test instances and compute test error + val labelsAndPredictions = testData.map { point => + val prediction = model.predict(point.features) + (point.label, prediction) + } + val testMSE = labelsAndPredictions.map{ case (v, p) => math.pow(v - p, 2) }.mean() + println("Test Mean Squared Error = " + testMSE) + println("Learned regression tree model:\n" + model.toDebugString) + + // Save and load model + model.save(sc, "target/tmp/myDecisionTreeRegressionModel") + val sameModel = DecisionTreeModel.load(sc, "target/tmp/myDecisionTreeRegressionModel") + // $example off$ + } +} +// scalastyle:on println From 689386b1c60997e4505749915f7005a52c207de2 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 10 Nov 2015 10:14:19 -0800 Subject: [PATCH 75/88] [SPARK-7841][BUILD] Stop using retrieveManaged to retrieve dependencies in SBT This patch modifies Spark's SBT build so that it no longer uses `retrieveManaged` / `lib_managed` to store its dependencies. The motivations for this change are nicely described on the JIRA ticket ([SPARK-7841](https://issues.apache.org/jira/browse/SPARK-7841)); my personal interest in doing this stems from the fact that `lib_managed` has caused me some pain while debugging dependency issues in another PR of mine. Removing our use of `lib_managed` would be trivial except for one snag: the Datanucleus JARs, required by Spark SQL's Hive integration, cannot be included in assembly JARs due to problems with merging OSGI `plugin.xml` files. As a result, several places in the packaging and deployment pipeline assume that these Datanucleus JARs are copied to `lib_managed/jars`. In the interest of maintaining compatibility, I have chosen to retain the `lib_managed/jars` directory _only_ for these Datanucleus JARs and have added custom code to `SparkBuild.scala` to automatically copy those JARs to that folder as part of the `assembly` task. `dev/mima` also depended on `lib_managed` in a hacky way in order to set classpaths when generating MiMa excludes; I've updated this to obtain the classpaths directly from SBT instead. /cc dragos marmbrus pwendell srowen Author: Josh Rosen Closes #9575 from JoshRosen/SPARK-7841. --- dev/mima | 2 +- project/SparkBuild.scala | 22 +++++++++++++++++----- 2 files changed, 18 insertions(+), 6 deletions(-) diff --git a/dev/mima b/dev/mima index 2952fa65d42ff..d5baffc6ef8a3 100755 --- a/dev/mima +++ b/dev/mima @@ -38,7 +38,7 @@ generate_mima_ignore() { # it did not process the new classes (which are in assembly jar). generate_mima_ignore -export SPARK_CLASSPATH="`find lib_managed \( -name '*spark*jar' -a -type f \) | tr "\\n" ":"`" +export SPARK_CLASSPATH="$(build/sbt "export oldDeps/fullClasspath" | tail -n1)" echo "SPARK_CLASSPATH=$SPARK_CLASSPATH" generate_mima_ignore diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index b75ed13a78c68..a9fb741d75933 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -16,6 +16,7 @@ */ import java.io._ +import java.nio.file.Files import scala.util.Properties import scala.collection.JavaConverters._ @@ -135,8 +136,6 @@ object SparkBuild extends PomBuild { .orElse(sys.props.get("java.home").map { p => new File(p).getParentFile().getAbsolutePath() }) .map(file), incOptions := incOptions.value.withNameHashing(true), - retrieveManaged := true, - retrievePattern := "[type]s/[artifact](-[revision])(-[classifier]).[ext]", publishMavenStyle := true, unidocGenjavadocVersion := "0.9-spark0", @@ -326,8 +325,6 @@ object OldDeps { def oldDepsSettings() = Defaults.coreDefaultSettings ++ Seq( name := "old-deps", scalaVersion := "2.10.5", - retrieveManaged := true, - retrievePattern := "[type]s/[artifact](-[revision])(-[classifier]).[ext]", libraryDependencies := Seq("spark-streaming-mqtt", "spark-streaming-zeromq", "spark-streaming-flume", "spark-streaming-kafka", "spark-streaming-twitter", "spark-streaming", "spark-mllib", "spark-bagel", "spark-graphx", @@ -404,6 +401,8 @@ object Assembly { val hadoopVersion = taskKey[String]("The version of hadoop that spark is compiled against.") + val deployDatanucleusJars = taskKey[Unit]("Deploy datanucleus jars to the spark/lib_managed/jars directory") + lazy val settings = assemblySettings ++ Seq( test in assembly := {}, hadoopVersion := { @@ -429,7 +428,20 @@ object Assembly { case m if m.toLowerCase.startsWith("meta-inf/services/") => MergeStrategy.filterDistinctLines case "reference.conf" => MergeStrategy.concat case _ => MergeStrategy.first - } + }, + deployDatanucleusJars := { + val jars: Seq[File] = (fullClasspath in assembly).value.map(_.data) + .filter(_.getPath.contains("org.datanucleus")) + var libManagedJars = new File(BuildCommons.sparkHome, "lib_managed/jars") + libManagedJars.mkdirs() + jars.foreach { jar => + val dest = new File(libManagedJars, jar.getName) + if (!dest.exists()) { + Files.copy(jar.toPath, dest.toPath) + } + } + }, + assembly <<= assembly.dependsOn(deployDatanucleusJars) ) } From 6e5fc37883ed81c3ee2338145a48de3036d19399 Mon Sep 17 00:00:00 2001 From: Lianhui Wang Date: Tue, 10 Nov 2015 10:40:08 -0800 Subject: [PATCH 76/88] [SPARK-11252][NETWORK] ShuffleClient should release connection after fetching blocks had been completed for external shuffle with yarn's external shuffle, ExternalShuffleClient of executors reserve its connections for yarn's NodeManager until application has been completed. so it will make NodeManager and executors have many socket connections. in order to reduce network pressure of NodeManager's shuffleService, after registerWithShuffleServer or fetchBlocks have been completed in ExternalShuffleClient, connection for NM's shuffleService needs to be closed.andrewor14 rxin vanzin Author: Lianhui Wang Closes #9227 from lianhuiwang/spark-11252. --- .../spark/deploy/ExternalShuffleService.scala | 3 +- .../spark/network/TransportContext.java | 11 +++++- .../client/TransportClientFactory.java | 10 ++++++ .../server/TransportChannelHandler.java | 26 +++++++++----- .../network/TransportClientFactorySuite.java | 34 +++++++++++++++++++ .../shuffle/ExternalShuffleClient.java | 12 ++++--- 6 files changed, 81 insertions(+), 15 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala b/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala index 6840a3ae831f0..a039d543c35e7 100644 --- a/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala +++ b/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala @@ -47,7 +47,8 @@ class ExternalShuffleService(sparkConf: SparkConf, securityManager: SecurityMana private val transportConf = SparkTransportConf.fromSparkConf(sparkConf, numUsableCores = 0) private val blockHandler = newShuffleBlockHandler(transportConf) - private val transportContext: TransportContext = new TransportContext(transportConf, blockHandler) + private val transportContext: TransportContext = + new TransportContext(transportConf, blockHandler, true) private var server: TransportServer = _ diff --git a/network/common/src/main/java/org/apache/spark/network/TransportContext.java b/network/common/src/main/java/org/apache/spark/network/TransportContext.java index 43900e6f2c972..1b64b863a9fe5 100644 --- a/network/common/src/main/java/org/apache/spark/network/TransportContext.java +++ b/network/common/src/main/java/org/apache/spark/network/TransportContext.java @@ -59,15 +59,24 @@ public class TransportContext { private final TransportConf conf; private final RpcHandler rpcHandler; + private final boolean closeIdleConnections; private final MessageEncoder encoder; private final MessageDecoder decoder; public TransportContext(TransportConf conf, RpcHandler rpcHandler) { + this(conf, rpcHandler, false); + } + + public TransportContext( + TransportConf conf, + RpcHandler rpcHandler, + boolean closeIdleConnections) { this.conf = conf; this.rpcHandler = rpcHandler; this.encoder = new MessageEncoder(); this.decoder = new MessageDecoder(); + this.closeIdleConnections = closeIdleConnections; } /** @@ -144,7 +153,7 @@ private TransportChannelHandler createChannelHandler(Channel channel, RpcHandler TransportRequestHandler requestHandler = new TransportRequestHandler(channel, client, rpcHandler); return new TransportChannelHandler(client, responseHandler, requestHandler, - conf.connectionTimeoutMs()); + conf.connectionTimeoutMs(), closeIdleConnections); } public TransportConf getConf() { return conf; } diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java index 4952ffb44bb8b..42a4f664e697c 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java +++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java @@ -158,6 +158,16 @@ public TransportClient createClient(String remoteHost, int remotePort) throws IO } } + /** + * Create a completely new {@link TransportClient} to the given remote host / port + * But this connection is not pooled. + */ + public TransportClient createUnmanagedClient(String remoteHost, int remotePort) + throws IOException { + final InetSocketAddress address = new InetSocketAddress(remoteHost, remotePort); + return createClient(address); + } + /** Create a completely new {@link TransportClient} to the remote address. */ private TransportClient createClient(InetSocketAddress address) throws IOException { logger.debug("Creating new connection to " + address); diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java b/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java index 8e0ee709e38e3..f8fcd1c3d7d76 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java @@ -55,16 +55,19 @@ public class TransportChannelHandler extends SimpleChannelInboundHandler 0; + // there's no race between the idle timeout and incrementing the numOutstandingRequests + // (see SPARK-7003). boolean isActuallyOverdue = System.nanoTime() - responseHandler.getTimeOfLastRequestNs() > requestTimeoutNs; - if (e.state() == IdleState.ALL_IDLE && hasInFlightRequests && isActuallyOverdue) { - String address = NettyUtils.getRemoteAddress(ctx.channel()); - logger.error("Connection to {} has been quiet for {} ms while there are outstanding " + - "requests. Assuming connection is dead; please adjust spark.network.timeout if this " + - "is wrong.", address, requestTimeoutNs / 1000 / 1000); - ctx.close(); + if (e.state() == IdleState.ALL_IDLE && isActuallyOverdue) { + if (responseHandler.numOutstandingRequests() > 0) { + String address = NettyUtils.getRemoteAddress(ctx.channel()); + logger.error("Connection to {} has been quiet for {} ms while there are outstanding " + + "requests. Assuming connection is dead; please adjust spark.network.timeout if this " + + "is wrong.", address, requestTimeoutNs / 1000 / 1000); + ctx.close(); + } else if (closeIdleConnections) { + // While CloseIdleConnections is enable, we also close idle connection + ctx.close(); + } } } } diff --git a/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java b/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java index 35de5e57ccb98..f447137419306 100644 --- a/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java +++ b/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java @@ -21,6 +21,7 @@ import java.util.Collections; import java.util.HashSet; import java.util.Map; +import java.util.NoSuchElementException; import java.util.Set; import java.util.concurrent.atomic.AtomicInteger; @@ -37,6 +38,7 @@ import org.apache.spark.network.server.NoOpRpcHandler; import org.apache.spark.network.server.RpcHandler; import org.apache.spark.network.server.TransportServer; +import org.apache.spark.network.util.ConfigProvider; import org.apache.spark.network.util.SystemPropertyConfigProvider; import org.apache.spark.network.util.JavaUtils; import org.apache.spark.network.util.MapConfigProvider; @@ -177,4 +179,36 @@ public void closeBlockClientsWithFactory() throws IOException { assertFalse(c1.isActive()); assertFalse(c2.isActive()); } + + @Test + public void closeIdleConnectionForRequestTimeOut() throws IOException, InterruptedException { + TransportConf conf = new TransportConf(new ConfigProvider() { + + @Override + public String get(String name) { + if ("spark.shuffle.io.connectionTimeout".equals(name)) { + // We should make sure there is enough time for us to observe the channel is active + return "1s"; + } + String value = System.getProperty(name); + if (value == null) { + throw new NoSuchElementException(name); + } + return value; + } + }); + TransportContext context = new TransportContext(conf, new NoOpRpcHandler(), true); + TransportClientFactory factory = context.createClientFactory(); + try { + TransportClient c1 = factory.createClient(TestUtils.getLocalHost(), server1.getPort()); + assertTrue(c1.isActive()); + long expiredTime = System.currentTimeMillis() + 10000; // 10 seconds + while (c1.isActive() && System.currentTimeMillis() < expiredTime) { + Thread.sleep(10); + } + assertFalse(c1.isActive()); + } finally { + factory.close(); + } + } } diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java index ea6d248d66be3..ef3a9dcc8711f 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java @@ -78,7 +78,7 @@ protected void checkInit() { @Override public void init(String appId) { this.appId = appId; - TransportContext context = new TransportContext(conf, new NoOpRpcHandler()); + TransportContext context = new TransportContext(conf, new NoOpRpcHandler(), true); List bootstraps = Lists.newArrayList(); if (saslEnabled) { bootstraps.add(new SaslClientBootstrap(conf, appId, secretKeyHolder, saslEncryptionEnabled)); @@ -137,9 +137,13 @@ public void registerWithShuffleServer( String execId, ExecutorShuffleInfo executorInfo) throws IOException { checkInit(); - TransportClient client = clientFactory.createClient(host, port); - byte[] registerMessage = new RegisterExecutor(appId, execId, executorInfo).toByteArray(); - client.sendRpcSync(registerMessage, 5000 /* timeoutMs */); + TransportClient client = clientFactory.createUnmanagedClient(host, port); + try { + byte[] registerMessage = new RegisterExecutor(appId, execId, executorInfo).toByteArray(); + client.sendRpcSync(registerMessage, 5000 /* timeoutMs */); + } finally { + client.close(); + } } @Override From e0701c75601c43f69ed27fc7c252321703db51f2 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Tue, 10 Nov 2015 11:06:29 -0800 Subject: [PATCH 77/88] [SPARK-9830][SQL] Remove AggregateExpression1 and Aggregate Operator used to evaluate AggregateExpression1s https://issues.apache.org/jira/browse/SPARK-9830 This PR contains the following main changes. * Removing `AggregateExpression1`. * Removing `Aggregate` operator, which is used to evaluate `AggregateExpression1`. * Removing planner rule used to plan `Aggregate`. * Linking `MultipleDistinctRewriter` to analyzer. * Renaming `AggregateExpression2` to `AggregateExpression` and `AggregateFunction2` to `AggregateFunction`. * Updating places where we create aggregate expression. The way to create aggregate expressions is `AggregateExpression(aggregateFunction, mode, isDistinct)`. * Changing `val`s in `DeclarativeAggregate`s that touch children of this function to `lazy val`s (when we create aggregate expression in DataFrame API, children of an aggregate function can be unresolved). Author: Yin Huai Closes #9556 from yhuai/removeAgg1. --- R/pkg/R/functions.R | 2 +- python/pyspark/sql/dataframe.py | 2 +- python/pyspark/sql/functions.py | 2 +- python/pyspark/sql/tests.py | 2 +- .../spark/sql/catalyst/CatalystConf.scala | 10 +- .../apache/spark/sql/catalyst/SqlParser.scala | 14 +- .../sql/catalyst/analysis/Analyzer.scala | 26 +- .../sql/catalyst/analysis/CheckAnalysis.scala | 46 +- .../DistinctAggregationRewriter.scala} | 235 +--- .../catalyst/analysis/FunctionRegistry.scala | 2 + .../catalyst/analysis/HiveTypeCoercion.scala | 20 +- .../sql/catalyst/analysis/unresolved.scala | 4 + .../spark/sql/catalyst/dsl/package.scala | 22 +- .../expressions/aggregate/Average.scala | 31 +- .../aggregate/CentralMomentAgg.scala | 13 +- .../catalyst/expressions/aggregate/Corr.scala | 15 + .../expressions/aggregate/Count.scala | 28 +- .../expressions/aggregate/First.scala | 14 +- .../aggregate/HyperLogLogPlusPlus.scala | 17 + .../expressions/aggregate/Kurtosis.scala | 2 + .../catalyst/expressions/aggregate/Last.scala | 12 +- .../catalyst/expressions/aggregate/Max.scala | 17 +- .../catalyst/expressions/aggregate/Min.scala | 17 +- .../expressions/aggregate/Skewness.scala | 2 + .../expressions/aggregate/Stddev.scala | 31 +- .../catalyst/expressions/aggregate/Sum.scala | 29 +- .../expressions/aggregate/Variance.scala | 7 +- .../expressions/aggregate/interfaces.scala | 57 +- .../sql/catalyst/expressions/aggregates.scala | 1073 ----------------- .../sql/catalyst/optimizer/Optimizer.scala | 23 +- .../sql/catalyst/planning/patterns.scala | 74 -- .../spark/sql/catalyst/plans/QueryPlan.scala | 12 +- .../plans/logical/basicOperators.scala | 4 +- .../analysis/AnalysisErrorSuite.scala | 23 +- .../sql/catalyst/analysis/AnalysisSuite.scala | 2 +- .../analysis/DecimalPrecisionSuite.scala | 1 + .../ExpressionTypeCheckingSuite.scala | 6 +- .../optimizer/ConstantFoldingSuite.scala | 4 +- .../optimizer/FilterPushdownSuite.scala | 14 +- .../org/apache/spark/sql/DataFrame.scala | 13 +- .../org/apache/spark/sql/GroupedData.scala | 45 +- .../scala/org/apache/spark/sql/SQLConf.scala | 20 +- .../spark/sql/execution/Aggregate.scala | 205 ---- .../apache/spark/sql/execution/Expand.scala | 3 + .../spark/sql/execution/SparkPlanner.scala | 1 - .../spark/sql/execution/SparkStrategies.scala | 238 ++-- .../aggregate/AggregationIterator.scala | 28 +- .../aggregate/SortBasedAggregate.scala | 4 +- .../SortBasedAggregationIterator.scala | 8 +- .../aggregate/TungstenAggregate.scala | 6 +- .../TungstenAggregationIterator.scala | 36 +- .../spark/sql/execution/aggregate/udaf.scala | 2 +- .../spark/sql/execution/aggregate/utils.scala | 20 +- .../spark/sql/expressions/Aggregator.scala | 5 +- .../spark/sql/expressions/WindowSpec.scala | 82 +- .../apache/spark/sql/expressions/udaf.scala | 6 +- .../org/apache/spark/sql/functions.scala | 53 +- .../org/apache/spark/sql/SQLQuerySuite.scala | 69 +- .../spark/sql/UserDefinedTypeSuite.scala | 15 +- .../spark/sql/execution/PlannerSuite.scala | 2 +- .../execution/metric/SQLMetricsSuite.scala | 30 - .../apache/spark/sql/hive/HiveContext.scala | 1 - .../org/apache/spark/sql/hive/HiveQl.scala | 8 +- .../execution/AggregationQuerySuite.scala | 188 ++- 64 files changed, 743 insertions(+), 2260 deletions(-) rename sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/{expressions/aggregate/Utils.scala => analysis/DistinctAggregationRewriter.scala} (58%) delete mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index d7fd279279137..0b280870295a2 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -1339,7 +1339,7 @@ setMethod("pmod", signature(y = "Column"), #' @export setMethod("approxCountDistinct", signature(x = "Column"), - function(x, rsd = 0.95) { + function(x, rsd = 0.05) { jc <- callJStatic("org.apache.spark.sql.functions", "approxCountDistinct", x@jc, rsd) column(jc) }) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index b97c94dad834a..0dd75ba7ca820 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -866,7 +866,7 @@ def selectExpr(self, *expr): This is a variant of :func:`select` that accepts SQL expressions. >>> df.selectExpr("age * 2", "abs(age)").collect() - [Row((age * 2)=4, 'abs(age)=2), Row((age * 2)=10, 'abs(age)=5)] + [Row((age * 2)=4, abs(age)=2), Row((age * 2)=10, abs(age)=5)] """ if len(expr) == 1 and isinstance(expr[0], list): expr = expr[0] diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 962f676d406d8..6e1cbde4239f3 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -382,7 +382,7 @@ def expr(str): """Parses the expression string into the column that it represents >>> df.select(expr("length(name)")).collect() - [Row('length(name)=5), Row('length(name)=3)] + [Row(length(name)=5), Row(length(name)=3)] """ sc = SparkContext._active_spark_context return Column(sc._jvm.functions.expr(str)) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index e224574bcb301..9f5f7cfdf7a69 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -1017,7 +1017,7 @@ def test_expr(self): row = Row(a="length string", b=75) df = self.sqlCtx.createDataFrame([row]) result = df.select(functions.expr("length(a)")).collect()[0].asDict() - self.assertEqual(13, result["'length(a)"]) + self.assertEqual(13, result["length(a)"]) def test_replace(self): schema = StructType([ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala index 3f351b07b37df..7c2b8a9407884 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.catalyst private[spark] trait CatalystConf { def caseSensitiveAnalysis: Boolean + + protected[spark] def specializeSingleDistinctAggPlanning: Boolean } /** @@ -29,7 +31,13 @@ object EmptyConf extends CatalystConf { override def caseSensitiveAnalysis: Boolean = { throw new UnsupportedOperationException } + + protected[spark] override def specializeSingleDistinctAggPlanning: Boolean = { + throw new UnsupportedOperationException + } } /** A CatalystConf that can be used for local testing. */ -case class SimpleCatalystConf(caseSensitiveAnalysis: Boolean) extends CatalystConf +case class SimpleCatalystConf(caseSensitiveAnalysis: Boolean) extends CatalystConf { + protected[spark] override def specializeSingleDistinctAggPlanning: Boolean = true +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index cd717c09f8e5e..2a132d8b82bef 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -22,6 +22,7 @@ import scala.language.implicitConversions import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util.DataTypeParser @@ -272,7 +273,7 @@ object SqlParser extends AbstractSparkSQLParser with DataTypeParser { protected lazy val function: Parser[Expression] = ( ident <~ ("(" ~ "*" ~ ")") ^^ { case udfName => if (lexical.normalizeKeyword(udfName) == "count") { - Count(Literal(1)) + AggregateExpression(Count(Literal(1)), mode = Complete, isDistinct = false) } else { throw new AnalysisException(s"invalid expression $udfName(*)") } @@ -281,14 +282,14 @@ object SqlParser extends AbstractSparkSQLParser with DataTypeParser { { case udfName ~ exprs => UnresolvedFunction(udfName, exprs, isDistinct = false) } | ident ~ ("(" ~ DISTINCT ~> repsep(expression, ",")) <~ ")" ^^ { case udfName ~ exprs => lexical.normalizeKeyword(udfName) match { - case "sum" => SumDistinct(exprs.head) - case "count" => CountDistinct(exprs) + case "count" => + aggregate.Count(exprs).toAggregateExpression(isDistinct = true) case _ => UnresolvedFunction(udfName, exprs, isDistinct = true) } } | APPROXIMATE ~> ident ~ ("(" ~ DISTINCT ~> expression <~ ")") ^^ { case udfName ~ exp => if (lexical.normalizeKeyword(udfName) == "count") { - ApproxCountDistinct(exp) + AggregateExpression(new HyperLogLogPlusPlus(exp), mode = Complete, isDistinct = false) } else { throw new AnalysisException(s"invalid function approximate $udfName") } @@ -296,7 +297,10 @@ object SqlParser extends AbstractSparkSQLParser with DataTypeParser { | APPROXIMATE ~> "(" ~> unsignedFloat ~ ")" ~ ident ~ "(" ~ DISTINCT ~ expression <~ ")" ^^ { case s ~ _ ~ udfName ~ _ ~ _ ~ exp => if (lexical.normalizeKeyword(udfName) == "count") { - ApproxCountDistinct(exp, s.toDouble) + AggregateExpression( + HyperLogLogPlusPlus(exp, s.toDouble, 0, 0), + mode = Complete, + isDistinct = false) } else { throw new AnalysisException(s"invalid function approximate($s) $udfName") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 899ee67352df4..b1e14390b7dc0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -20,8 +20,8 @@ package org.apache.spark.sql.catalyst.analysis import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.expressions.aggregate.{Complete, AggregateExpression2, AggregateFunction2} import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.catalyst.trees.TreeNodeRef @@ -79,6 +79,7 @@ class Analyzer( ExtractWindowExpressions :: GlobalAggregates :: ResolveAggregateFunctions :: + DistinctAggregationRewriter(conf) :: HiveTypeCoercion.typeCoercionRules ++ extendedResolutionRules : _*), Batch("Nondeterministic", Once, @@ -525,21 +526,14 @@ class Analyzer( case u @ UnresolvedFunction(name, children, isDistinct) => withPosition(u) { registry.lookupFunction(name, children) match { - // We get an aggregate function built based on AggregateFunction2 interface. - // So, we wrap it in AggregateExpression2. - case agg2: AggregateFunction2 => AggregateExpression2(agg2, Complete, isDistinct) - // Currently, our old aggregate function interface supports SUM(DISTINCT ...) - // and COUTN(DISTINCT ...). - case sumDistinct: SumDistinct => sumDistinct - case countDistinct: CountDistinct => countDistinct - // DISTINCT is not meaningful with Max and Min. - case max: Max if isDistinct => max - case min: Min if isDistinct => min - // For other aggregate functions, DISTINCT keyword is not supported for now. - // Once we converted to the new code path, we will allow using DISTINCT keyword. - case other: AggregateExpression1 if isDistinct => - failAnalysis(s"$name does not support DISTINCT keyword.") - // If it does not have DISTINCT keyword, we will return it as is. + // DISTINCT is not meaningful for a Max or a Min. + case max: Max if isDistinct => + AggregateExpression(max, Complete, isDistinct = false) + case min: Min if isDistinct => + AggregateExpression(min, Complete, isDistinct = false) + // We get an aggregate function, we need to wrap it in an AggregateExpression. + case agg2: AggregateFunction => AggregateExpression(agg2, Complete, isDistinct) + // This function is not an aggregate function, just return the resolved one. case other => other } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 98d6637c0601b..8322e9930cd5a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateFunction, AggregateExpression} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.types._ @@ -108,7 +109,19 @@ trait CheckAnalysis { case Aggregate(groupingExprs, aggregateExprs, child) => def checkValidAggregateExpression(expr: Expression): Unit = expr match { - case _: AggregateExpression => // OK + case aggExpr: AggregateExpression => + // TODO: Is it possible that the child of a agg function is another + // agg function? + aggExpr.aggregateFunction.children.foreach { + // This is just a sanity check, our analysis rule PullOutNondeterministic should + // already pull out those nondeterministic expressions and evaluate them in + // a Project node. + case child if !child.deterministic => + failAnalysis( + s"nondeterministic expression ${expr.prettyString} should not " + + s"appear in the arguments of an aggregate function.") + case child => // OK + } case e: Attribute if !groupingExprs.exists(_.semanticEquals(e)) => failAnalysis( s"expression '${e.prettyString}' is neither present in the group by, " + @@ -120,14 +133,26 @@ trait CheckAnalysis { case e => e.children.foreach(checkValidAggregateExpression) } - def checkValidGroupingExprs(expr: Expression): Unit = expr.dataType match { - case BinaryType => - failAnalysis(s"binary type expression ${expr.prettyString} cannot be used " + - "in grouping expression") - case m: MapType => - failAnalysis(s"map type expression ${expr.prettyString} cannot be used " + - "in grouping expression") - case _ => // OK + def checkValidGroupingExprs(expr: Expression): Unit = { + expr.dataType match { + case BinaryType => + failAnalysis(s"binary type expression ${expr.prettyString} cannot be used " + + "in grouping expression") + case a: ArrayType => + failAnalysis(s"array type expression ${expr.prettyString} cannot be used " + + "in grouping expression") + case m: MapType => + failAnalysis(s"map type expression ${expr.prettyString} cannot be used " + + "in grouping expression") + case _ => // OK + } + if (!expr.deterministic) { + // This is just a sanity check, our analysis rule PullOutNondeterministic should + // already pull out those nondeterministic expressions and evaluate them in + // a Project node. + failAnalysis(s"nondeterministic expression ${expr.prettyString} should not " + + s"appear in grouping expression.") + } } aggregateExprs.foreach(checkValidAggregateExpression) @@ -179,7 +204,8 @@ trait CheckAnalysis { s"unresolved operator ${operator.simpleString}") case o if o.expressions.exists(!_.deterministic) && - !o.isInstanceOf[Project] && !o.isInstanceOf[Filter] => + !o.isInstanceOf[Project] && !o.isInstanceOf[Filter] & !o.isInstanceOf[Aggregate] => + // The rule above is used to check Aggregate operator. failAnalysis( s"""nondeterministic expressions are only allowed in Project or Filter, found: | ${o.expressions.map(_.prettyString).mkString(",")} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Utils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala similarity index 58% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Utils.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala index 9b22ce2619731..397eff05686b6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Utils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DistinctAggregationRewriter.scala @@ -15,215 +15,17 @@ * limitations under the License. */ -package org.apache.spark.sql.catalyst.expressions.aggregate +package org.apache.spark.sql.catalyst.analysis -import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst._ +import org.apache.spark.sql.catalyst.CatalystConf import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical.{Expand, Aggregate, LogicalPlan} +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction, Complete} +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Expand, LogicalPlan} import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.types._ +import org.apache.spark.sql.types.IntegerType /** - * Utility functions used by the query planner to convert our plan to new aggregation code path. - */ -object Utils { - - // Check if the DataType given cannot be part of a group by clause. - private def isUnGroupable(dt: DataType): Boolean = dt match { - case _: ArrayType | _: MapType => true - case s: StructType => s.fields.exists(f => isUnGroupable(f.dataType)) - case _ => false - } - - // Right now, we do not support complex types in the grouping key schema. - private def supportsGroupingKeySchema(aggregate: Aggregate): Boolean = - !aggregate.groupingExpressions.exists(e => isUnGroupable(e.dataType)) - - private def doConvert(plan: LogicalPlan): Option[Aggregate] = plan match { - case p: Aggregate if supportsGroupingKeySchema(p) => - - val converted = MultipleDistinctRewriter.rewrite(p.transformExpressionsDown { - case expressions.Average(child) => - aggregate.AggregateExpression2( - aggregateFunction = aggregate.Average(child), - mode = aggregate.Complete, - isDistinct = false) - - case expressions.Count(child) => - aggregate.AggregateExpression2( - aggregateFunction = aggregate.Count(child), - mode = aggregate.Complete, - isDistinct = false) - - case expressions.CountDistinct(children) => - val child = if (children.size > 1) { - DropAnyNull(CreateStruct(children)) - } else { - children.head - } - aggregate.AggregateExpression2( - aggregateFunction = aggregate.Count(child), - mode = aggregate.Complete, - isDistinct = true) - - case expressions.First(child, ignoreNulls) => - aggregate.AggregateExpression2( - aggregateFunction = aggregate.First(child, ignoreNulls), - mode = aggregate.Complete, - isDistinct = false) - - case expressions.Kurtosis(child) => - aggregate.AggregateExpression2( - aggregateFunction = aggregate.Kurtosis(child), - mode = aggregate.Complete, - isDistinct = false) - - case expressions.Last(child, ignoreNulls) => - aggregate.AggregateExpression2( - aggregateFunction = aggregate.Last(child, ignoreNulls), - mode = aggregate.Complete, - isDistinct = false) - - case expressions.Max(child) => - aggregate.AggregateExpression2( - aggregateFunction = aggregate.Max(child), - mode = aggregate.Complete, - isDistinct = false) - - case expressions.Min(child) => - aggregate.AggregateExpression2( - aggregateFunction = aggregate.Min(child), - mode = aggregate.Complete, - isDistinct = false) - - case expressions.Skewness(child) => - aggregate.AggregateExpression2( - aggregateFunction = aggregate.Skewness(child), - mode = aggregate.Complete, - isDistinct = false) - - case expressions.StddevPop(child) => - aggregate.AggregateExpression2( - aggregateFunction = aggregate.StddevPop(child), - mode = aggregate.Complete, - isDistinct = false) - - case expressions.StddevSamp(child) => - aggregate.AggregateExpression2( - aggregateFunction = aggregate.StddevSamp(child), - mode = aggregate.Complete, - isDistinct = false) - - case expressions.Sum(child) => - aggregate.AggregateExpression2( - aggregateFunction = aggregate.Sum(child), - mode = aggregate.Complete, - isDistinct = false) - - case expressions.SumDistinct(child) => - aggregate.AggregateExpression2( - aggregateFunction = aggregate.Sum(child), - mode = aggregate.Complete, - isDistinct = true) - - case expressions.Corr(left, right) => - aggregate.AggregateExpression2( - aggregateFunction = aggregate.Corr(left, right), - mode = aggregate.Complete, - isDistinct = false) - - case expressions.ApproxCountDistinct(child, rsd) => - aggregate.AggregateExpression2( - aggregateFunction = aggregate.HyperLogLogPlusPlus(child, rsd), - mode = aggregate.Complete, - isDistinct = false) - - case expressions.VariancePop(child) => - aggregate.AggregateExpression2( - aggregateFunction = aggregate.VariancePop(child), - mode = aggregate.Complete, - isDistinct = false) - - case expressions.VarianceSamp(child) => - aggregate.AggregateExpression2( - aggregateFunction = aggregate.VarianceSamp(child), - mode = aggregate.Complete, - isDistinct = false) - }) - - // Check if there is any expressions.AggregateExpression1 left. - // If so, we cannot convert this plan. - val hasAggregateExpression1 = converted.aggregateExpressions.exists { expr => - // For every expressions, check if it contains AggregateExpression1. - expr.find { - case agg: expressions.AggregateExpression1 => true - case other => false - }.isDefined - } - - // Check if there are multiple distinct columns. - // TODO remove this. - val aggregateExpressions = converted.aggregateExpressions.flatMap { expr => - expr.collect { - case agg: AggregateExpression2 => agg - } - }.toSet.toSeq - val functionsWithDistinct = aggregateExpressions.filter(_.isDistinct) - val hasMultipleDistinctColumnSets = - if (functionsWithDistinct.map(_.aggregateFunction.children).distinct.length > 1) { - true - } else { - false - } - - if (!hasAggregateExpression1 && !hasMultipleDistinctColumnSets) Some(converted) else None - - case other => None - } - - def checkInvalidAggregateFunction2(aggregate: Aggregate): Unit = { - // If the plan cannot be converted, we will do a final round check to see if the original - // logical.Aggregate contains both AggregateExpression1 and AggregateExpression2. If so, - // we need to throw an exception. - val aggregateFunction2s = aggregate.aggregateExpressions.flatMap { expr => - expr.collect { - case agg: AggregateExpression2 => agg.aggregateFunction - } - }.distinct - if (aggregateFunction2s.nonEmpty) { - // For functions implemented based on the new interface, prepare a list of function names. - val invalidFunctions = { - if (aggregateFunction2s.length > 1) { - s"${aggregateFunction2s.tail.map(_.nodeName).mkString(",")} " + - s"and ${aggregateFunction2s.head.nodeName} are" - } else { - s"${aggregateFunction2s.head.nodeName} is" - } - } - val errorMessage = - s"${invalidFunctions} implemented based on the new Aggregate Function " + - s"interface and it cannot be used with functions implemented based on " + - s"the old Aggregate Function interface." - throw new AnalysisException(errorMessage) - } - } - - def tryConvert(plan: LogicalPlan): Option[Aggregate] = plan match { - case p: Aggregate => - val converted = doConvert(p) - if (converted.isDefined) { - converted - } else { - checkInvalidAggregateFunction2(p) - None - } - case other => None - } -} - -/** - * This rule rewrites an aggregate query with multiple distinct clauses into an expanded double + * This rule rewrites an aggregate query with distinct aggregations into an expanded double * aggregation in which the regular aggregation expressions and every distinct clause is aggregated * in a separate group. The results are then combined in a second aggregate. * @@ -298,9 +100,11 @@ object Utils { * we could improve this in the current rule by applying more advanced expression cannocalization * techniques. */ -object MultipleDistinctRewriter extends Rule[LogicalPlan] { +case class DistinctAggregationRewriter(conf: CatalystConf) extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + case p if !p.resolved => p + // We need to wait until this Aggregate operator is resolved. case a: Aggregate => rewrite(a) case p => p } @@ -310,7 +114,7 @@ object MultipleDistinctRewriter extends Rule[LogicalPlan] { // Collect all aggregate expressions. val aggExpressions = a.aggregateExpressions.flatMap { e => e.collect { - case ae: AggregateExpression2 => ae + case ae: AggregateExpression => ae } } @@ -319,8 +123,15 @@ object MultipleDistinctRewriter extends Rule[LogicalPlan] { .filter(_.isDistinct) .groupBy(_.aggregateFunction.children.toSet) - // Only continue to rewrite if there is more than one distinct group. - if (distinctAggGroups.size > 1) { + val shouldRewrite = if (conf.specializeSingleDistinctAggPlanning) { + // When the flag is set to specialize single distinct agg planning, + // we will rely on our Aggregation strategy to handle queries with a single + // distinct column and this aggregate operator does have grouping expressions. + distinctAggGroups.size > 1 || (distinctAggGroups.size == 1 && a.groupingExpressions.isEmpty) + } else { + distinctAggGroups.size >= 1 + } + if (shouldRewrite) { // Create the attributes for the grouping id and the group by clause. val gid = new AttributeReference("gid", IntegerType, false)() val groupByMap = a.groupingExpressions.collect { @@ -332,11 +143,11 @@ object MultipleDistinctRewriter extends Rule[LogicalPlan] { // Functions used to modify aggregate functions and their inputs. def evalWithinGroup(id: Literal, e: Expression) = If(EqualTo(gid, id), e, nullify(e)) def patchAggregateFunctionChildren( - af: AggregateFunction2)( - attrs: Expression => Expression): AggregateFunction2 = { + af: AggregateFunction)( + attrs: Expression => Expression): AggregateFunction = { af.withNewChildren(af.children.map { case afc => attrs(afc) - }).asInstanceOf[AggregateFunction2] + }).asInstanceOf[AggregateFunction] } // Setup unique distinct aggregate children. @@ -381,7 +192,7 @@ object MultipleDistinctRewriter extends Rule[LogicalPlan] { val operator = Alias(e.copy(aggregateFunction = af), e.prettyString)() // Select the result of the first aggregate in the last aggregate. - val result = AggregateExpression2( + val result = AggregateExpression( aggregate.First(evalWithinGroup(regularGroupId, operator.toAttribute), Literal(true)), mode = Complete, isDistinct = false) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index d4334d16289a5..dfa749d1afa5b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -24,6 +24,7 @@ import scala.util.{Failure, Success, Try} import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.util.StringKeyHashMap @@ -177,6 +178,7 @@ object FunctionRegistry { expression[ToRadians]("radians"), // aggregate functions + expression[HyperLogLogPlusPlus]("approx_count_distinct"), expression[Average]("avg"), expression[Corr]("corr"), expression[Count]("count"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 84e2b1366f626..bf2bff0243fa3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.analysis import javax.annotation.Nullable import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.types._ @@ -295,14 +296,17 @@ object HiveTypeCoercion { i.makeCopy(Array(Cast(a, StringType), b.map(Cast(_, StringType)))) case Sum(e @ StringType()) => Sum(Cast(e, DoubleType)) - case SumDistinct(e @ StringType()) => Sum(Cast(e, DoubleType)) case Average(e @ StringType()) => Average(Cast(e, DoubleType)) case StddevPop(e @ StringType()) => StddevPop(Cast(e, DoubleType)) case StddevSamp(e @ StringType()) => StddevSamp(Cast(e, DoubleType)) - case VariancePop(e @ StringType()) => VariancePop(Cast(e, DoubleType)) - case VarianceSamp(e @ StringType()) => VarianceSamp(Cast(e, DoubleType)) - case Skewness(e @ StringType()) => Skewness(Cast(e, DoubleType)) - case Kurtosis(e @ StringType()) => Kurtosis(Cast(e, DoubleType)) + case VariancePop(e @ StringType(), mutableAggBufferOffset, inputAggBufferOffset) => + VariancePop(Cast(e, DoubleType), mutableAggBufferOffset, inputAggBufferOffset) + case VarianceSamp(e @ StringType(), mutableAggBufferOffset, inputAggBufferOffset) => + VarianceSamp(Cast(e, DoubleType), mutableAggBufferOffset, inputAggBufferOffset) + case Skewness(e @ StringType(), mutableAggBufferOffset, inputAggBufferOffset) => + Skewness(Cast(e, DoubleType), mutableAggBufferOffset, inputAggBufferOffset) + case Kurtosis(e @ StringType(), mutableAggBufferOffset, inputAggBufferOffset) => + Kurtosis(Cast(e, DoubleType), mutableAggBufferOffset, inputAggBufferOffset) } } @@ -562,12 +566,6 @@ object HiveTypeCoercion { case Sum(e @ IntegralType()) if e.dataType != LongType => Sum(Cast(e, LongType)) case Sum(e @ FractionalType()) if e.dataType != DoubleType => Sum(Cast(e, DoubleType)) - case s @ SumDistinct(e @ DecimalType()) => s // Decimal is already the biggest. - case SumDistinct(e @ IntegralType()) if e.dataType != LongType => - SumDistinct(Cast(e, LongType)) - case SumDistinct(e @ FractionalType()) if e.dataType != DoubleType => - SumDistinct(Cast(e, DoubleType)) - case s @ Average(e @ DecimalType()) => s // Decimal is already the biggest. case Average(e @ IntegralType()) if e.dataType != LongType => Average(Cast(e, LongType)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index eae17c86ddc7a..6485bdfb30234 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -141,6 +141,10 @@ case class UnresolvedFunction( override def nullable: Boolean = throw new UnresolvedException(this, "nullable") override lazy val resolved = false + override def prettyString: String = { + s"${name}(${children.map(_.prettyString).mkString(",")})" + } + override def toString: String = s"'$name(${children.mkString(",")})" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index d8df66430a695..af594c25c54cb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -23,6 +23,7 @@ import scala.language.implicitConversions import org.apache.spark.sql.catalyst.analysis.{EliminateSubQueries, UnresolvedExtractValue, UnresolvedAttribute} import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.{Inner, JoinType} import org.apache.spark.sql.types._ @@ -144,17 +145,18 @@ package object dsl { } } - def sum(e: Expression): Expression = Sum(e) - def sumDistinct(e: Expression): Expression = SumDistinct(e) - def count(e: Expression): Expression = Count(e) - def countDistinct(e: Expression*): Expression = CountDistinct(e) + def sum(e: Expression): Expression = Sum(e).toAggregateExpression() + def sumDistinct(e: Expression): Expression = Sum(e).toAggregateExpression(isDistinct = true) + def count(e: Expression): Expression = Count(e).toAggregateExpression() + def countDistinct(e: Expression*): Expression = + Count(e).toAggregateExpression(isDistinct = true) def approxCountDistinct(e: Expression, rsd: Double = 0.05): Expression = - ApproxCountDistinct(e, rsd) - def avg(e: Expression): Expression = Average(e) - def first(e: Expression): Expression = First(e) - def last(e: Expression): Expression = Last(e) - def min(e: Expression): Expression = Min(e) - def max(e: Expression): Expression = Max(e) + HyperLogLogPlusPlus(e, rsd).toAggregateExpression() + def avg(e: Expression): Expression = Average(e).toAggregateExpression() + def first(e: Expression): Expression = new First(e).toAggregateExpression() + def last(e: Expression): Expression = new Last(e).toAggregateExpression() + def min(e: Expression): Expression = Min(e).toAggregateExpression() + def max(e: Expression): Expression = Max(e).toAggregateExpression() def upper(e: Expression): Expression = Upper(e) def lower(e: Expression): Expression = Lower(e) def sqrt(e: Expression): Expression = Sqrt(e) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala index c8c20ada5fbc7..7f9e5034702e9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala @@ -17,8 +17,10 @@ package org.apache.spark.sql.catalyst.expressions.aggregate +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ case class Average(child: Expression) extends DeclarativeAggregate { @@ -32,36 +34,33 @@ case class Average(child: Expression) extends DeclarativeAggregate { // Return data type. override def dataType: DataType = resultType - // Expected input data type. - // TODO: Right now, we replace old aggregate functions (based on AggregateExpression1) to the - // new version at planning time (after analysis phase). For now, NullType is added at here - // to make it resolved when we have cases like `select avg(null)`. - // We can use our analyzer to cast NullType to the default data type of the NumericType once - // we remove the old aggregate functions. Then, we will not need NullType at here. - override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(NumericType, NullType)) + override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(NumericType)) - private val resultType = child.dataType match { + override def checkInputDataTypes(): TypeCheckResult = + TypeUtils.checkForNumericExpr(child.dataType, "function average") + + private lazy val resultType = child.dataType match { case DecimalType.Fixed(p, s) => DecimalType.bounded(p + 4, s + 4) case _ => DoubleType } - private val sumDataType = child.dataType match { + private lazy val sumDataType = child.dataType match { case _ @ DecimalType.Fixed(p, s) => DecimalType.bounded(p + 10, s) case _ => DoubleType } - private val sum = AttributeReference("sum", sumDataType)() - private val count = AttributeReference("count", LongType)() + private lazy val sum = AttributeReference("sum", sumDataType)() + private lazy val count = AttributeReference("count", LongType)() - override val aggBufferAttributes = sum :: count :: Nil + override lazy val aggBufferAttributes = sum :: count :: Nil - override val initialValues = Seq( + override lazy val initialValues = Seq( /* sum = */ Cast(Literal(0), sumDataType), /* count = */ Literal(0L) ) - override val updateExpressions = Seq( + override lazy val updateExpressions = Seq( /* sum = */ Add( sum, @@ -69,13 +68,13 @@ case class Average(child: Expression) extends DeclarativeAggregate { /* count = */ If(IsNull(child), count, count + 1L) ) - override val mergeExpressions = Seq( + override lazy val mergeExpressions = Seq( /* sum = */ sum.left + sum.right, /* count = */ count.left + count.right ) // If all input are nulls, count will be 0 and we will get null after the division. - override val evaluateExpression = child.dataType match { + override lazy val evaluateExpression = child.dataType match { case DecimalType.Fixed(p, s) => // increase the precision and scale to prevent precision loss val dt = DecimalType.bounded(p + 14, s + 4) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala index ef08b025ff556..984ce7f24dacc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala @@ -18,7 +18,9 @@ package org.apache.spark.sql.catalyst.expressions.aggregate import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ /** @@ -55,13 +57,10 @@ abstract class CentralMomentAgg(child: Expression) extends ImperativeAggregate w override def dataType: DataType = DoubleType - // Expected input data type. - // TODO: Right now, we replace old aggregate functions (based on AggregateExpression1) to the - // new version at planning time (after analysis phase). For now, NullType is added at here - // to make it resolved when we have cases like `select avg(null)`. - // We can use our analyzer to cast NullType to the default data type of the NumericType once - // we remove the old aggregate functions. Then, we will not need NullType at here. - override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(NumericType, NullType)) + override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(NumericType)) + + override def checkInputDataTypes(): TypeCheckResult = + TypeUtils.checkForNumericExpr(child.dataType, s"function $prettyName") override def aggBufferSchema: StructType = StructType.fromAttributes(aggBufferAttributes) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala index 832338378fb38..00d7436b710d2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala @@ -18,7 +18,9 @@ package org.apache.spark.sql.catalyst.expressions.aggregate import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ /** @@ -35,6 +37,9 @@ case class Corr( inputAggBufferOffset: Int = 0) extends ImperativeAggregate { + def this(left: Expression, right: Expression) = + this(left, right, mutableAggBufferOffset = 0, inputAggBufferOffset = 0) + override def children: Seq[Expression] = Seq(left, right) override def nullable: Boolean = false @@ -43,6 +48,16 @@ case class Corr( override def inputTypes: Seq[AbstractDataType] = Seq(DoubleType, DoubleType) + override def checkInputDataTypes(): TypeCheckResult = { + if (left.dataType.isInstanceOf[DoubleType] && right.dataType.isInstanceOf[DoubleType]) { + TypeCheckResult.TypeCheckSuccess + } else { + TypeCheckResult.TypeCheckFailure( + s"corr requires that both arguments are double type, " + + s"not (${left.dataType}, ${right.dataType}).") + } + } + override def aggBufferSchema: StructType = StructType.fromAttributes(aggBufferAttributes) override def inputAggBufferAttributes: Seq[AttributeReference] = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala index ec0c8b483a909..09a1da9200df0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala @@ -32,23 +32,39 @@ case class Count(child: Expression) extends DeclarativeAggregate { // Expected input data type. override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType) - private val count = AttributeReference("count", LongType)() + private lazy val count = AttributeReference("count", LongType)() - override val aggBufferAttributes = count :: Nil + override lazy val aggBufferAttributes = count :: Nil - override val initialValues = Seq( + override lazy val initialValues = Seq( /* count = */ Literal(0L) ) - override val updateExpressions = Seq( + override lazy val updateExpressions = Seq( /* count = */ If(IsNull(child), count, count + 1L) ) - override val mergeExpressions = Seq( + override lazy val mergeExpressions = Seq( /* count = */ count.left + count.right ) - override val evaluateExpression = Cast(count, LongType) + override lazy val evaluateExpression = Cast(count, LongType) override def defaultResult: Option[Literal] = Option(Literal(0L)) } + +object Count { + def apply(children: Seq[Expression]): Count = { + // This is used to deal with COUNT DISTINCT. When we have multiple + // children (COUNT(DISTINCT col1, col2, ...)), we wrap them in a STRUCT (i.e. a Row). + // Also, the semantic of COUNT(DISTINCT col1, col2, ...) is that if there is any + // null in the arguments, we will not count that row. So, we use DropAnyNull at here + // to return a null when any field of the created STRUCT is null. + val child = if (children.size > 1) { + DropAnyNull(CreateStruct(children)) + } else { + children.head + } + Count(child) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala index 9028143015853..35f57426feaf2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala @@ -51,18 +51,18 @@ case class First(child: Expression, ignoreNullsExpr: Expression) extends Declara // Expected input data type. override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType) - private val first = AttributeReference("first", child.dataType)() + private lazy val first = AttributeReference("first", child.dataType)() - private val valueSet = AttributeReference("valueSet", BooleanType)() + private lazy val valueSet = AttributeReference("valueSet", BooleanType)() - override val aggBufferAttributes: Seq[AttributeReference] = first :: valueSet :: Nil + override lazy val aggBufferAttributes: Seq[AttributeReference] = first :: valueSet :: Nil - override val initialValues: Seq[Literal] = Seq( + override lazy val initialValues: Seq[Literal] = Seq( /* first = */ Literal.create(null, child.dataType), /* valueSet = */ Literal.create(false, BooleanType) ) - override val updateExpressions: Seq[Expression] = { + override lazy val updateExpressions: Seq[Expression] = { if (ignoreNulls) { Seq( /* first = */ If(Or(valueSet, IsNull(child)), first, child), @@ -76,7 +76,7 @@ case class First(child: Expression, ignoreNullsExpr: Expression) extends Declara } } - override val mergeExpressions: Seq[Expression] = { + override lazy val mergeExpressions: Seq[Expression] = { // For first, we can just check if valueSet.left is set to true. If it is set // to true, we use first.right. If not, we use first.right (even if valueSet.right is // false, we are safe to do so because first.right will be null in this case). @@ -86,7 +86,7 @@ case class First(child: Expression, ignoreNullsExpr: Expression) extends Declara ) } - override val evaluateExpression: AttributeReference = first + override lazy val evaluateExpression: AttributeReference = first override def toString: String = s"first($child)${if (ignoreNulls) " ignore nulls"}" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala index 8d341ee630bdb..8a95c541f1e86 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala @@ -22,6 +22,7 @@ import java.util import com.clearspring.analytics.hash.MurmurHash +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ @@ -55,6 +56,22 @@ case class HyperLogLogPlusPlus( extends ImperativeAggregate { import HyperLogLogPlusPlus._ + def this(child: Expression) = { + this(child = child, relativeSD = 0.05, mutableAggBufferOffset = 0, inputAggBufferOffset = 0) + } + + def this(child: Expression, relativeSD: Expression) = { + this( + child = child, + relativeSD = relativeSD match { + case Literal(d: Double, DoubleType) => d + case _ => + throw new AnalysisException("The second argument should be a double literal.") + }, + mutableAggBufferOffset = 0, + inputAggBufferOffset = 0) + } + override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = copy(mutableAggBufferOffset = newMutableAggBufferOffset) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Kurtosis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Kurtosis.scala index 6da39e7143447..bae78d98493b8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Kurtosis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Kurtosis.scala @@ -24,6 +24,8 @@ case class Kurtosis(child: Expression, inputAggBufferOffset: Int = 0) extends CentralMomentAgg(child) { + def this(child: Expression) = this(child, mutableAggBufferOffset = 0, inputAggBufferOffset = 0) + override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = copy(mutableAggBufferOffset = newMutableAggBufferOffset) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala index 8636bfe8d07aa..be7e12d7a2336 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala @@ -51,15 +51,15 @@ case class Last(child: Expression, ignoreNullsExpr: Expression) extends Declarat // Expected input data type. override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType) - private val last = AttributeReference("last", child.dataType)() + private lazy val last = AttributeReference("last", child.dataType)() - override val aggBufferAttributes: Seq[AttributeReference] = last :: Nil + override lazy val aggBufferAttributes: Seq[AttributeReference] = last :: Nil - override val initialValues: Seq[Literal] = Seq( + override lazy val initialValues: Seq[Literal] = Seq( /* last = */ Literal.create(null, child.dataType) ) - override val updateExpressions: Seq[Expression] = { + override lazy val updateExpressions: Seq[Expression] = { if (ignoreNulls) { Seq( /* last = */ If(IsNull(child), last, child) @@ -71,7 +71,7 @@ case class Last(child: Expression, ignoreNullsExpr: Expression) extends Declarat } } - override val mergeExpressions: Seq[Expression] = { + override lazy val mergeExpressions: Seq[Expression] = { if (ignoreNulls) { Seq( /* last = */ If(IsNull(last.right), last.left, last.right) @@ -83,7 +83,7 @@ case class Last(child: Expression, ignoreNullsExpr: Expression) extends Declarat } } - override val evaluateExpression: AttributeReference = last + override lazy val evaluateExpression: AttributeReference = last override def toString: String = s"last($child)${if (ignoreNulls) " ignore nulls"}" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala index b9d75ad452838..61cae44cd0f5b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala @@ -17,7 +17,9 @@ package org.apache.spark.sql.catalyst.expressions.aggregate +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ case class Max(child: Expression) extends DeclarativeAggregate { @@ -32,24 +34,27 @@ case class Max(child: Expression) extends DeclarativeAggregate { // Expected input data type. override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType) - private val max = AttributeReference("max", child.dataType)() + override def checkInputDataTypes(): TypeCheckResult = + TypeUtils.checkForOrderingExpr(child.dataType, "function max") - override val aggBufferAttributes: Seq[AttributeReference] = max :: Nil + private lazy val max = AttributeReference("max", child.dataType)() - override val initialValues: Seq[Literal] = Seq( + override lazy val aggBufferAttributes: Seq[AttributeReference] = max :: Nil + + override lazy val initialValues: Seq[Literal] = Seq( /* max = */ Literal.create(null, child.dataType) ) - override val updateExpressions: Seq[Expression] = Seq( + override lazy val updateExpressions: Seq[Expression] = Seq( /* max = */ If(IsNull(child), max, If(IsNull(max), child, Greatest(Seq(max, child)))) ) - override val mergeExpressions: Seq[Expression] = { + override lazy val mergeExpressions: Seq[Expression] = { val greatest = Greatest(Seq(max.left, max.right)) Seq( /* max = */ If(IsNull(max.right), max.left, If(IsNull(max.left), max.right, greatest)) ) } - override val evaluateExpression: AttributeReference = max + override lazy val evaluateExpression: AttributeReference = max } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala index 5ed9cd348daba..242456d9e2e18 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala @@ -17,7 +17,9 @@ package org.apache.spark.sql.catalyst.expressions.aggregate +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ @@ -33,24 +35,27 @@ case class Min(child: Expression) extends DeclarativeAggregate { // Expected input data type. override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType) - private val min = AttributeReference("min", child.dataType)() + override def checkInputDataTypes(): TypeCheckResult = + TypeUtils.checkForOrderingExpr(child.dataType, "function min") - override val aggBufferAttributes: Seq[AttributeReference] = min :: Nil + private lazy val min = AttributeReference("min", child.dataType)() - override val initialValues: Seq[Expression] = Seq( + override lazy val aggBufferAttributes: Seq[AttributeReference] = min :: Nil + + override lazy val initialValues: Seq[Expression] = Seq( /* min = */ Literal.create(null, child.dataType) ) - override val updateExpressions: Seq[Expression] = Seq( + override lazy val updateExpressions: Seq[Expression] = Seq( /* min = */ If(IsNull(child), min, If(IsNull(min), child, Least(Seq(min, child)))) ) - override val mergeExpressions: Seq[Expression] = { + override lazy val mergeExpressions: Seq[Expression] = { val least = Least(Seq(min.left, min.right)) Seq( /* min = */ If(IsNull(min.right), min.left, If(IsNull(min.left), min.right, least)) ) } - override val evaluateExpression: AttributeReference = min + override lazy val evaluateExpression: AttributeReference = min } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Skewness.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Skewness.scala index 0def7ddfd9d3d..c593074fa2479 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Skewness.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Skewness.scala @@ -24,6 +24,8 @@ case class Skewness(child: Expression, inputAggBufferOffset: Int = 0) extends CentralMomentAgg(child) { + def this(child: Expression) = this(child, mutableAggBufferOffset = 0, inputAggBufferOffset = 0) + override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = copy(mutableAggBufferOffset = newMutableAggBufferOffset) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Stddev.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Stddev.scala index 3f47ffe13cbc8..5b9eb7ae02f25 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Stddev.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Stddev.scala @@ -17,8 +17,10 @@ package org.apache.spark.sql.catalyst.expressions.aggregate +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ @@ -48,29 +50,26 @@ abstract class StddevAgg(child: Expression) extends DeclarativeAggregate { override def dataType: DataType = resultType - // Expected input data type. - // TODO: Right now, we replace old aggregate functions (based on AggregateExpression1) to the - // new version at planning time (after analysis phase). For now, NullType is added at here - // to make it resolved when we have cases like `select stddev(null)`. - // We can use our analyzer to cast NullType to the default data type of the NumericType once - // we remove the old aggregate functions. Then, we will not need NullType at here. - override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(NumericType, NullType)) + override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(NumericType)) - private val resultType = DoubleType + override def checkInputDataTypes(): TypeCheckResult = + TypeUtils.checkForNumericExpr(child.dataType, "function stddev") - private val count = AttributeReference("count", resultType)() - private val avg = AttributeReference("avg", resultType)() - private val mk = AttributeReference("mk", resultType)() + private lazy val resultType = DoubleType - override val aggBufferAttributes = count :: avg :: mk :: Nil + private lazy val count = AttributeReference("count", resultType)() + private lazy val avg = AttributeReference("avg", resultType)() + private lazy val mk = AttributeReference("mk", resultType)() - override val initialValues: Seq[Expression] = Seq( + override lazy val aggBufferAttributes = count :: avg :: mk :: Nil + + override lazy val initialValues: Seq[Expression] = Seq( /* count = */ Cast(Literal(0), resultType), /* avg = */ Cast(Literal(0), resultType), /* mk = */ Cast(Literal(0), resultType) ) - override val updateExpressions: Seq[Expression] = { + override lazy val updateExpressions: Seq[Expression] = { val value = Cast(child, resultType) val newCount = count + Cast(Literal(1), resultType) @@ -89,7 +88,7 @@ abstract class StddevAgg(child: Expression) extends DeclarativeAggregate { ) } - override val mergeExpressions: Seq[Expression] = { + override lazy val mergeExpressions: Seq[Expression] = { // count merge val newCount = count.left + count.right @@ -114,7 +113,7 @@ abstract class StddevAgg(child: Expression) extends DeclarativeAggregate { ) } - override val evaluateExpression: Expression = { + override lazy val evaluateExpression: Expression = { // when count == 0, return null // when count == 1, return 0 // when count >1 diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala index 7f8adbc56ad1d..c005ec9657211 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala @@ -17,7 +17,9 @@ package org.apache.spark.sql.catalyst.expressions.aggregate +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ case class Sum(child: Expression) extends DeclarativeAggregate { @@ -29,16 +31,13 @@ case class Sum(child: Expression) extends DeclarativeAggregate { // Return data type. override def dataType: DataType = resultType - // Expected input data type. - // TODO: Right now, we replace old aggregate functions (based on AggregateExpression1) to the - // new version at planning time (after analysis phase). For now, NullType is added at here - // to make it resolved when we have cases like `select sum(null)`. - // We can use our analyzer to cast NullType to the default data type of the NumericType once - // we remove the old aggregate functions. Then, we will not need NullType at here. override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(LongType, DoubleType, DecimalType, NullType)) - private val resultType = child.dataType match { + override def checkInputDataTypes(): TypeCheckResult = + TypeUtils.checkForNumericExpr(child.dataType, "function sum") + + private lazy val resultType = child.dataType match { case DecimalType.Fixed(precision, scale) => DecimalType.bounded(precision + 10, scale) // TODO: Remove this line once we remove the NullType from inputTypes. @@ -46,24 +45,24 @@ case class Sum(child: Expression) extends DeclarativeAggregate { case _ => child.dataType } - private val sumDataType = resultType + private lazy val sumDataType = resultType - private val sum = AttributeReference("sum", sumDataType)() + private lazy val sum = AttributeReference("sum", sumDataType)() - private val zero = Cast(Literal(0), sumDataType) + private lazy val zero = Cast(Literal(0), sumDataType) - override val aggBufferAttributes = sum :: Nil + override lazy val aggBufferAttributes = sum :: Nil - override val initialValues: Seq[Expression] = Seq( + override lazy val initialValues: Seq[Expression] = Seq( /* sum = */ Literal.create(null, sumDataType) ) - override val updateExpressions: Seq[Expression] = Seq( + override lazy val updateExpressions: Seq[Expression] = Seq( /* sum = */ Coalesce(Seq(Add(Coalesce(Seq(sum, zero)), Cast(child, sumDataType)), sum)) ) - override val mergeExpressions: Seq[Expression] = { + override lazy val mergeExpressions: Seq[Expression] = { val add = Add(Coalesce(Seq(sum.left, zero)), Cast(sum.right, sumDataType)) Seq( /* sum = */ @@ -71,5 +70,5 @@ case class Sum(child: Expression) extends DeclarativeAggregate { ) } - override val evaluateExpression: Expression = Cast(sum, resultType) + override lazy val evaluateExpression: Expression = Cast(sum, resultType) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Variance.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Variance.scala index ec63534e5290a..ede2da2805966 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Variance.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Variance.scala @@ -24,6 +24,8 @@ case class VarianceSamp(child: Expression, inputAggBufferOffset: Int = 0) extends CentralMomentAgg(child) { + def this(child: Expression) = this(child, mutableAggBufferOffset = 0, inputAggBufferOffset = 0) + override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = copy(mutableAggBufferOffset = newMutableAggBufferOffset) @@ -42,11 +44,14 @@ case class VarianceSamp(child: Expression, } } -case class VariancePop(child: Expression, +case class VariancePop( + child: Expression, mutableAggBufferOffset: Int = 0, inputAggBufferOffset: Int = 0) extends CentralMomentAgg(child) { + def this(child: Expression) = this(child, mutableAggBufferOffset = 0, inputAggBufferOffset = 0) + override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = copy(mutableAggBufferOffset = newMutableAggBufferOffset) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala index 5c5b3d1ccd3cd..3b441de34a49f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala @@ -17,23 +17,24 @@ package org.apache.spark.sql.catalyst.expressions.aggregate +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.types._ -/** The mode of an [[AggregateFunction2]]. */ +/** The mode of an [[AggregateFunction]]. */ private[sql] sealed trait AggregateMode /** - * An [[AggregateFunction2]] with [[Partial]] mode is used for partial aggregation. + * An [[AggregateFunction]] with [[Partial]] mode is used for partial aggregation. * This function updates the given aggregation buffer with the original input of this * function. When it has processed all input rows, the aggregation buffer is returned. */ private[sql] case object Partial extends AggregateMode /** - * An [[AggregateFunction2]] with [[PartialMerge]] mode is used to merge aggregation buffers + * An [[AggregateFunction]] with [[PartialMerge]] mode is used to merge aggregation buffers * containing intermediate results for this function. * This function updates the given aggregation buffer by merging multiple aggregation buffers. * When it has processed all input rows, the aggregation buffer is returned. @@ -41,7 +42,7 @@ private[sql] case object Partial extends AggregateMode private[sql] case object PartialMerge extends AggregateMode /** - * An [[AggregateFunction2]] with [[Final]] mode is used to merge aggregation buffers + * An [[AggregateFunction]] with [[Final]] mode is used to merge aggregation buffers * containing intermediate results for this function and then generate final result. * This function updates the given aggregation buffer by merging multiple aggregation buffers. * When it has processed all input rows, the final result of this function is returned. @@ -49,7 +50,7 @@ private[sql] case object PartialMerge extends AggregateMode private[sql] case object Final extends AggregateMode /** - * An [[AggregateFunction2]] with [[Complete]] mode is used to evaluate this function directly + * An [[AggregateFunction]] with [[Complete]] mode is used to evaluate this function directly * from original input rows without any partial aggregation. * This function updates the given aggregation buffer with the original input of this * function. When it has processed all input rows, the final result of this function is returned. @@ -67,13 +68,15 @@ private[sql] case object NoOp extends Expression with Unevaluable { } /** - * A container for an [[AggregateFunction2]] with its [[AggregateMode]] and a field + * A container for an [[AggregateFunction]] with its [[AggregateMode]] and a field * (`isDistinct`) indicating if DISTINCT keyword is specified for this function. */ -private[sql] case class AggregateExpression2( - aggregateFunction: AggregateFunction2, +private[sql] case class AggregateExpression( + aggregateFunction: AggregateFunction, mode: AggregateMode, - isDistinct: Boolean) extends AggregateExpression { + isDistinct: Boolean) + extends Expression + with Unevaluable { override def children: Seq[Expression] = aggregateFunction :: Nil override def dataType: DataType = aggregateFunction.dataType @@ -89,6 +92,8 @@ private[sql] case class AggregateExpression2( AttributeSet(childReferences) } + override def prettyString: String = aggregateFunction.prettyString + override def toString: String = s"(${aggregateFunction},mode=$mode,isDistinct=$isDistinct)" } @@ -106,10 +111,10 @@ private[sql] case class AggregateExpression2( * combined aggregation buffer which concatenates the aggregation buffers of the individual * aggregate functions. * - * Code which accepts [[AggregateFunction2]] instances should be prepared to handle both types of + * Code which accepts [[AggregateFunction]] instances should be prepared to handle both types of * aggregate functions. */ -sealed abstract class AggregateFunction2 extends Expression with ImplicitCastInputTypes { +sealed abstract class AggregateFunction extends Expression with ImplicitCastInputTypes { /** An aggregate function is not foldable. */ final override def foldable: Boolean = false @@ -141,6 +146,27 @@ sealed abstract class AggregateFunction2 extends Expression with ImplicitCastInp override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = throw new UnsupportedOperationException(s"Cannot evaluate expression: $this") + + /** + * Wraps this [[AggregateFunction]] in an [[AggregateExpression]] because + * [[AggregateExpression]] is the container of an [[AggregateFunction]], aggregation mode, + * and the flag indicating if this aggregation is distinct aggregation or not. + * An [[AggregateFunction]] should not be used without being wrapped in + * an [[AggregateExpression]]. + */ + def toAggregateExpression(): AggregateExpression = toAggregateExpression(isDistinct = false) + + /** + * Wraps this [[AggregateFunction]] in an [[AggregateExpression]] and set isDistinct + * field of the [[AggregateExpression]] to the given value because + * [[AggregateExpression]] is the container of an [[AggregateFunction]], aggregation mode, + * and the flag indicating if this aggregation is distinct aggregation or not. + * An [[AggregateFunction]] should not be used without being wrapped in + * an [[AggregateExpression]]. + */ + def toAggregateExpression(isDistinct: Boolean): AggregateExpression = { + AggregateExpression(aggregateFunction = this, mode = Complete, isDistinct = isDistinct) + } } /** @@ -161,7 +187,7 @@ sealed abstract class AggregateFunction2 extends Expression with ImplicitCastInp * `inputAggBufferOffset`, but not on the correctness of the attribute ids in `aggBufferAttributes` * and `inputAggBufferAttributes`. */ -abstract class ImperativeAggregate extends AggregateFunction2 { +abstract class ImperativeAggregate extends AggregateFunction { /** * The offset of this function's first buffer value in the underlying shared mutable aggregation @@ -258,9 +284,14 @@ abstract class ImperativeAggregate extends AggregateFunction2 { * `bufferAttributes`, defining attributes for the fields of the mutable aggregation buffer. You * can then use these attributes when defining `updateExpressions`, `mergeExpressions`, and * `evaluateExpressions`. + * + * Please note that children of an aggregate function can be unresolved (it will happen when + * we create this function in DataFrame API). So, if there is any fields in + * the implemented class that need to access fields of its children, please make + * those fields `lazy val`s. */ abstract class DeclarativeAggregate - extends AggregateFunction2 + extends AggregateFunction with Serializable with Unevaluable { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala deleted file mode 100644 index 3dcf7915d77b3..0000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala +++ /dev/null @@ -1,1073 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.expressions - -import com.clearspring.analytics.stream.cardinality.HyperLogLog - -import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} -import org.apache.spark.sql.catalyst.util.{GenericArrayData, ArrayData, TypeUtils} -import org.apache.spark.sql.types._ -import org.apache.spark.util.collection.OpenHashSet - - -trait AggregateExpression extends Expression with Unevaluable - -trait AggregateExpression1 extends AggregateExpression { - - /** - * Aggregate expressions should not be foldable. - */ - override def foldable: Boolean = false - - /** - * Creates a new instance that can be used to compute this aggregate expression for a group - * of input rows/ - */ - def newInstance(): AggregateFunction1 -} - -/** - * Represents an aggregation that has been rewritten to be performed in two steps. - * - * @param finalEvaluation an aggregate expression that evaluates to same final result as the - * original aggregation. - * @param partialEvaluations A sequence of [[NamedExpression]]s that can be computed on partial - * data sets and are required to compute the `finalEvaluation`. - */ -case class SplitEvaluation( - finalEvaluation: Expression, - partialEvaluations: Seq[NamedExpression]) - -/** - * An [[AggregateExpression1]] that can be partially computed without seeing all relevant tuples. - * These partial evaluations can then be combined to compute the actual answer. - */ -trait PartialAggregate1 extends AggregateExpression1 { - - /** - * Returns a [[SplitEvaluation]] that computes this aggregation using partial aggregation. - */ - def asPartial: SplitEvaluation -} - -/** - * A specific implementation of an aggregate function. Used to wrap a generic - * [[AggregateExpression1]] with an algorithm that will be used to compute one specific result. - */ -abstract class AggregateFunction1 extends LeafExpression with Serializable { - - /** Base should return the generic aggregate expression that this function is computing */ - val base: AggregateExpression1 - - override def nullable: Boolean = base.nullable - override def dataType: DataType = base.dataType - - def update(input: InternalRow): Unit - - override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - throw new UnsupportedOperationException( - "AggregateFunction1 should not be used for generated aggregates") - } -} - -case class Min(child: Expression) extends UnaryExpression with PartialAggregate1 { - - override def nullable: Boolean = true - override def dataType: DataType = child.dataType - - override def asPartial: SplitEvaluation = { - val partialMin = Alias(Min(child), "PartialMin")() - SplitEvaluation(Min(partialMin.toAttribute), partialMin :: Nil) - } - - override def newInstance(): MinFunction = new MinFunction(child, this) - - override def checkInputDataTypes(): TypeCheckResult = - TypeUtils.checkForOrderingExpr(child.dataType, "function min") -} - -case class MinFunction(expr: Expression, base: AggregateExpression1) extends AggregateFunction1 { - def this() = this(null, null) // Required for serialization. - - val currentMin: MutableLiteral = MutableLiteral(null, expr.dataType) - val cmp = GreaterThan(currentMin, expr) - - override def update(input: InternalRow): Unit = { - if (currentMin.value == null) { - currentMin.value = expr.eval(input) - } else if (cmp.eval(input) == true) { - currentMin.value = expr.eval(input) - } - } - - override def eval(input: InternalRow): Any = currentMin.value -} - -case class Max(child: Expression) extends UnaryExpression with PartialAggregate1 { - - override def nullable: Boolean = true - override def dataType: DataType = child.dataType - - override def asPartial: SplitEvaluation = { - val partialMax = Alias(Max(child), "PartialMax")() - SplitEvaluation(Max(partialMax.toAttribute), partialMax :: Nil) - } - - override def newInstance(): MaxFunction = new MaxFunction(child, this) - - override def checkInputDataTypes(): TypeCheckResult = - TypeUtils.checkForOrderingExpr(child.dataType, "function max") -} - -case class MaxFunction(expr: Expression, base: AggregateExpression1) extends AggregateFunction1 { - def this() = this(null, null) // Required for serialization. - - val currentMax: MutableLiteral = MutableLiteral(null, expr.dataType) - val cmp = LessThan(currentMax, expr) - - override def update(input: InternalRow): Unit = { - if (currentMax.value == null) { - currentMax.value = expr.eval(input) - } else if (cmp.eval(input) == true) { - currentMax.value = expr.eval(input) - } - } - - override def eval(input: InternalRow): Any = currentMax.value -} - -case class Count(child: Expression) extends UnaryExpression with PartialAggregate1 { - - override def nullable: Boolean = false - override def dataType: LongType.type = LongType - - override def asPartial: SplitEvaluation = { - val partialCount = Alias(Count(child), "PartialCount")() - SplitEvaluation(Coalesce(Seq(Sum(partialCount.toAttribute), Literal(0L))), partialCount :: Nil) - } - - override def newInstance(): CountFunction = new CountFunction(child, this) -} - -case class CountFunction(expr: Expression, base: AggregateExpression1) extends AggregateFunction1 { - def this() = this(null, null) // Required for serialization. - - var count: Long = _ - - override def update(input: InternalRow): Unit = { - val evaluatedExpr = expr.eval(input) - if (evaluatedExpr != null) { - count += 1L - } - } - - override def eval(input: InternalRow): Any = count -} - -case class CountDistinct(expressions: Seq[Expression]) extends PartialAggregate1 { - def this() = this(null) - - override def children: Seq[Expression] = expressions - - override def nullable: Boolean = false - override def dataType: DataType = LongType - override def toString: String = s"COUNT(DISTINCT ${expressions.mkString(",")})" - override def newInstance(): CountDistinctFunction = new CountDistinctFunction(expressions, this) - - override def asPartial: SplitEvaluation = { - val partialSet = Alias(CollectHashSet(expressions), "partialSets")() - SplitEvaluation( - CombineSetsAndCount(partialSet.toAttribute), - partialSet :: Nil) - } -} - -case class CountDistinctFunction( - @transient expr: Seq[Expression], - @transient base: AggregateExpression1) - extends AggregateFunction1 { - - def this() = this(null, null) // Required for serialization. - - val seen = new OpenHashSet[Any]() - - @transient - val distinctValue = new InterpretedProjection(expr) - - override def update(input: InternalRow): Unit = { - val evaluatedExpr = distinctValue(input) - if (!evaluatedExpr.anyNull) { - seen.add(evaluatedExpr) - } - } - - override def eval(input: InternalRow): Any = seen.size.toLong -} - -case class CollectHashSet(expressions: Seq[Expression]) extends AggregateExpression1 { - def this() = this(null) - - override def children: Seq[Expression] = expressions - override def nullable: Boolean = false - override def dataType: OpenHashSetUDT = new OpenHashSetUDT(expressions.head.dataType) - override def toString: String = s"AddToHashSet(${expressions.mkString(",")})" - override def newInstance(): CollectHashSetFunction = - new CollectHashSetFunction(expressions, this) -} - -case class CollectHashSetFunction( - @transient expr: Seq[Expression], - @transient base: AggregateExpression1) - extends AggregateFunction1 { - - def this() = this(null, null) // Required for serialization. - - val seen = new OpenHashSet[Any]() - - @transient - val distinctValue = new InterpretedProjection(expr) - - override def update(input: InternalRow): Unit = { - val evaluatedExpr = distinctValue(input) - if (!evaluatedExpr.anyNull) { - seen.add(evaluatedExpr) - } - } - - override def eval(input: InternalRow): Any = { - seen - } -} - -case class CombineSetsAndCount(inputSet: Expression) extends AggregateExpression1 { - def this() = this(null) - - override def children: Seq[Expression] = inputSet :: Nil - override def nullable: Boolean = false - override def dataType: DataType = LongType - override def toString: String = s"CombineAndCount($inputSet)" - override def newInstance(): CombineSetsAndCountFunction = { - new CombineSetsAndCountFunction(inputSet, this) - } -} - -case class CombineSetsAndCountFunction( - @transient inputSet: Expression, - @transient base: AggregateExpression1) - extends AggregateFunction1 { - - def this() = this(null, null) // Required for serialization. - - val seen = new OpenHashSet[Any]() - - override def update(input: InternalRow): Unit = { - val inputSetEval = inputSet.eval(input).asInstanceOf[OpenHashSet[Any]] - val inputIterator = inputSetEval.iterator - while (inputIterator.hasNext) { - seen.add(inputIterator.next) - } - } - - override def eval(input: InternalRow): Any = seen.size.toLong -} - -/** The data type of ApproxCountDistinctPartition since its output is a HyperLogLog object. */ -private[sql] case object HyperLogLogUDT extends UserDefinedType[HyperLogLog] { - - override def sqlType: DataType = BinaryType - - /** Since we are using HyperLogLog internally, usually it will not be called. */ - override def serialize(obj: Any): Array[Byte] = - obj.asInstanceOf[HyperLogLog].getBytes - - - /** Since we are using HyperLogLog internally, usually it will not be called. */ - override def deserialize(datum: Any): HyperLogLog = - HyperLogLog.Builder.build(datum.asInstanceOf[Array[Byte]]) - - override def userClass: Class[HyperLogLog] = classOf[HyperLogLog] -} - -case class ApproxCountDistinctPartition(child: Expression, relativeSD: Double) - extends UnaryExpression with AggregateExpression1 { - - override def nullable: Boolean = false - override def dataType: DataType = HyperLogLogUDT - override def toString: String = s"APPROXIMATE COUNT(DISTINCT $child)" - override def newInstance(): ApproxCountDistinctPartitionFunction = { - new ApproxCountDistinctPartitionFunction(child, this, relativeSD) - } -} - -case class ApproxCountDistinctPartitionFunction( - expr: Expression, - base: AggregateExpression1, - relativeSD: Double) - extends AggregateFunction1 { - def this() = this(null, null, 0) // Required for serialization. - - private val hyperLogLog = new HyperLogLog(relativeSD) - - override def update(input: InternalRow): Unit = { - val evaluatedExpr = expr.eval(input) - if (evaluatedExpr != null) { - hyperLogLog.offer(evaluatedExpr) - } - } - - override def eval(input: InternalRow): Any = hyperLogLog -} - -case class ApproxCountDistinctMerge(child: Expression, relativeSD: Double) - extends UnaryExpression with AggregateExpression1 { - - override def nullable: Boolean = false - override def dataType: LongType.type = LongType - override def toString: String = s"APPROXIMATE COUNT(DISTINCT $child)" - override def newInstance(): ApproxCountDistinctMergeFunction = { - new ApproxCountDistinctMergeFunction(child, this, relativeSD) - } -} - -case class ApproxCountDistinctMergeFunction( - expr: Expression, - base: AggregateExpression1, - relativeSD: Double) - extends AggregateFunction1 { - def this() = this(null, null, 0) // Required for serialization. - - private val hyperLogLog = new HyperLogLog(relativeSD) - - override def update(input: InternalRow): Unit = { - val evaluatedExpr = expr.eval(input) - hyperLogLog.addAll(evaluatedExpr.asInstanceOf[HyperLogLog]) - } - - override def eval(input: InternalRow): Any = hyperLogLog.cardinality() -} - -case class ApproxCountDistinct(child: Expression, relativeSD: Double = 0.05) - extends UnaryExpression with PartialAggregate1 { - - override def nullable: Boolean = false - override def dataType: LongType.type = LongType - override def toString: String = s"APPROXIMATE COUNT(DISTINCT $child)" - - override def asPartial: SplitEvaluation = { - val partialCount = - Alias(ApproxCountDistinctPartition(child, relativeSD), "PartialApproxCountDistinct")() - - SplitEvaluation( - ApproxCountDistinctMerge(partialCount.toAttribute, relativeSD), - partialCount :: Nil) - } - - override def newInstance(): CountDistinctFunction = new CountDistinctFunction(child :: Nil, this) -} - -case class Average(child: Expression) extends UnaryExpression with PartialAggregate1 { - - override def prettyName: String = "avg" - - override def nullable: Boolean = true - - override def dataType: DataType = child.dataType match { - case DecimalType.Fixed(precision, scale) => - // Add 4 digits after decimal point, like Hive - DecimalType.bounded(precision + 4, scale + 4) - case _ => - DoubleType - } - - override def asPartial: SplitEvaluation = { - child.dataType match { - case DecimalType.Fixed(precision, scale) => - val partialSum = Alias(Sum(child), "PartialSum")() - val partialCount = Alias(Count(child), "PartialCount")() - - // partialSum already increase the precision by 10 - val castedSum = Cast(Sum(partialSum.toAttribute), partialSum.dataType) - val castedCount = Cast(Sum(partialCount.toAttribute), partialSum.dataType) - SplitEvaluation( - Cast(Divide(castedSum, castedCount), dataType), - partialCount :: partialSum :: Nil) - - case _ => - val partialSum = Alias(Sum(child), "PartialSum")() - val partialCount = Alias(Count(child), "PartialCount")() - - val castedSum = Cast(Sum(partialSum.toAttribute), dataType) - val castedCount = Cast(Sum(partialCount.toAttribute), dataType) - SplitEvaluation( - Divide(castedSum, castedCount), - partialCount :: partialSum :: Nil) - } - } - - override def newInstance(): AverageFunction = new AverageFunction(child, this) - - override def checkInputDataTypes(): TypeCheckResult = - TypeUtils.checkForNumericExpr(child.dataType, "function average") -} - -case class AverageFunction(expr: Expression, base: AggregateExpression1) - extends AggregateFunction1 { - - def this() = this(null, null) // Required for serialization. - - private val calcType = - expr.dataType match { - case DecimalType.Fixed(precision, scale) => - DecimalType.bounded(precision + 10, scale) - case _ => - expr.dataType - } - - private val zero = Cast(Literal(0), calcType) - - private var count: Long = _ - private val sum = MutableLiteral(zero.eval(null), calcType) - - private def addFunction(value: Any) = Add(sum, - Cast(Literal.create(value, expr.dataType), calcType)) - - override def eval(input: InternalRow): Any = { - if (count == 0L) { - null - } else { - expr.dataType match { - case DecimalType.Fixed(precision, scale) => - val dt = DecimalType.bounded(precision + 14, scale + 4) - Cast(Divide(Cast(sum, dt), Cast(Literal(count), dt)), dataType).eval(null) - case _ => - Divide( - Cast(sum, dataType), - Cast(Literal(count), dataType)).eval(null) - } - } - } - - override def update(input: InternalRow): Unit = { - val evaluatedExpr = expr.eval(input) - if (evaluatedExpr != null) { - count += 1 - sum.update(addFunction(evaluatedExpr), input) - } - } -} - -case class Sum(child: Expression) extends UnaryExpression with PartialAggregate1 { - - override def nullable: Boolean = true - - override def dataType: DataType = child.dataType match { - case DecimalType.Fixed(precision, scale) => - // Add 10 digits left of decimal point, like Hive - DecimalType.bounded(precision + 10, scale) - case _ => - child.dataType - } - - override def asPartial: SplitEvaluation = { - child.dataType match { - case DecimalType.Fixed(_, _) => - val partialSum = Alias(Sum(child), "PartialSum")() - SplitEvaluation( - Cast(Sum(partialSum.toAttribute), dataType), - partialSum :: Nil) - - case _ => - val partialSum = Alias(Sum(child), "PartialSum")() - SplitEvaluation( - Sum(partialSum.toAttribute), - partialSum :: Nil) - } - } - - override def newInstance(): SumFunction = new SumFunction(child, this) - - override def checkInputDataTypes(): TypeCheckResult = - TypeUtils.checkForNumericExpr(child.dataType, "function sum") -} - -case class SumFunction(expr: Expression, base: AggregateExpression1) extends AggregateFunction1 { - def this() = this(null, null) // Required for serialization. - - private val calcType = - expr.dataType match { - case DecimalType.Fixed(precision, scale) => - DecimalType.bounded(precision + 10, scale) - case _ => - expr.dataType - } - - private val zero = Cast(Literal(0), calcType) - - private val sum = MutableLiteral(null, calcType) - - private val addFunction = Coalesce(Seq(Add(Coalesce(Seq(sum, zero)), Cast(expr, calcType)), sum)) - - override def update(input: InternalRow): Unit = { - sum.update(addFunction, input) - } - - override def eval(input: InternalRow): Any = { - expr.dataType match { - case DecimalType.Fixed(_, _) => - Cast(sum, dataType).eval(null) - case _ => sum.eval(null) - } - } -} - -case class SumDistinct(child: Expression) extends UnaryExpression with PartialAggregate1 { - - def this() = this(null) - override def nullable: Boolean = true - override def dataType: DataType = child.dataType match { - case DecimalType.Fixed(precision, scale) => - // Add 10 digits left of decimal point, like Hive - DecimalType.bounded(precision + 10, scale) - case _ => - child.dataType - } - override def toString: String = s"sum(distinct $child)" - override def newInstance(): SumDistinctFunction = new SumDistinctFunction(child, this) - - override def asPartial: SplitEvaluation = { - val partialSet = Alias(CollectHashSet(child :: Nil), "partialSets")() - SplitEvaluation( - CombineSetsAndSum(partialSet.toAttribute, this), - partialSet :: Nil) - } - - override def checkInputDataTypes(): TypeCheckResult = - TypeUtils.checkForNumericExpr(child.dataType, "function sumDistinct") -} - -case class SumDistinctFunction(expr: Expression, base: AggregateExpression1) - extends AggregateFunction1 { - - def this() = this(null, null) // Required for serialization. - - private val seen = new scala.collection.mutable.HashSet[Any]() - - override def update(input: InternalRow): Unit = { - val evaluatedExpr = expr.eval(input) - if (evaluatedExpr != null) { - seen += evaluatedExpr - } - } - - override def eval(input: InternalRow): Any = { - if (seen.size == 0) { - null - } else { - Cast(Literal( - seen.reduceLeft( - dataType.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].plus)), - dataType).eval(null) - } - } -} - -case class CombineSetsAndSum(inputSet: Expression, base: Expression) extends AggregateExpression1 { - def this() = this(null, null) - - override def children: Seq[Expression] = inputSet :: Nil - override def nullable: Boolean = true - override def dataType: DataType = base.dataType - override def toString: String = s"CombineAndSum($inputSet)" - override def newInstance(): CombineSetsAndSumFunction = { - new CombineSetsAndSumFunction(inputSet, this) - } -} - -case class CombineSetsAndSumFunction( - @transient inputSet: Expression, - @transient base: AggregateExpression1) - extends AggregateFunction1 { - - def this() = this(null, null) // Required for serialization. - - val seen = new OpenHashSet[Any]() - - override def update(input: InternalRow): Unit = { - val inputSetEval = inputSet.eval(input).asInstanceOf[OpenHashSet[Any]] - val inputIterator = inputSetEval.iterator - while (inputIterator.hasNext) { - seen.add(inputIterator.next()) - } - } - - override def eval(input: InternalRow): Any = { - val casted = seen.asInstanceOf[OpenHashSet[InternalRow]] - if (casted.size == 0) { - null - } else { - Cast(Literal( - casted.iterator.map(f => f.get(0, null)).reduceLeft( - base.dataType.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].plus)), - base.dataType).eval(null) - } - } -} - -case class First( - child: Expression, - ignoreNullsExpr: Expression) - extends UnaryExpression with PartialAggregate1 { - - def this(child: Expression) = this(child, Literal.create(false, BooleanType)) - - private val ignoreNulls: Boolean = ignoreNullsExpr match { - case Literal(b: Boolean, BooleanType) => b - case _ => - throw new AnalysisException("The second argument of First should be a boolean literal.") - } - - override def nullable: Boolean = true - override def dataType: DataType = child.dataType - override def toString: String = s"first(${child}${if (ignoreNulls) " ignore nulls"})" - - override def asPartial: SplitEvaluation = { - val partialFirst = Alias(First(child, ignoreNulls), "PartialFirst")() - SplitEvaluation( - First(partialFirst.toAttribute, ignoreNulls), - partialFirst :: Nil) - } - override def newInstance(): FirstFunction = new FirstFunction(child, ignoreNulls, this) -} - -object First { - def apply(child: Expression): First = First(child, ignoreNulls = false) - - def apply(child: Expression, ignoreNulls: Boolean): First = - First(child, Literal.create(ignoreNulls, BooleanType)) -} - -case class FirstFunction( - expr: Expression, - ignoreNulls: Boolean, - base: AggregateExpression1) - extends AggregateFunction1 { - - def this() = this(null, null.asInstanceOf[Boolean], null) // Required for serialization. - - private[this] var result: Any = null - - private[this] var valueSet: Boolean = false - - override def update(input: InternalRow): Unit = { - if (!valueSet) { - val value = expr.eval(input) - // When we have not set the result, we will set the result if we respect nulls - // (i.e. ignoreNulls is false), or we ignore nulls and the evaluated value is not null. - if (!ignoreNulls || (ignoreNulls && value != null)) { - result = value - valueSet = true - } - } - } - - override def eval(input: InternalRow): Any = result -} - -case class Last( - child: Expression, - ignoreNullsExpr: Expression) - extends UnaryExpression with PartialAggregate1 { - - def this(child: Expression) = this(child, Literal.create(false, BooleanType)) - - private val ignoreNulls: Boolean = ignoreNullsExpr match { - case Literal(b: Boolean, BooleanType) => b - case _ => - throw new AnalysisException("The second argument of First should be a boolean literal.") - } - - override def references: AttributeSet = child.references - override def nullable: Boolean = true - override def dataType: DataType = child.dataType - override def toString: String = s"last($child)${if (ignoreNulls) " ignore nulls"}" - - override def asPartial: SplitEvaluation = { - val partialLast = Alias(Last(child, ignoreNulls), "PartialLast")() - SplitEvaluation( - Last(partialLast.toAttribute, ignoreNulls), - partialLast :: Nil) - } - override def newInstance(): LastFunction = new LastFunction(child, ignoreNulls, this) -} - -object Last { - def apply(child: Expression): Last = Last(child, ignoreNulls = false) - - def apply(child: Expression, ignoreNulls: Boolean): Last = - Last(child, Literal.create(ignoreNulls, BooleanType)) -} - -case class LastFunction( - expr: Expression, - ignoreNulls: Boolean, - base: AggregateExpression1) - extends AggregateFunction1 { - - def this() = this(null, null.asInstanceOf[Boolean], null) // Required for serialization. - - var result: Any = null - - override def update(input: InternalRow): Unit = { - val value = expr.eval(input) - if (!ignoreNulls || (ignoreNulls && value != null)) { - result = value - } - } - - override def eval(input: InternalRow): Any = { - result - } -} - -/** - * Calculate Pearson Correlation Coefficient for the given columns. - * Only support AggregateExpression2. - * - */ -case class Corr(left: Expression, right: Expression) - extends BinaryExpression with AggregateExpression1 with ImplicitCastInputTypes { - override def nullable: Boolean = false - override def dataType: DoubleType.type = DoubleType - override def toString: String = s"corr($left, $right)" - override def inputTypes: Seq[AbstractDataType] = Seq(DoubleType, DoubleType) - override def newInstance(): AggregateFunction1 = { - throw new UnsupportedOperationException( - "Corr only supports the new AggregateExpression2 and can only be used " + - "when spark.sql.useAggregate2 = true") - } -} - -// Compute standard deviation based on online algorithm specified here: -// http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance -abstract class StddevAgg1(child: Expression) extends UnaryExpression with PartialAggregate1 { - override def nullable: Boolean = true - override def dataType: DataType = DoubleType - - def isSample: Boolean - - override def asPartial: SplitEvaluation = { - val partialStd = Alias(ComputePartialStd(child), "PartialStddev")() - SplitEvaluation(MergePartialStd(partialStd.toAttribute, isSample), partialStd :: Nil) - } - - override def newInstance(): StddevFunction = new StddevFunction(child, this, isSample) - - override def checkInputDataTypes(): TypeCheckResult = - TypeUtils.checkForNumericExpr(child.dataType, "function stddev") - -} - -// Compute the population standard deviation of a column -case class StddevPop(child: Expression) extends StddevAgg1(child) { - - override def toString: String = s"stddev_pop($child)" - override def isSample: Boolean = false -} - -// Compute the sample standard deviation of a column -case class StddevSamp(child: Expression) extends StddevAgg1(child) { - - override def toString: String = s"stddev_samp($child)" - override def isSample: Boolean = true -} - -case class ComputePartialStd(child: Expression) extends UnaryExpression with AggregateExpression1 { - def this() = this(null) - - override def children: Seq[Expression] = child :: Nil - override def nullable: Boolean = false - override def dataType: DataType = ArrayType(DoubleType) - override def toString: String = s"computePartialStddev($child)" - override def newInstance(): ComputePartialStdFunction = - new ComputePartialStdFunction(child, this) -} - -case class ComputePartialStdFunction ( - expr: Expression, - base: AggregateExpression1 - ) extends AggregateFunction1 { - - def this() = this(null, null) // Required for serialization - - private val computeType = DoubleType - private val zero = Cast(Literal(0), computeType) - private var partialCount: Long = 0L - - // the mean of data processed so far - private val partialAvg: MutableLiteral = MutableLiteral(zero.eval(null), computeType) - - // update average based on this formula: - // avg = avg + (value - avg)/count - private def avgAddFunction (value: Literal): Expression = { - val delta = Subtract(Cast(value, computeType), partialAvg) - Add(partialAvg, Divide(delta, Cast(Literal(partialCount), computeType))) - } - - // the sum of squares of difference from mean - private val partialMk: MutableLiteral = MutableLiteral(zero.eval(null), computeType) - - // update sum of square of difference from mean based on following formula: - // Mk = Mk + (value - preAvg) * (value - updatedAvg) - private def mkAddFunction(value: Literal, prePartialAvg: MutableLiteral): Expression = { - val delta1 = Subtract(Cast(value, computeType), prePartialAvg) - val delta2 = Subtract(Cast(value, computeType), partialAvg) - Add(partialMk, Multiply(delta1, delta2)) - } - - override def update(input: InternalRow): Unit = { - val evaluatedExpr = expr.eval(input) - if (evaluatedExpr != null) { - val exprValue = Literal.create(evaluatedExpr, expr.dataType) - val prePartialAvg = partialAvg.copy() - partialCount += 1 - partialAvg.update(avgAddFunction(exprValue), input) - partialMk.update(mkAddFunction(exprValue, prePartialAvg), input) - } - } - - override def eval(input: InternalRow): Any = { - new GenericArrayData(Array(Cast(Literal(partialCount), computeType).eval(null), - partialAvg.eval(null), - partialMk.eval(null))) - } -} - -case class MergePartialStd( - child: Expression, - isSample: Boolean -) extends UnaryExpression with AggregateExpression1 { - def this() = this(null, false) // required for serialization - - override def children: Seq[Expression] = child:: Nil - override def nullable: Boolean = false - override def dataType: DataType = DoubleType - override def toString: String = s"MergePartialStd($child)" - override def newInstance(): MergePartialStdFunction = { - new MergePartialStdFunction(child, this, isSample) - } -} - -case class MergePartialStdFunction( - expr: Expression, - base: AggregateExpression1, - isSample: Boolean -) extends AggregateFunction1 { - def this() = this (null, null, false) // Required for serialization - - private val computeType = DoubleType - private val zero = Cast(Literal(0), computeType) - private val combineCount = MutableLiteral(zero.eval(null), computeType) - private val combineAvg = MutableLiteral(zero.eval(null), computeType) - private val combineMk = MutableLiteral(zero.eval(null), computeType) - - private def avgUpdateFunction(preCount: Expression, - partialCount: Expression, - partialAvg: Expression): Expression = { - Divide(Add(Multiply(combineAvg, preCount), - Multiply(partialAvg, partialCount)), - Add(preCount, partialCount)) - } - - override def update(input: InternalRow): Unit = { - val evaluatedExpr = expr.eval(input).asInstanceOf[ArrayData] - - if (evaluatedExpr != null) { - val exprValue = evaluatedExpr.toArray(computeType) - val (partialCount, partialAvg, partialMk) = - (Literal.create(exprValue(0), computeType), - Literal.create(exprValue(1), computeType), - Literal.create(exprValue(2), computeType)) - - if (Cast(partialCount, LongType).eval(null).asInstanceOf[Long] > 0) { - val preCount = combineCount.copy() - combineCount.update(Add(combineCount, partialCount), input) - - val preAvg = combineAvg.copy() - val avgDelta = Subtract(partialAvg, preAvg) - val mkDelta = Multiply(Multiply(avgDelta, avgDelta), - Divide(Multiply(preCount, partialCount), - combineCount)) - - // update average based on following formula - // (combineAvg * preCount + partialAvg * partialCount) / (preCount + partialCount) - combineAvg.update(avgUpdateFunction(preCount, partialCount, partialAvg), input) - - // update sum of square differences from mean based on following formula - // (combineMk + partialMk + (avgDelta * avgDelta) * (preCount * partialCount/combineCount) - combineMk.update(Add(combineMk, Add(partialMk, mkDelta)), input) - } - } - } - - override def eval(input: InternalRow): Any = { - val count: Long = Cast(combineCount, LongType).eval(null).asInstanceOf[Long] - - if (count == 0) null - else if (count < 2) zero.eval(null) - else { - // when total count > 2 - // stddev_samp = sqrt (combineMk/(combineCount -1)) - // stddev_pop = sqrt (combineMk/combineCount) - val varCol = { - if (isSample) { - Divide(combineMk, Cast(Literal(count - 1), computeType)) - } - else { - Divide(combineMk, Cast(Literal(count), computeType)) - } - } - Sqrt(varCol).eval(null) - } - } -} - -case class StddevFunction( - expr: Expression, - base: AggregateExpression1, - isSample: Boolean -) extends AggregateFunction1 { - - def this() = this(null, null, false) // Required for serialization - - private val computeType = DoubleType - private var curCount: Long = 0L - private val zero = Cast(Literal(0), computeType) - private val curAvg = MutableLiteral(zero.eval(null), computeType) - private val curMk = MutableLiteral(zero.eval(null), computeType) - - private def curAvgAddFunction(value: Literal): Expression = { - val delta = Subtract(Cast(value, computeType), curAvg) - Add(curAvg, Divide(delta, Cast(Literal(curCount), computeType))) - } - private def curMkAddFunction(value: Literal, preAvg: MutableLiteral): Expression = { - val delta1 = Subtract(Cast(value, computeType), preAvg) - val delta2 = Subtract(Cast(value, computeType), curAvg) - Add(curMk, Multiply(delta1, delta2)) - } - - override def update(input: InternalRow): Unit = { - val evaluatedExpr = expr.eval(input) - if (evaluatedExpr != null) { - val preAvg: MutableLiteral = curAvg.copy() - val exprValue = Literal.create(evaluatedExpr, expr.dataType) - curCount += 1L - curAvg.update(curAvgAddFunction(exprValue), input) - curMk.update(curMkAddFunction(exprValue, preAvg), input) - } - } - - override def eval(input: InternalRow): Any = { - if (curCount == 0) null - else if (curCount < 2) zero.eval(null) - else { - // when total count > 2, - // stddev_samp = sqrt(curMk/(curCount - 1)) - // stddev_pop = sqrt(curMk/curCount) - val varCol = { - if (isSample) { - Divide(curMk, Cast(Literal(curCount - 1), computeType)) - } - else { - Divide(curMk, Cast(Literal(curCount), computeType)) - } - } - Sqrt(varCol).eval(null) - } - } -} - -// placeholder -case class Kurtosis(child: Expression) extends UnaryExpression with AggregateExpression1 { - - override def newInstance(): AggregateFunction1 = { - throw new UnsupportedOperationException("AggregateExpression1 is no longer supported, " + - "please set spark.sql.useAggregate2 = true") - } - - override def nullable: Boolean = false - - override def dataType: DoubleType.type = DoubleType - - override def foldable: Boolean = false - - override def prettyName: String = "kurtosis" -} - -// placeholder -case class Skewness(child: Expression) extends UnaryExpression with AggregateExpression1 { - - override def newInstance(): AggregateFunction1 = { - throw new UnsupportedOperationException("AggregateExpression1 is no longer supported, " + - "please set spark.sql.useAggregate2 = true") - } - - override def nullable: Boolean = false - - override def dataType: DoubleType.type = DoubleType - - override def foldable: Boolean = false - - override def prettyName: String = "skewness" -} - -// placeholder -case class VariancePop(child: Expression) extends UnaryExpression with AggregateExpression1 { - - override def newInstance(): AggregateFunction1 = { - throw new UnsupportedOperationException("AggregateExpression1 is no longer supported, " + - "please set spark.sql.useAggregate2 = true") - } - - override def nullable: Boolean = false - - override def dataType: DoubleType.type = DoubleType - - override def foldable: Boolean = false - - override def prettyName: String = "var_pop" -} - -// placeholder -case class VarianceSamp(child: Expression) extends UnaryExpression with AggregateExpression1 { - - override def newInstance(): AggregateFunction1 = { - throw new UnsupportedOperationException("AggregateExpression1 is no longer supported, " + - "please set spark.sql.useAggregate2 = true") - } - - override def nullable: Boolean = false - - override def dataType: DoubleType.type = DoubleType - - override def foldable: Boolean = false - - override def prettyName: String = "var_samp" -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index d222dfa33ad8a..f4dba67f13b54 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.optimizer import scala.collection.immutable.HashSet import org.apache.spark.sql.catalyst.analysis.{CleanupAliases, EliminateSubQueries} import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.Inner import org.apache.spark.sql.catalyst.plans.FullOuter import org.apache.spark.sql.catalyst.plans.LeftOuter @@ -201,8 +202,8 @@ object SetOperationPushDown extends Rule[LogicalPlan] with PredicateHelper { object ColumnPruning extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case a @ Aggregate(_, _, e @ Expand(_, _, child)) - if (child.outputSet -- AttributeSet(e.output) -- a.references).nonEmpty => - a.copy(child = e.copy(child = prunedChild(child, AttributeSet(e.output) ++ a.references))) + if (child.outputSet -- e.references -- a.references).nonEmpty => + a.copy(child = e.copy(child = prunedChild(child, e.references ++ a.references))) // Eliminate attributes that are not needed to calculate the specified aggregates. case a @ Aggregate(_, _, child) if (child.outputSet -- a.references).nonEmpty => @@ -363,7 +364,8 @@ object LikeSimplification extends Rule[LogicalPlan] { object NullPropagation extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case q: LogicalPlan => q transformExpressionsUp { - case e @ Count(Literal(null, _)) => Cast(Literal(0L), e.dataType) + case e @ AggregateExpression(Count(Literal(null, _)), _, _) => + Cast(Literal(0L), e.dataType) case e @ IsNull(c) if !c.nullable => Literal.create(false, BooleanType) case e @ IsNotNull(c) if !c.nullable => Literal.create(true, BooleanType) case e @ GetArrayItem(Literal(null, _), _) => Literal.create(null, e.dataType) @@ -375,7 +377,9 @@ object NullPropagation extends Rule[LogicalPlan] { Literal.create(null, e.dataType) case e @ EqualNullSafe(Literal(null, _), r) => IsNull(r) case e @ EqualNullSafe(l, Literal(null, _)) => IsNull(l) - case e @ Count(expr) if !expr.nullable => Count(Literal(1)) + case e @ AggregateExpression(Count(expr), mode, false) if !expr.nullable => + // This rule should be only triggered when isDistinct field is false. + AggregateExpression(Count(Literal(1)), mode, isDistinct = false) // For Coalesce, remove null literals. case e @ Coalesce(children) => @@ -857,12 +861,15 @@ object DecimalAggregates extends Rule[LogicalPlan] { private val MAX_DOUBLE_DIGITS = 15 def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { - case Sum(e @ DecimalType.Expression(prec, scale)) if prec + 10 <= MAX_LONG_DIGITS => - MakeDecimal(Sum(UnscaledValue(e)), prec + 10, scale) + case AggregateExpression(Sum(e @ DecimalType.Expression(prec, scale)), mode, isDistinct) + if prec + 10 <= MAX_LONG_DIGITS => + MakeDecimal(AggregateExpression(Sum(UnscaledValue(e)), mode, isDistinct), prec + 10, scale) - case Average(e @ DecimalType.Expression(prec, scale)) if prec + 4 <= MAX_DOUBLE_DIGITS => + case AggregateExpression(Average(e @ DecimalType.Expression(prec, scale)), mode, isDistinct) + if prec + 4 <= MAX_DOUBLE_DIGITS => + val newAggExpr = AggregateExpression(Average(UnscaledValue(e)), mode, isDistinct) Cast( - Divide(Average(UnscaledValue(e)), Literal.create(math.pow(10.0, scale), DoubleType)), + Divide(newAggExpr, Literal.create(math.pow(10.0, scale), DoubleType)), DecimalType(prec + 4, scale + 4)) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 3b975b904a332..6f4f11406d7c4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -84,80 +84,6 @@ object PhysicalOperation extends PredicateHelper { } } -/** - * Matches a logical aggregation that can be performed on distributed data in two steps. The first - * operates on the data in each partition performing partial aggregation for each group. The second - * occurs after the shuffle and completes the aggregation. - * - * This pattern will only match if all aggregate expressions can be computed partially and will - * return the rewritten aggregation expressions for both phases. - * - * The returned values for this match are as follows: - * - Grouping attributes for the final aggregation. - * - Aggregates for the final aggregation. - * - Grouping expressions for the partial aggregation. - * - Partial aggregate expressions. - * - Input to the aggregation. - */ -object PartialAggregation { - type ReturnType = - (Seq[Attribute], Seq[NamedExpression], Seq[Expression], Seq[NamedExpression], LogicalPlan) - - def unapply(plan: LogicalPlan): Option[ReturnType] = plan match { - case logical.Aggregate(groupingExpressions, aggregateExpressions, child) => - // Collect all aggregate expressions. - val allAggregates = - aggregateExpressions.flatMap(_ collect { case a: AggregateExpression1 => a}) - // Collect all aggregate expressions that can be computed partially. - val partialAggregates = - aggregateExpressions.flatMap(_ collect { case p: PartialAggregate1 => p}) - - // Only do partial aggregation if supported by all aggregate expressions. - if (allAggregates.size == partialAggregates.size) { - // Create a map of expressions to their partial evaluations for all aggregate expressions. - val partialEvaluations: Map[TreeNodeRef, SplitEvaluation] = - partialAggregates.map(a => (new TreeNodeRef(a), a.asPartial)).toMap - - // We need to pass all grouping expressions though so the grouping can happen a second - // time. However some of them might be unnamed so we alias them allowing them to be - // referenced in the second aggregation. - val namedGroupingExpressions: Seq[(Expression, NamedExpression)] = - groupingExpressions.map { - case n: NamedExpression => (n, n) - case other => (other, Alias(other, "PartialGroup")()) - } - - // Replace aggregations with a new expression that computes the result from the already - // computed partial evaluations and grouping values. - val rewrittenAggregateExpressions = aggregateExpressions.map(_.transformDown { - case e: Expression if partialEvaluations.contains(new TreeNodeRef(e)) => - partialEvaluations(new TreeNodeRef(e)).finalEvaluation - - case e: Expression => - namedGroupingExpressions.collectFirst { - case (expr, ne) if expr semanticEquals e => ne.toAttribute - }.getOrElse(e) - }).asInstanceOf[Seq[NamedExpression]] - - val partialComputation = namedGroupingExpressions.map(_._2) ++ - partialEvaluations.values.flatMap(_.partialEvaluations) - - val namedGroupingAttributes = namedGroupingExpressions.map(_._2.toAttribute) - - Some( - (namedGroupingAttributes, - rewrittenAggregateExpressions, - groupingExpressions, - partialComputation, - child)) - } else { - None - } - case _ => None - } -} - - /** * A pattern that finds joins with equality conditions that can be evaluated using equi-join. * diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 0ec9f08571082..b9db7838db08a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -137,13 +137,17 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy /** Returns all of the expressions present in this query plan operator. */ def expressions: Seq[Expression] = { + // Recursively find all expressions from a traversable. + def seqToExpressions(seq: Traversable[Any]): Traversable[Expression] = seq.flatMap { + case e: Expression => e :: Nil + case s: Traversable[_] => seqToExpressions(s) + case other => Nil + } + productIterator.flatMap { case e: Expression => e :: Nil case Some(e: Expression) => e :: Nil - case seq: Traversable[_] => seq.flatMap { - case e: Expression => e :: Nil - case other => Nil - } + case seq: Traversable[_] => seqToExpressions(seq) case other => Nil }.toSeq } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index d771088d69dea..764f8aaebddf1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.encoders._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate.Utils +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.types._ import org.apache.spark.util.collection.OpenHashSet @@ -219,8 +219,6 @@ case class Aggregate( !expressions.exists(!_.resolved) && childrenResolved && !hasWindowExpressions } - lazy val newAggregation: Option[Aggregate] = Utils.tryConvert(this) - override def output: Seq[Attribute] = aggregateExpressions.map(_.toAttribute) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index fbdd3a7776f50..5a2368e329976 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -171,16 +171,18 @@ class AnalysisErrorSuite extends AnalysisTest { test("SPARK-6452 regression test") { // CheckAnalysis should throw AnalysisException when Aggregate contains missing attribute(s) + // Since we manually construct the logical plan at here and Sum only accetp + // LongType, DoubleType, and DecimalType. We use LongType as the type of a. val plan = Aggregate( Nil, - Alias(Sum(AttributeReference("a", IntegerType)(exprId = ExprId(1))), "b")() :: Nil, + Alias(sum(AttributeReference("a", LongType)(exprId = ExprId(1))), "b")() :: Nil, LocalRelation( - AttributeReference("a", IntegerType)(exprId = ExprId(2)))) + AttributeReference("a", LongType)(exprId = ExprId(2)))) assert(plan.resolved) - assertAnalysisError(plan, "resolved attribute(s) a#1 missing from a#2" :: Nil) + assertAnalysisError(plan, "resolved attribute(s) a#1L missing from a#2L" :: Nil) } test("error test for self-join") { @@ -196,7 +198,7 @@ class AnalysisErrorSuite extends AnalysisTest { val plan = Aggregate( AttributeReference("a", BinaryType)(exprId = ExprId(2)) :: Nil, - Alias(Sum(AttributeReference("b", IntegerType)(exprId = ExprId(1))), "c")() :: Nil, + Alias(sum(AttributeReference("b", IntegerType)(exprId = ExprId(1))), "c")() :: Nil, LocalRelation( AttributeReference("a", BinaryType)(exprId = ExprId(2)), AttributeReference("b", IntegerType)(exprId = ExprId(1)))) @@ -207,13 +209,24 @@ class AnalysisErrorSuite extends AnalysisTest { val plan2 = Aggregate( AttributeReference("a", MapType(IntegerType, StringType))(exprId = ExprId(2)) :: Nil, - Alias(Sum(AttributeReference("b", IntegerType)(exprId = ExprId(1))), "c")() :: Nil, + Alias(sum(AttributeReference("b", IntegerType)(exprId = ExprId(1))), "c")() :: Nil, LocalRelation( AttributeReference("a", MapType(IntegerType, StringType))(exprId = ExprId(2)), AttributeReference("b", IntegerType)(exprId = ExprId(1)))) assertAnalysisError(plan2, "map type expression a cannot be used in grouping expression" :: Nil) + + val plan3 = + Aggregate( + AttributeReference("a", ArrayType(IntegerType))(exprId = ExprId(2)) :: Nil, + Alias(sum(AttributeReference("b", IntegerType)(exprId = ExprId(1))), "c")() :: Nil, + LocalRelation( + AttributeReference("a", ArrayType(IntegerType))(exprId = ExprId(2)), + AttributeReference("b", IntegerType)(exprId = ExprId(1)))) + + assertAnalysisError(plan3, + "array type expression a cannot be used in grouping expression" :: Nil) } test("Join can't work on binary and map types") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 71d2939ecffe6..65f09b46afae1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -45,7 +45,7 @@ class AnalysisSuite extends AnalysisTest { val explode = Explode(AttributeReference("a", IntegerType, nullable = true)()) assert(!Project(Seq(Alias(explode, "explode")()), testRelation).resolved) - assert(!Project(Seq(Alias(Count(Literal(1)), "count")()), testRelation).resolved) + assert(!Project(Seq(Alias(count(Literal(1)), "count")()), testRelation).resolved) } test("analyze project") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala index 40c4ae7920918..fed591fd90a9a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala @@ -21,6 +21,7 @@ import org.scalatest.BeforeAndAfter import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical.{Union, Project, LocalRelation} import org.apache.spark.sql.types._ import org.apache.spark.sql.catalyst.{TableIdentifier, SimpleCatalystConf} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala index c9bcc68f02030..b902982add8ff 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala @@ -22,6 +22,7 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.types.{TypeCollection, StringType} @@ -140,15 +141,16 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { } test("check types for aggregates") { + // We use AggregateFunction directly at here because the error will be thrown from it + // instead of from AggregateExpression, which is the wrapper of an AggregateFunction. + // We will cast String to Double for sum and average assertSuccess(Sum('stringField)) - assertSuccess(SumDistinct('stringField)) assertSuccess(Average('stringField)) assertError(Min('complexField), "min does not support ordering on type") assertError(Max('complexField), "max does not support ordering on type") assertError(Sum('booleanField), "function sum requires numeric type") - assertError(SumDistinct('booleanField), "function sumDistinct requires numeric type") assertError(Average('booleanField), "function average requires numeric type") } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala index e67606288f514..8aaefa84937c2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala @@ -162,7 +162,7 @@ class ConstantFoldingSuite extends PlanTest { testRelation .select( Rand(5L) + Literal(1) as Symbol("c1"), - Sum('a) as Symbol("c2")) + sum('a) as Symbol("c2")) val optimized = Optimize.execute(originalQuery.analyze) @@ -170,7 +170,7 @@ class ConstantFoldingSuite extends PlanTest { testRelation .select( Rand(5L) + Literal(1.0) as Symbol("c1"), - Sum('a) as Symbol("c2")) + sum('a) as Symbol("c2")) .analyze comparePlans(optimized, correctAnswer) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index ed810a12808f0..0290fafe879f6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -68,7 +68,7 @@ class FilterPushdownSuite extends PlanTest { test("column pruning for group") { val originalQuery = testRelation - .groupBy('a)('a, Count('b)) + .groupBy('a)('a, count('b)) .select('a) val optimized = Optimize.execute(originalQuery.analyze) @@ -84,7 +84,7 @@ class FilterPushdownSuite extends PlanTest { test("column pruning for group with alias") { val originalQuery = testRelation - .groupBy('a)('a as 'c, Count('b)) + .groupBy('a)('a as 'c, count('b)) .select('c) val optimized = Optimize.execute(originalQuery.analyze) @@ -656,7 +656,7 @@ class FilterPushdownSuite extends PlanTest { test("aggregate: push down filter when filter on group by expression") { val originalQuery = testRelation - .groupBy('a)('a, Count('b) as 'c) + .groupBy('a)('a, count('b) as 'c) .select('a, 'c) .where('a === 2) @@ -664,7 +664,7 @@ class FilterPushdownSuite extends PlanTest { val correctAnswer = testRelation .where('a === 2) - .groupBy('a)('a, Count('b) as 'c) + .groupBy('a)('a, count('b) as 'c) .analyze comparePlans(optimized, correctAnswer) } @@ -672,7 +672,7 @@ class FilterPushdownSuite extends PlanTest { test("aggregate: don't push down filter when filter not on group by expression") { val originalQuery = testRelation .select('a, 'b) - .groupBy('a)('a, Count('b) as 'c) + .groupBy('a)('a, count('b) as 'c) .where('c === 2L) val optimized = Optimize.execute(originalQuery.analyze) @@ -683,7 +683,7 @@ class FilterPushdownSuite extends PlanTest { test("aggregate: push down filters partially which are subset of group by expressions") { val originalQuery = testRelation .select('a, 'b) - .groupBy('a)('a, Count('b) as 'c) + .groupBy('a)('a, count('b) as 'c) .where('c === 2L && 'a === 3) val optimized = Optimize.execute(originalQuery.analyze) @@ -691,7 +691,7 @@ class FilterPushdownSuite extends PlanTest { val correctAnswer = testRelation .select('a, 'b) .where('a === 3) - .groupBy('a)('a, Count('b) as 'c) + .groupBy('a)('a, count('b) as 'c) .where('c === 2L) .analyze diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index d25807cf8d09c..3b69247dc54ef 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -34,6 +34,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.encoders.Encoder import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.{Inner, JoinType} @@ -1338,7 +1339,7 @@ class DataFrame private[sql]( if (groupColExprIds.contains(attr.exprId)) { attr } else { - Alias(First(attr), attr.name)() + Alias(new First(attr).toAggregateExpression(), attr.name)() } } Aggregate(groupCols, aggCols, logicalPlan) @@ -1381,11 +1382,11 @@ class DataFrame private[sql]( // The list of summary statistics to compute, in the form of expressions. val statistics = List[(String, Expression => Expression)]( - "count" -> Count, - "mean" -> Average, - "stddev" -> StddevSamp, - "min" -> Min, - "max" -> Max) + "count" -> ((child: Expression) => Count(child).toAggregateExpression()), + "mean" -> ((child: Expression) => Average(child).toAggregateExpression()), + "stddev" -> ((child: Expression) => StddevSamp(child).toAggregateExpression()), + "min" -> ((child: Expression) => Min(child).toAggregateExpression()), + "max" -> ((child: Expression) => Max(child).toAggregateExpression())) val outputCols = (if (cols.isEmpty) numericColumns.map(_.prettyString) else cols).toList diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala index f9eab5c2e965b..5babf2cc0ca25 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala @@ -21,8 +21,9 @@ import scala.collection.JavaConverters._ import scala.language.implicitConversions import org.apache.spark.annotation.Experimental -import org.apache.spark.sql.catalyst.analysis.{UnresolvedAlias, UnresolvedAttribute, Star} +import org.apache.spark.sql.catalyst.analysis.{UnresolvedFunction, UnresolvedAlias, UnresolvedAttribute, Star} import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical.{Rollup, Cube, Aggregate} import org.apache.spark.sql.types.NumericType @@ -70,7 +71,7 @@ class GroupedData protected[sql]( } } - private[this] def aggregateNumericColumns(colNames: String*)(f: Expression => Expression) + private[this] def aggregateNumericColumns(colNames: String*)(f: Expression => AggregateFunction) : DataFrame = { val columnExprs = if (colNames.isEmpty) { @@ -88,30 +89,28 @@ class GroupedData protected[sql]( namedExpr } } - toDF(columnExprs.map(f)) + toDF(columnExprs.map(expr => f(expr).toAggregateExpression())) } private[this] def strToExpr(expr: String): (Expression => Expression) = { - expr.toLowerCase match { - case "avg" | "average" | "mean" => Average - case "max" => Max - case "min" => Min - case "stddev" | "std" => StddevSamp - case "stddev_pop" => StddevPop - case "stddev_samp" => StddevSamp - case "variance" => VarianceSamp - case "var_pop" => VariancePop - case "var_samp" => VarianceSamp - case "sum" => Sum - case "skewness" => Skewness - case "kurtosis" => Kurtosis - case "count" | "size" => - // Turn count(*) into count(1) - (inputExpr: Expression) => inputExpr match { - case s: Star => Count(Literal(1)) - case _ => Count(inputExpr) - } + val exprToFunc: (Expression => Expression) = { + (inputExpr: Expression) => expr.toLowerCase match { + // We special handle a few cases that have alias that are not in function registry. + case "avg" | "average" | "mean" => + UnresolvedFunction("avg", inputExpr :: Nil, isDistinct = false) + case "stddev" | "std" => + UnresolvedFunction("stddev", inputExpr :: Nil, isDistinct = false) + // Also special handle count because we need to take care count(*). + case "count" | "size" => + // Turn count(*) into count(1) + inputExpr match { + case s: Star => Count(Literal(1)).toAggregateExpression() + case _ => Count(inputExpr).toAggregateExpression() + } + case name => UnresolvedFunction(name, inputExpr :: Nil, isDistinct = false) + } } + (inputExpr: Expression) => exprToFunc(inputExpr) } /** @@ -213,7 +212,7 @@ class GroupedData protected[sql]( * * @since 1.3.0 */ - def count(): DataFrame = toDF(Seq(Alias(Count(Literal(1)), "count")())) + def count(): DataFrame = toDF(Seq(Alias(Count(Literal(1)).toAggregateExpression(), "count")())) /** * Compute the average value for each numeric columns for each group. This is an alias for `avg`. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index ed8b634ad5630..b7314189b5403 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -448,15 +448,24 @@ private[spark] object SQLConf { defaultValue = Some(true), isPublic = false) - val USE_SQL_AGGREGATE2 = booleanConf("spark.sql.useAggregate2", - defaultValue = Some(true), doc = "") - val RUN_SQL_ON_FILES = booleanConf("spark.sql.runSQLOnFiles", defaultValue = Some(true), isPublic = false, doc = "When true, we could use `datasource`.`path` as table in SQL query" ) + val SPECIALIZE_SINGLE_DISTINCT_AGG_PLANNING = + booleanConf("spark.sql.specializeSingleDistinctAggPlanning", + defaultValue = Some(true), + isPublic = false, + doc = "When true, if a query only has a single distinct column and it has " + + "grouping expressions, we will use our planner rule to handle this distinct " + + "column (other cases are handled by DistinctAggregationRewriter). " + + "When false, we will always use DistinctAggregationRewriter to plan " + + "aggregation queries with DISTINCT keyword. This is an internal flag that is " + + "used to benchmark the performance impact of using DistinctAggregationRewriter to " + + "plan aggregation queries with a single distinct column.") + object Deprecated { val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks" val EXTERNAL_SORT = "spark.sql.planner.externalSort" @@ -532,8 +541,6 @@ private[sql] class SQLConf extends Serializable with CatalystConf { private[spark] def unsafeEnabled: Boolean = getConf(UNSAFE_ENABLED, getConf(TUNGSTEN_ENABLED)) - private[spark] def useSqlAggregate2: Boolean = getConf(USE_SQL_AGGREGATE2) - private[spark] def autoBroadcastJoinThreshold: Int = getConf(AUTO_BROADCASTJOIN_THRESHOLD) private[spark] def defaultSizeInBytes: Long = @@ -575,6 +582,9 @@ private[sql] class SQLConf extends Serializable with CatalystConf { private[spark] def runSQLOnFile: Boolean = getConf(RUN_SQL_ON_FILES) + protected[spark] override def specializeSingleDistinctAggPlanning: Boolean = + getConf(SPECIALIZE_SINGLE_DISTINCT_AGG_PLANNING) + /** ********************** SQLConf functionality methods ************ */ /** Set Spark SQL configuration properties. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala deleted file mode 100644 index 6f3f1bd97ad52..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala +++ /dev/null @@ -1,205 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution - -import java.util.HashMap - -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.errors._ -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.execution.metric.SQLMetrics - -/** - * Groups input data by `groupingExpressions` and computes the `aggregateExpressions` for each - * group. - * - * @param partial if true then aggregation is done partially on local data without shuffling to - * ensure all values where `groupingExpressions` are equal are present. - * @param groupingExpressions expressions that are evaluated to determine grouping. - * @param aggregateExpressions expressions that are computed for each group. - * @param child the input data source. - */ -case class Aggregate( - partial: Boolean, - groupingExpressions: Seq[Expression], - aggregateExpressions: Seq[NamedExpression], - child: SparkPlan) - extends UnaryNode { - - override private[sql] lazy val metrics = Map( - "numInputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of input rows"), - "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) - - override def requiredChildDistribution: List[Distribution] = { - if (partial) { - UnspecifiedDistribution :: Nil - } else { - if (groupingExpressions == Nil) { - AllTuples :: Nil - } else { - ClusteredDistribution(groupingExpressions) :: Nil - } - } - } - - override def output: Seq[Attribute] = aggregateExpressions.map(_.toAttribute) - - /** - * An aggregate that needs to be computed for each row in a group. - * - * @param unbound Unbound version of this aggregate, used for result substitution. - * @param aggregate A bound copy of this aggregate used to create a new aggregation buffer. - * @param resultAttribute An attribute used to refer to the result of this aggregate in the final - * output. - */ - case class ComputedAggregate( - unbound: AggregateExpression1, - aggregate: AggregateExpression1, - resultAttribute: AttributeReference) - - /** A list of aggregates that need to be computed for each group. */ - private[this] val computedAggregates = aggregateExpressions.flatMap { agg => - agg.collect { - case a: AggregateExpression1 => - ComputedAggregate( - a, - BindReferences.bindReference(a, child.output), - AttributeReference(s"aggResult:$a", a.dataType, a.nullable)()) - } - }.toArray - - /** The schema of the result of all aggregate evaluations */ - private[this] val computedSchema = computedAggregates.map(_.resultAttribute) - - /** Creates a new aggregate buffer for a group. */ - private[this] def newAggregateBuffer(): Array[AggregateFunction1] = { - val buffer = new Array[AggregateFunction1](computedAggregates.length) - var i = 0 - while (i < computedAggregates.length) { - buffer(i) = computedAggregates(i).aggregate.newInstance() - i += 1 - } - buffer - } - - /** Named attributes used to substitute grouping attributes into the final result. */ - private[this] val namedGroups = groupingExpressions.map { - case ne: NamedExpression => ne -> ne.toAttribute - case e => e -> Alias(e, s"groupingExpr:$e")().toAttribute - } - - /** - * A map of substitutions that are used to insert the aggregate expressions and grouping - * expression into the final result expression. - */ - private[this] val resultMap = - (computedAggregates.map { agg => agg.unbound -> agg.resultAttribute } ++ namedGroups).toMap - - /** - * Substituted version of aggregateExpressions expressions which are used to compute final - * output rows given a group and the result of all aggregate computations. - */ - private[this] val resultExpressions = aggregateExpressions.map { agg => - agg.transform { - case e: Expression if resultMap.contains(e) => resultMap(e) - } - } - - protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") { - val numInputRows = longMetric("numInputRows") - val numOutputRows = longMetric("numOutputRows") - if (groupingExpressions.isEmpty) { - child.execute().mapPartitions { iter => - val buffer = newAggregateBuffer() - var currentRow: InternalRow = null - while (iter.hasNext) { - currentRow = iter.next() - numInputRows += 1 - var i = 0 - while (i < buffer.length) { - buffer(i).update(currentRow) - i += 1 - } - } - val resultProjection = new InterpretedProjection(resultExpressions, computedSchema) - val aggregateResults = new GenericMutableRow(computedAggregates.length) - - var i = 0 - while (i < buffer.length) { - aggregateResults(i) = buffer(i).eval(EmptyRow) - i += 1 - } - - numOutputRows += 1 - Iterator(resultProjection(aggregateResults)) - } - } else { - child.execute().mapPartitions { iter => - val hashTable = new HashMap[InternalRow, Array[AggregateFunction1]] - val groupingProjection = new InterpretedMutableProjection(groupingExpressions, child.output) - - var currentRow: InternalRow = null - while (iter.hasNext) { - currentRow = iter.next() - numInputRows += 1 - val currentGroup = groupingProjection(currentRow) - var currentBuffer = hashTable.get(currentGroup) - if (currentBuffer == null) { - currentBuffer = newAggregateBuffer() - hashTable.put(currentGroup.copy(), currentBuffer) - } - - var i = 0 - while (i < currentBuffer.length) { - currentBuffer(i).update(currentRow) - i += 1 - } - } - - new Iterator[InternalRow] { - private[this] val hashTableIter = hashTable.entrySet().iterator() - private[this] val aggregateResults = new GenericMutableRow(computedAggregates.length) - private[this] val resultProjection = - new InterpretedMutableProjection( - resultExpressions, computedSchema ++ namedGroups.map(_._2)) - private[this] val joinedRow = new JoinedRow - - override final def hasNext: Boolean = hashTableIter.hasNext - - override final def next(): InternalRow = { - val currentEntry = hashTableIter.next() - val currentGroup = currentEntry.getKey - val currentBuffer = currentEntry.getValue - numOutputRows += 1 - - var i = 0 - while (i < currentBuffer.length) { - // Evaluating an aggregate buffer returns the result. No row is required since we - // already added all rows in the group using update. - aggregateResults(i) = currentBuffer(i).eval(EmptyRow) - i += 1 - } - resultProjection(joinedRow(aggregateResults, currentGroup)) - } - } - } - } - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala index 55e95769d3faa..91530bd63798a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala @@ -45,6 +45,9 @@ case class Expand( override def canProcessUnsafeRows: Boolean = true override def canProcessSafeRows: Boolean = true + override def references: AttributeSet = + AttributeSet(projections.flatten.flatMap(_.references)) + private[this] val projection = { if (outputsUnsafeRows) { (exprs: Seq[Expression]) => UnsafeProjection.create(exprs, child.output) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala index 0f98fe88b2101..a10d1edcc91aa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala @@ -38,7 +38,6 @@ class SparkPlanner(val sqlContext: SQLContext) extends SparkStrategies { DataSourceStrategy :: DDLStrategy :: TakeOrderedAndProject :: - HashAggregation :: Aggregation :: LeftSemiJoin :: EquiJoinSelection :: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index dd3bb33c57287..d65cb1bae7fb5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression2, Utils} +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.planning._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.{BroadcastHint, LogicalPlan} @@ -146,148 +146,104 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } } - object HashAggregation extends Strategy { - def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - // Aggregations that can be performed in two phases, before and after the shuffle. - case PartialAggregation( - namedGroupingAttributes, - rewrittenAggregateExpressions, - groupingExpressions, - partialComputation, - child) if !canBeConvertedToNewAggregation(plan) => - execution.Aggregate( - partial = false, - namedGroupingAttributes, - rewrittenAggregateExpressions, - execution.Aggregate( - partial = true, - groupingExpressions, - partialComputation, - planLater(child))) :: Nil - - case _ => Nil - } - - def canBeConvertedToNewAggregation(plan: LogicalPlan): Boolean = plan match { - case a: logical.Aggregate => - if (sqlContext.conf.useSqlAggregate2 && sqlContext.conf.codegenEnabled) { - a.newAggregation.isDefined - } else { - Utils.checkInvalidAggregateFunction2(a) - false - } - case _ => false - } - - def allAggregates(exprs: Seq[Expression]): Seq[AggregateExpression1] = - exprs.flatMap(_.collect { case a: AggregateExpression1 => a }) - } - /** * Used to plan the aggregate operator for expressions based on the AggregateFunction2 interface. */ object Aggregation extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case p: logical.Aggregate if sqlContext.conf.useSqlAggregate2 && - sqlContext.conf.codegenEnabled => - val converted = p.newAggregation - converted match { - case None => Nil // Cannot convert to new aggregation code path. - case Some(logical.Aggregate(groupingExpressions, resultExpressions, child)) => - // A single aggregate expression might appear multiple times in resultExpressions. - // In order to avoid evaluating an individual aggregate function multiple times, we'll - // build a set of the distinct aggregate expressions and build a function which can - // be used to re-write expressions so that they reference the single copy of the - // aggregate function which actually gets computed. - val aggregateExpressions = resultExpressions.flatMap { expr => - expr.collect { - case agg: AggregateExpression2 => agg - } - }.distinct - // For those distinct aggregate expressions, we create a map from the - // aggregate function to the corresponding attribute of the function. - val aggregateFunctionToAttribute = aggregateExpressions.map { agg => - val aggregateFunction = agg.aggregateFunction - val attribute = Alias(aggregateFunction, aggregateFunction.toString)().toAttribute - (aggregateFunction, agg.isDistinct) -> attribute - }.toMap - - val (functionsWithDistinct, functionsWithoutDistinct) = - aggregateExpressions.partition(_.isDistinct) - if (functionsWithDistinct.map(_.aggregateFunction.children).distinct.length > 1) { - // This is a sanity check. We should not reach here when we have multiple distinct - // column sets (aggregate.NewAggregation will not match). - sys.error( - "Multiple distinct column sets are not supported by the new aggregation" + - "code path.") - } + case logical.Aggregate(groupingExpressions, resultExpressions, child) => + // A single aggregate expression might appear multiple times in resultExpressions. + // In order to avoid evaluating an individual aggregate function multiple times, we'll + // build a set of the distinct aggregate expressions and build a function which can + // be used to re-write expressions so that they reference the single copy of the + // aggregate function which actually gets computed. + val aggregateExpressions = resultExpressions.flatMap { expr => + expr.collect { + case agg: AggregateExpression => agg + } + }.distinct + // For those distinct aggregate expressions, we create a map from the + // aggregate function to the corresponding attribute of the function. + val aggregateFunctionToAttribute = aggregateExpressions.map { agg => + val aggregateFunction = agg.aggregateFunction + val attribute = Alias(aggregateFunction, aggregateFunction.toString)().toAttribute + (aggregateFunction, agg.isDistinct) -> attribute + }.toMap + + val (functionsWithDistinct, functionsWithoutDistinct) = + aggregateExpressions.partition(_.isDistinct) + if (functionsWithDistinct.map(_.aggregateFunction.children).distinct.length > 1) { + // This is a sanity check. We should not reach here when we have multiple distinct + // column sets. Our MultipleDistinctRewriter should take care this case. + sys.error("You hit a query analyzer bug. Please report your query to " + + "Spark user mailing list.") + } - val namedGroupingExpressions = groupingExpressions.map { - case ne: NamedExpression => ne -> ne - // If the expression is not a NamedExpressions, we add an alias. - // So, when we generate the result of the operator, the Aggregate Operator - // can directly get the Seq of attributes representing the grouping expressions. - case other => - val withAlias = Alias(other, other.toString)() - other -> withAlias - } - val groupExpressionMap = namedGroupingExpressions.toMap - - // The original `resultExpressions` are a set of expressions which may reference - // aggregate expressions, grouping column values, and constants. When aggregate operator - // emits output rows, we will use `resultExpressions` to generate an output projection - // which takes the grouping columns and final aggregate result buffer as input. - // Thus, we must re-write the result expressions so that their attributes match up with - // the attributes of the final result projection's input row: - val rewrittenResultExpressions = resultExpressions.map { expr => - expr.transformDown { - case AggregateExpression2(aggregateFunction, _, isDistinct) => - // The final aggregation buffer's attributes will be `finalAggregationAttributes`, - // so replace each aggregate expression by its corresponding attribute in the set: - aggregateFunctionToAttribute(aggregateFunction, isDistinct) - case expression => - // Since we're using `namedGroupingAttributes` to extract the grouping key - // columns, we need to replace grouping key expressions with their corresponding - // attributes. We do not rely on the equality check at here since attributes may - // differ cosmetically. Instead, we use semanticEquals. - groupExpressionMap.collectFirst { - case (expr, ne) if expr semanticEquals expression => ne.toAttribute - }.getOrElse(expression) - }.asInstanceOf[NamedExpression] + val namedGroupingExpressions = groupingExpressions.map { + case ne: NamedExpression => ne -> ne + // If the expression is not a NamedExpressions, we add an alias. + // So, when we generate the result of the operator, the Aggregate Operator + // can directly get the Seq of attributes representing the grouping expressions. + case other => + val withAlias = Alias(other, other.toString)() + other -> withAlias + } + val groupExpressionMap = namedGroupingExpressions.toMap + + // The original `resultExpressions` are a set of expressions which may reference + // aggregate expressions, grouping column values, and constants. When aggregate operator + // emits output rows, we will use `resultExpressions` to generate an output projection + // which takes the grouping columns and final aggregate result buffer as input. + // Thus, we must re-write the result expressions so that their attributes match up with + // the attributes of the final result projection's input row: + val rewrittenResultExpressions = resultExpressions.map { expr => + expr.transformDown { + case AggregateExpression(aggregateFunction, _, isDistinct) => + // The final aggregation buffer's attributes will be `finalAggregationAttributes`, + // so replace each aggregate expression by its corresponding attribute in the set: + aggregateFunctionToAttribute(aggregateFunction, isDistinct) + case expression => + // Since we're using `namedGroupingAttributes` to extract the grouping key + // columns, we need to replace grouping key expressions with their corresponding + // attributes. We do not rely on the equality check at here since attributes may + // differ cosmetically. Instead, we use semanticEquals. + groupExpressionMap.collectFirst { + case (expr, ne) if expr semanticEquals expression => ne.toAttribute + }.getOrElse(expression) + }.asInstanceOf[NamedExpression] + } + + val aggregateOperator = + if (aggregateExpressions.map(_.aggregateFunction).exists(!_.supportsPartial)) { + if (functionsWithDistinct.nonEmpty) { + sys.error("Distinct columns cannot exist in Aggregate operator containing " + + "aggregate functions which don't support partial aggregation.") + } else { + aggregate.Utils.planAggregateWithoutPartial( + namedGroupingExpressions.map(_._2), + aggregateExpressions, + aggregateFunctionToAttribute, + rewrittenResultExpressions, + planLater(child)) } + } else if (functionsWithDistinct.isEmpty) { + aggregate.Utils.planAggregateWithoutDistinct( + namedGroupingExpressions.map(_._2), + aggregateExpressions, + aggregateFunctionToAttribute, + rewrittenResultExpressions, + planLater(child)) + } else { + aggregate.Utils.planAggregateWithOneDistinct( + namedGroupingExpressions.map(_._2), + functionsWithDistinct, + functionsWithoutDistinct, + aggregateFunctionToAttribute, + rewrittenResultExpressions, + planLater(child)) + } - val aggregateOperator = - if (aggregateExpressions.map(_.aggregateFunction).exists(!_.supportsPartial)) { - if (functionsWithDistinct.nonEmpty) { - sys.error("Distinct columns cannot exist in Aggregate operator containing " + - "aggregate functions which don't support partial aggregation.") - } else { - aggregate.Utils.planAggregateWithoutPartial( - namedGroupingExpressions.map(_._2), - aggregateExpressions, - aggregateFunctionToAttribute, - rewrittenResultExpressions, - planLater(child)) - } - } else if (functionsWithDistinct.isEmpty) { - aggregate.Utils.planAggregateWithoutDistinct( - namedGroupingExpressions.map(_._2), - aggregateExpressions, - aggregateFunctionToAttribute, - rewrittenResultExpressions, - planLater(child)) - } else { - aggregate.Utils.planAggregateWithOneDistinct( - namedGroupingExpressions.map(_._2), - functionsWithDistinct, - functionsWithoutDistinct, - aggregateFunctionToAttribute, - rewrittenResultExpressions, - planLater(child)) - } - - aggregateOperator - } + aggregateOperator case _ => Nil } @@ -422,18 +378,6 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { execution.Filter(condition, planLater(child)) :: Nil case e @ logical.Expand(_, _, child) => execution.Expand(e.projections, e.output, planLater(child)) :: Nil - case a @ logical.Aggregate(group, agg, child) => { - val useNewAggregation = sqlContext.conf.useSqlAggregate2 && sqlContext.conf.codegenEnabled - if (useNewAggregation && a.newAggregation.isDefined) { - // If this logical.Aggregate can be planned to use new aggregation code path - // (i.e. it can be planned by the Strategy Aggregation), we will not use the old - // aggregation code path. - Nil - } else { - Utils.checkInvalidAggregateFunction2(a) - execution.Aggregate(partial = false, group, agg, planLater(child)) :: Nil - } - } case logical.Window(projectList, windowExprs, partitionSpec, orderSpec, child) => execution.Window( projectList, windowExprs, partitionSpec, orderSpec, planLater(child)) :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala index 99fb7a40b72e1..008478a6a0e17 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala @@ -35,9 +35,9 @@ import scala.collection.mutable.ArrayBuffer abstract class AggregationIterator( groupingKeyAttributes: Seq[Attribute], valueAttributes: Seq[Attribute], - nonCompleteAggregateExpressions: Seq[AggregateExpression2], + nonCompleteAggregateExpressions: Seq[AggregateExpression], nonCompleteAggregateAttributes: Seq[Attribute], - completeAggregateExpressions: Seq[AggregateExpression2], + completeAggregateExpressions: Seq[AggregateExpression], completeAggregateAttributes: Seq[Attribute], initialInputBufferOffset: Int, resultExpressions: Seq[NamedExpression], @@ -76,14 +76,14 @@ abstract class AggregationIterator( // Initialize all AggregateFunctions by binding references if necessary, // and set inputBufferOffset and mutableBufferOffset. - protected val allAggregateFunctions: Array[AggregateFunction2] = { + protected val allAggregateFunctions: Array[AggregateFunction] = { var mutableBufferOffset = 0 var inputBufferOffset: Int = initialInputBufferOffset - val functions = new Array[AggregateFunction2](allAggregateExpressions.length) + val functions = new Array[AggregateFunction](allAggregateExpressions.length) var i = 0 while (i < allAggregateExpressions.length) { val func = allAggregateExpressions(i).aggregateFunction - val funcWithBoundReferences: AggregateFunction2 = allAggregateExpressions(i).mode match { + val funcWithBoundReferences: AggregateFunction = allAggregateExpressions(i).mode match { case Partial | Complete if func.isInstanceOf[ImperativeAggregate] => // We need to create BoundReferences if the function is not an // expression-based aggregate function (it does not support code-gen) and the mode of @@ -135,7 +135,7 @@ abstract class AggregationIterator( } // All AggregateFunctions functions with mode Partial, PartialMerge, or Final. - private[this] val nonCompleteAggregateFunctions: Array[AggregateFunction2] = + private[this] val nonCompleteAggregateFunctions: Array[AggregateFunction] = allAggregateFunctions.take(nonCompleteAggregateExpressions.length) // All imperative aggregate functions with mode Partial, PartialMerge, or Final. @@ -172,7 +172,7 @@ abstract class AggregationIterator( case (Some(Partial), None) => val updateExpressions = nonCompleteAggregateFunctions.flatMap { case ae: DeclarativeAggregate => ae.updateExpressions - case agg: AggregateFunction2 => Seq.fill(agg.aggBufferAttributes.length)(NoOp) + case agg: AggregateFunction => Seq.fill(agg.aggBufferAttributes.length)(NoOp) } val expressionAggUpdateProjection = newMutableProjection(updateExpressions, aggregationBufferSchema ++ valueAttributes)() @@ -204,7 +204,7 @@ abstract class AggregationIterator( // allAggregateFunctions.flatMap(_.cloneBufferAttributes) val mergeExpressions = nonCompleteAggregateFunctions.flatMap { case ae: DeclarativeAggregate => ae.mergeExpressions - case agg: AggregateFunction2 => Seq.fill(agg.aggBufferAttributes.length)(NoOp) + case agg: AggregateFunction => Seq.fill(agg.aggBufferAttributes.length)(NoOp) } // This projection is used to merge buffer values for all expression-based aggregates. val expressionAggMergeProjection = @@ -225,7 +225,7 @@ abstract class AggregationIterator( // Final-Complete case (Some(Final), Some(Complete)) => - val completeAggregateFunctions: Array[AggregateFunction2] = + val completeAggregateFunctions: Array[AggregateFunction] = allAggregateFunctions.takeRight(completeAggregateExpressions.length) // All imperative aggregate functions with mode Complete. val completeImperativeAggregateFunctions: Array[ImperativeAggregate] = @@ -248,7 +248,7 @@ abstract class AggregationIterator( val mergeExpressions = nonCompleteAggregateFunctions.flatMap { case ae: DeclarativeAggregate => ae.mergeExpressions - case agg: AggregateFunction2 => Seq.fill(agg.aggBufferAttributes.length)(NoOp) + case agg: AggregateFunction => Seq.fill(agg.aggBufferAttributes.length)(NoOp) } ++ completeOffsetExpressions val finalExpressionAggMergeProjection = newMutableProjection(mergeExpressions, mergeInputSchema)() @@ -256,7 +256,7 @@ abstract class AggregationIterator( val updateExpressions = finalOffsetExpressions ++ completeAggregateFunctions.flatMap { case ae: DeclarativeAggregate => ae.updateExpressions - case agg: AggregateFunction2 => Seq.fill(agg.aggBufferAttributes.length)(NoOp) + case agg: AggregateFunction => Seq.fill(agg.aggBufferAttributes.length)(NoOp) } val completeExpressionAggUpdateProjection = newMutableProjection(updateExpressions, aggregationBufferSchema ++ valueAttributes)() @@ -282,7 +282,7 @@ abstract class AggregationIterator( // Complete-only case (None, Some(Complete)) => - val completeAggregateFunctions: Array[AggregateFunction2] = + val completeAggregateFunctions: Array[AggregateFunction] = allAggregateFunctions.takeRight(completeAggregateExpressions.length) // All imperative aggregate functions with mode Complete. val completeImperativeAggregateFunctions: Array[ImperativeAggregate] = @@ -291,7 +291,7 @@ abstract class AggregationIterator( val updateExpressions = completeAggregateFunctions.flatMap { case ae: DeclarativeAggregate => ae.updateExpressions - case agg: AggregateFunction2 => Seq.fill(agg.aggBufferAttributes.length)(NoOp) + case agg: AggregateFunction => Seq.fill(agg.aggBufferAttributes.length)(NoOp) } val completeExpressionAggUpdateProjection = newMutableProjection(updateExpressions, aggregationBufferSchema ++ valueAttributes)() @@ -353,7 +353,7 @@ abstract class AggregationIterator( allAggregateFunctions.flatMap(_.aggBufferAttributes) val evalExpressions = allAggregateFunctions.map { case ae: DeclarativeAggregate => ae.evaluateExpression - case agg: AggregateFunction2 => NoOp + case agg: AggregateFunction => NoOp } val expressionAggEvalProjection = newMutableProjection(evalExpressions, bufferSchemata)() val aggregateResultSchema = nonCompleteAggregateAttributes ++ completeAggregateAttributes diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala index 4d37106e007f5..fb7f30c2aec99 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala @@ -29,9 +29,9 @@ import org.apache.spark.sql.execution.metric.SQLMetrics case class SortBasedAggregate( requiredChildDistributionExpressions: Option[Seq[Expression]], groupingExpressions: Seq[NamedExpression], - nonCompleteAggregateExpressions: Seq[AggregateExpression2], + nonCompleteAggregateExpressions: Seq[AggregateExpression], nonCompleteAggregateAttributes: Seq[Attribute], - completeAggregateExpressions: Seq[AggregateExpression2], + completeAggregateExpressions: Seq[AggregateExpression], completeAggregateAttributes: Seq[Attribute], initialInputBufferOffset: Int, resultExpressions: Seq[NamedExpression], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala index 64c673064f576..fe5c3195f867b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala @@ -19,11 +19,11 @@ package org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression2, AggregateFunction2} +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction} import org.apache.spark.sql.execution.metric.LongSQLMetric /** - * An iterator used to evaluate [[AggregateFunction2]]. It assumes the input rows have been + * An iterator used to evaluate [[AggregateFunction]]. It assumes the input rows have been * sorted by values of [[groupingKeyAttributes]]. */ class SortBasedAggregationIterator( @@ -31,9 +31,9 @@ class SortBasedAggregationIterator( groupingKeyAttributes: Seq[Attribute], valueAttributes: Seq[Attribute], inputIterator: Iterator[InternalRow], - nonCompleteAggregateExpressions: Seq[AggregateExpression2], + nonCompleteAggregateExpressions: Seq[AggregateExpression], nonCompleteAggregateAttributes: Seq[Attribute], - completeAggregateExpressions: Seq[AggregateExpression2], + completeAggregateExpressions: Seq[AggregateExpression], completeAggregateAttributes: Seq[Attribute], initialInputBufferOffset: Int, resultExpressions: Seq[NamedExpression], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala index 15616915f7364..1edde1e5a16d9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala @@ -21,7 +21,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression2 +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.execution.{SparkPlan, UnaryNode, UnsafeFixedWidthAggregationMap} @@ -30,9 +30,9 @@ import org.apache.spark.sql.types.StructType case class TungstenAggregate( requiredChildDistributionExpressions: Option[Seq[Expression]], groupingExpressions: Seq[NamedExpression], - nonCompleteAggregateExpressions: Seq[AggregateExpression2], + nonCompleteAggregateExpressions: Seq[AggregateExpression], nonCompleteAggregateAttributes: Seq[Attribute], - completeAggregateExpressions: Seq[AggregateExpression2], + completeAggregateExpressions: Seq[AggregateExpression], completeAggregateAttributes: Seq[Attribute], initialInputBufferOffset: Int, resultExpressions: Seq[NamedExpression], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala index ce8d592c368ee..04391443920ac 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala @@ -64,12 +64,12 @@ import org.apache.spark.sql.types.StructType * @param groupingExpressions * expressions for grouping keys * @param nonCompleteAggregateExpressions - * [[AggregateExpression2]] containing [[AggregateFunction2]]s with mode [[Partial]], - * [[PartialMerge]], or [[Final]]. + * [[AggregateExpression]] containing [[AggregateFunction]]s with mode [[Partial]], + * [[PartialMerge]], or [[Final]]. * @param nonCompleteAggregateAttributes the attributes of the nonCompleteAggregateExpressions' * outputs when they are stored in the final aggregation buffer. * @param completeAggregateExpressions - * [[AggregateExpression2]] containing [[AggregateFunction2]]s with mode [[Complete]]. + * [[AggregateExpression]] containing [[AggregateFunction]]s with mode [[Complete]]. * @param completeAggregateAttributes the attributes of completeAggregateExpressions' outputs * when they are stored in the final aggregation buffer. * @param resultExpressions @@ -83,9 +83,9 @@ import org.apache.spark.sql.types.StructType */ class TungstenAggregationIterator( groupingExpressions: Seq[NamedExpression], - nonCompleteAggregateExpressions: Seq[AggregateExpression2], + nonCompleteAggregateExpressions: Seq[AggregateExpression], nonCompleteAggregateAttributes: Seq[Attribute], - completeAggregateExpressions: Seq[AggregateExpression2], + completeAggregateExpressions: Seq[AggregateExpression], completeAggregateAttributes: Seq[Attribute], initialInputBufferOffset: Int, resultExpressions: Seq[NamedExpression], @@ -106,7 +106,7 @@ class TungstenAggregationIterator( // A Seq containing all AggregateExpressions. // It is important that all AggregateExpressions with the mode Partial, PartialMerge or Final // are at the beginning of the allAggregateExpressions. - private[this] val allAggregateExpressions: Seq[AggregateExpression2] = + private[this] val allAggregateExpressions: Seq[AggregateExpression] = nonCompleteAggregateExpressions ++ completeAggregateExpressions // Check to make sure we do not have more than three modes in our AggregateExpressions. @@ -150,10 +150,10 @@ class TungstenAggregationIterator( // Initialize all AggregateFunctions by binding references, if necessary, // and setting inputBufferOffset and mutableBufferOffset. private def initializeAllAggregateFunctions( - startingInputBufferOffset: Int): Array[AggregateFunction2] = { + startingInputBufferOffset: Int): Array[AggregateFunction] = { var mutableBufferOffset = 0 var inputBufferOffset: Int = startingInputBufferOffset - val functions = new Array[AggregateFunction2](allAggregateExpressions.length) + val functions = new Array[AggregateFunction](allAggregateExpressions.length) var i = 0 while (i < allAggregateExpressions.length) { val func = allAggregateExpressions(i).aggregateFunction @@ -195,7 +195,7 @@ class TungstenAggregationIterator( functions } - private[this] var allAggregateFunctions: Array[AggregateFunction2] = + private[this] var allAggregateFunctions: Array[AggregateFunction] = initializeAllAggregateFunctions(initialInputBufferOffset) // Positions of those imperative aggregate functions in allAggregateFunctions. @@ -263,7 +263,7 @@ class TungstenAggregationIterator( case (Some(Partial), None) => val updateExpressions = allAggregateFunctions.flatMap { case ae: DeclarativeAggregate => ae.updateExpressions - case agg: AggregateFunction2 => Seq.fill(agg.aggBufferAttributes.length)(NoOp) + case agg: AggregateFunction => Seq.fill(agg.aggBufferAttributes.length)(NoOp) } val imperativeAggregateFunctions: Array[ImperativeAggregate] = allAggregateFunctions.collect { case func: ImperativeAggregate => func} @@ -286,7 +286,7 @@ class TungstenAggregationIterator( case (Some(PartialMerge), None) | (Some(Final), None) => val mergeExpressions = allAggregateFunctions.flatMap { case ae: DeclarativeAggregate => ae.mergeExpressions - case agg: AggregateFunction2 => Seq.fill(agg.aggBufferAttributes.length)(NoOp) + case agg: AggregateFunction => Seq.fill(agg.aggBufferAttributes.length)(NoOp) } val imperativeAggregateFunctions: Array[ImperativeAggregate] = allAggregateFunctions.collect { case func: ImperativeAggregate => func} @@ -307,11 +307,11 @@ class TungstenAggregationIterator( // Final-Complete case (Some(Final), Some(Complete)) => - val completeAggregateFunctions: Array[AggregateFunction2] = + val completeAggregateFunctions: Array[AggregateFunction] = allAggregateFunctions.takeRight(completeAggregateExpressions.length) val completeImperativeAggregateFunctions: Array[ImperativeAggregate] = completeAggregateFunctions.collect { case func: ImperativeAggregate => func } - val nonCompleteAggregateFunctions: Array[AggregateFunction2] = + val nonCompleteAggregateFunctions: Array[AggregateFunction] = allAggregateFunctions.take(nonCompleteAggregateExpressions.length) val nonCompleteImperativeAggregateFunctions: Array[ImperativeAggregate] = nonCompleteAggregateFunctions.collect { case func: ImperativeAggregate => func } @@ -321,7 +321,7 @@ class TungstenAggregationIterator( val mergeExpressions = nonCompleteAggregateFunctions.flatMap { case ae: DeclarativeAggregate => ae.mergeExpressions - case agg: AggregateFunction2 => Seq.fill(agg.aggBufferAttributes.length)(NoOp) + case agg: AggregateFunction => Seq.fill(agg.aggBufferAttributes.length)(NoOp) } ++ completeOffsetExpressions val finalMergeProjection = newMutableProjection(mergeExpressions, aggregationBufferAttributes ++ inputAttributes)() @@ -331,7 +331,7 @@ class TungstenAggregationIterator( Seq.fill(nonCompleteAggregateFunctions.map(_.aggBufferAttributes.length).sum)(NoOp) val updateExpressions = finalOffsetExpressions ++ completeAggregateFunctions.flatMap { case ae: DeclarativeAggregate => ae.updateExpressions - case agg: AggregateFunction2 => Seq.fill(agg.aggBufferAttributes.length)(NoOp) + case agg: AggregateFunction => Seq.fill(agg.aggBufferAttributes.length)(NoOp) } val completeUpdateProjection = newMutableProjection(updateExpressions, aggregationBufferAttributes ++ inputAttributes)() @@ -358,7 +358,7 @@ class TungstenAggregationIterator( // Complete-only case (None, Some(Complete)) => - val completeAggregateFunctions: Array[AggregateFunction2] = + val completeAggregateFunctions: Array[AggregateFunction] = allAggregateFunctions.takeRight(completeAggregateExpressions.length) // All imperative aggregate functions with mode Complete. val completeImperativeAggregateFunctions: Array[ImperativeAggregate] = @@ -366,7 +366,7 @@ class TungstenAggregationIterator( val updateExpressions = completeAggregateFunctions.flatMap { case ae: DeclarativeAggregate => ae.updateExpressions - case agg: AggregateFunction2 => Seq.fill(agg.aggBufferAttributes.length)(NoOp) + case agg: AggregateFunction => Seq.fill(agg.aggBufferAttributes.length)(NoOp) } val completeExpressionAggUpdateProjection = newMutableProjection(updateExpressions, aggregationBufferAttributes ++ inputAttributes)() @@ -414,7 +414,7 @@ class TungstenAggregationIterator( val joinedRow = new JoinedRow() val evalExpressions = allAggregateFunctions.map { case ae: DeclarativeAggregate => ae.evaluateExpression - case agg: AggregateFunction2 => NoOp + case agg: AggregateFunction => NoOp } val expressionAggEvalProjection = newMutableProjection(evalExpressions, bufferAttributes)() // These are the attributes of the row produced by `expressionAggEvalProjection` diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala index d2f56e0fc14a4..20359c1e540e2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.{InternalRow, CatalystTypeConverters} import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection import org.apache.spark.sql.catalyst.expressions.{MutableRow, InterpretedMutableProjection, AttributeReference, Expression} -import org.apache.spark.sql.catalyst.expressions.aggregate.{ImperativeAggregate, AggregateFunction2} +import org.apache.spark.sql.catalyst.expressions.aggregate.{ImperativeAggregate, AggregateFunction} import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction} import org.apache.spark.sql.types._ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala index eaafd83158a15..79abf2d5929be 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala @@ -28,8 +28,8 @@ object Utils { def planAggregateWithoutPartial( groupingExpressions: Seq[NamedExpression], - aggregateExpressions: Seq[AggregateExpression2], - aggregateFunctionToAttribute: Map[(AggregateFunction2, Boolean), Attribute], + aggregateExpressions: Seq[AggregateExpression], + aggregateFunctionToAttribute: Map[(AggregateFunction, Boolean), Attribute], resultExpressions: Seq[NamedExpression], child: SparkPlan): Seq[SparkPlan] = { @@ -54,8 +54,8 @@ object Utils { def planAggregateWithoutDistinct( groupingExpressions: Seq[NamedExpression], - aggregateExpressions: Seq[AggregateExpression2], - aggregateFunctionToAttribute: Map[(AggregateFunction2, Boolean), Attribute], + aggregateExpressions: Seq[AggregateExpression], + aggregateFunctionToAttribute: Map[(AggregateFunction, Boolean), Attribute], resultExpressions: Seq[NamedExpression], child: SparkPlan): Seq[SparkPlan] = { // Check if we can use TungstenAggregate. @@ -137,9 +137,9 @@ object Utils { def planAggregateWithOneDistinct( groupingExpressions: Seq[NamedExpression], - functionsWithDistinct: Seq[AggregateExpression2], - functionsWithoutDistinct: Seq[AggregateExpression2], - aggregateFunctionToAttribute: Map[(AggregateFunction2, Boolean), Attribute], + functionsWithDistinct: Seq[AggregateExpression], + functionsWithoutDistinct: Seq[AggregateExpression], + aggregateFunctionToAttribute: Map[(AggregateFunction, Boolean), Attribute], resultExpressions: Seq[NamedExpression], child: SparkPlan): Seq[SparkPlan] = { @@ -253,16 +253,16 @@ object Utils { // Children of an AggregateFunction with DISTINCT keyword has already // been evaluated. At here, we need to replace original children // to AttributeReferences. - case agg @ AggregateExpression2(aggregateFunction, mode, true) => + case agg @ AggregateExpression(aggregateFunction, mode, true) => val rewrittenAggregateFunction = aggregateFunction.transformDown { case expr if expr == distinctColumnExpression => distinctColumnAttribute - }.asInstanceOf[AggregateFunction2] + }.asInstanceOf[AggregateFunction] // We rewrite the aggregate function to a non-distinct aggregation because // its input will have distinct arguments. // We just keep the isDistinct setting to true, so when users look at the query plan, // they still can see distinct aggregations. val rewrittenAggregateExpression = - AggregateExpression2(rewrittenAggregateFunction, Complete, isDistinct = true) + AggregateExpression(rewrittenAggregateFunction, Complete, isDistinct = true) val aggregateFunctionAttribute = aggregateFunctionToAttribute(agg.aggregateFunction, true) (rewrittenAggregateExpression, aggregateFunctionAttribute) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala index 0b3192a6da9d8..8cc25c2440633 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.expressions import org.apache.spark.sql.catalyst.encoders.{encoderFor, Encoder} -import org.apache.spark.sql.catalyst.expressions.aggregate.{Complete, AggregateExpression2} +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete} import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression import org.apache.spark.sql.{Dataset, DataFrame, TypedColumn} @@ -70,7 +70,7 @@ abstract class Aggregator[-A, B, C] { implicit bEncoder: Encoder[B], cEncoder: Encoder[C]): TypedColumn[A, C] = { val expr = - new AggregateExpression2( + new AggregateExpression( TypedAggregateExpression(this), Complete, false) @@ -78,4 +78,3 @@ abstract class Aggregator[-A, B, C] { new TypedColumn[A, C](expr, encoderFor[C]) } } - diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala index 8b9247adea200..fc873c04f88f0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala @@ -18,9 +18,9 @@ package org.apache.spark.sql.expressions import org.apache.spark.annotation.Experimental -import org.apache.spark.sql.types.BooleanType import org.apache.spark.sql.{Column, catalyst} import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate._ /** @@ -141,40 +141,56 @@ class WindowSpec private[sql]( */ private[sql] def withAggregate(aggregate: Column): Column = { val windowExpr = aggregate.expr match { - case Average(child) => WindowExpression( - UnresolvedWindowFunction("avg", child :: Nil), - WindowSpecDefinition(partitionSpec, orderSpec, frame)) - case Sum(child) => WindowExpression( - UnresolvedWindowFunction("sum", child :: Nil), - WindowSpecDefinition(partitionSpec, orderSpec, frame)) - case Count(child) => WindowExpression( - UnresolvedWindowFunction("count", child :: Nil), - WindowSpecDefinition(partitionSpec, orderSpec, frame)) - case First(child, ignoreNulls) => WindowExpression( - // TODO this is a hack for Hive UDAF first_value - UnresolvedWindowFunction( - "first_value", - child :: ignoreNulls :: Nil), - WindowSpecDefinition(partitionSpec, orderSpec, frame)) - case Last(child, ignoreNulls) => WindowExpression( - // TODO this is a hack for Hive UDAF last_value - UnresolvedWindowFunction( - "last_value", - child :: ignoreNulls :: Nil), - WindowSpecDefinition(partitionSpec, orderSpec, frame)) - case Min(child) => WindowExpression( - UnresolvedWindowFunction("min", child :: Nil), - WindowSpecDefinition(partitionSpec, orderSpec, frame)) - case Max(child) => WindowExpression( - UnresolvedWindowFunction("max", child :: Nil), - WindowSpecDefinition(partitionSpec, orderSpec, frame)) - case wf: WindowFunction => WindowExpression( - wf, - WindowSpecDefinition(partitionSpec, orderSpec, frame)) + // First, we check if we get an aggregate function without the DISTINCT keyword. + // Right now, we do not support using a DISTINCT aggregate function as a + // window function. + case AggregateExpression(aggregateFunction, _, isDistinct) if !isDistinct => + aggregateFunction match { + case Average(child) => WindowExpression( + UnresolvedWindowFunction("avg", child :: Nil), + WindowSpecDefinition(partitionSpec, orderSpec, frame)) + case Sum(child) => WindowExpression( + UnresolvedWindowFunction("sum", child :: Nil), + WindowSpecDefinition(partitionSpec, orderSpec, frame)) + case Count(child) => WindowExpression( + UnresolvedWindowFunction("count", child :: Nil), + WindowSpecDefinition(partitionSpec, orderSpec, frame)) + case First(child, ignoreNulls) => WindowExpression( + // TODO this is a hack for Hive UDAF first_value + UnresolvedWindowFunction( + "first_value", + child :: ignoreNulls :: Nil), + WindowSpecDefinition(partitionSpec, orderSpec, frame)) + case Last(child, ignoreNulls) => WindowExpression( + // TODO this is a hack for Hive UDAF last_value + UnresolvedWindowFunction( + "last_value", + child :: ignoreNulls :: Nil), + WindowSpecDefinition(partitionSpec, orderSpec, frame)) + case Min(child) => WindowExpression( + UnresolvedWindowFunction("min", child :: Nil), + WindowSpecDefinition(partitionSpec, orderSpec, frame)) + case Max(child) => WindowExpression( + UnresolvedWindowFunction("max", child :: Nil), + WindowSpecDefinition(partitionSpec, orderSpec, frame)) + case x => + throw new UnsupportedOperationException(s"$x is not supported in a window operation.") + } + + case AggregateExpression(aggregateFunction, _, isDistinct) if isDistinct => + throw new UnsupportedOperationException( + s"Distinct aggregate function ${aggregateFunction} is not supported " + + s"in window operation.") + + case wf: WindowFunction => + WindowExpression( + wf, + WindowSpecDefinition(partitionSpec, orderSpec, frame)) + case x => - throw new UnsupportedOperationException(s"$x is not supported in window operation.") + throw new UnsupportedOperationException(s"$x is not supported in a window operation.") } + new Column(windowExpr) } - } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala index 258afadc76951..11dbf391cff98 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.expressions -import org.apache.spark.sql.catalyst.expressions.aggregate.{Complete, AggregateExpression2} +import org.apache.spark.sql.catalyst.expressions.aggregate.{Complete, AggregateExpression} import org.apache.spark.sql.execution.aggregate.ScalaUDAF import org.apache.spark.sql.{Column, Row} import org.apache.spark.sql.types._ @@ -109,7 +109,7 @@ abstract class UserDefinedAggregateFunction extends Serializable { @scala.annotation.varargs def apply(exprs: Column*): Column = { val aggregateExpression = - AggregateExpression2( + AggregateExpression( ScalaUDAF(exprs.map(_.expr), this), Complete, isDistinct = false) @@ -123,7 +123,7 @@ abstract class UserDefinedAggregateFunction extends Serializable { @scala.annotation.varargs def distinct(exprs: Column*): Column = { val aggregateExpression = - AggregateExpression2( + AggregateExpression( ScalaUDAF(exprs.map(_.expr), this), Complete, isDistinct = true) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 6d56542ee0875..22104e4d48617 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.{SqlParser, ScalaReflection} import org.apache.spark.sql.catalyst.analysis.{UnresolvedFunction, Star} import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, Encoder} import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical.BroadcastHint import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -76,6 +77,12 @@ object functions extends LegacyFunctions { private def withExpr(expr: Expression): Column = Column(expr) + private def withAggregateFunction( + func: AggregateFunction, + isDistinct: Boolean = false): Column = { + Column(func.toAggregateExpression(isDistinct)) + } + private implicit def newLongEncoder: Encoder[Long] = ExpressionEncoder[Long](flat = true) @@ -154,7 +161,9 @@ object functions extends LegacyFunctions { * @group agg_funcs * @since 1.3.0 */ - def approxCountDistinct(e: Column): Column = withExpr { ApproxCountDistinct(e.expr) } + def approxCountDistinct(e: Column): Column = withAggregateFunction { + HyperLogLogPlusPlus(e.expr) + } /** * Aggregate function: returns the approximate number of distinct items in a group. @@ -170,8 +179,8 @@ object functions extends LegacyFunctions { * @group agg_funcs * @since 1.3.0 */ - def approxCountDistinct(e: Column, rsd: Double): Column = withExpr { - ApproxCountDistinct(e.expr, rsd) + def approxCountDistinct(e: Column, rsd: Double): Column = withAggregateFunction { + HyperLogLogPlusPlus(e.expr, rsd, 0, 0) } /** @@ -190,7 +199,7 @@ object functions extends LegacyFunctions { * @group agg_funcs * @since 1.3.0 */ - def avg(e: Column): Column = withExpr { Average(e.expr) } + def avg(e: Column): Column = withAggregateFunction { Average(e.expr) } /** * Aggregate function: returns the average of the values in a group. @@ -226,7 +235,7 @@ object functions extends LegacyFunctions { * @group agg_funcs * @since 1.6.0 */ - def corr(column1: Column, column2: Column): Column = withExpr { + def corr(column1: Column, column2: Column): Column = withAggregateFunction { Corr(column1.expr, column2.expr) } @@ -246,7 +255,7 @@ object functions extends LegacyFunctions { * @group agg_funcs * @since 1.3.0 */ - def count(e: Column): Column = withExpr { + def count(e: Column): Column = withAggregateFunction { e.expr match { // Turn count(*) into count(1) case s: Star => Count(Literal(1)) @@ -269,8 +278,8 @@ object functions extends LegacyFunctions { * @since 1.3.0 */ @scala.annotation.varargs - def countDistinct(expr: Column, exprs: Column*): Column = withExpr { - CountDistinct((expr +: exprs).map(_.expr)) + def countDistinct(expr: Column, exprs: Column*): Column = { + withAggregateFunction(Count.apply((expr +: exprs).map(_.expr)), isDistinct = true) } /** @@ -289,7 +298,7 @@ object functions extends LegacyFunctions { * @group agg_funcs * @since 1.3.0 */ - def first(e: Column): Column = withExpr { First(e.expr) } + def first(e: Column): Column = withAggregateFunction { new First(e.expr) } /** * Aggregate function: returns the first value of a column in a group. @@ -305,7 +314,7 @@ object functions extends LegacyFunctions { * @group agg_funcs * @since 1.6.0 */ - def kurtosis(e: Column): Column = withExpr { Kurtosis(e.expr) } + def kurtosis(e: Column): Column = withAggregateFunction { Kurtosis(e.expr) } /** * Aggregate function: returns the last value in a group. @@ -313,7 +322,7 @@ object functions extends LegacyFunctions { * @group agg_funcs * @since 1.3.0 */ - def last(e: Column): Column = withExpr { Last(e.expr) } + def last(e: Column): Column = withAggregateFunction { new Last(e.expr) } /** * Aggregate function: returns the last value of the column in a group. @@ -329,7 +338,7 @@ object functions extends LegacyFunctions { * @group agg_funcs * @since 1.3.0 */ - def max(e: Column): Column = withExpr { Max(e.expr) } + def max(e: Column): Column = withAggregateFunction { Max(e.expr) } /** * Aggregate function: returns the maximum value of the column in a group. @@ -363,7 +372,7 @@ object functions extends LegacyFunctions { * @group agg_funcs * @since 1.3.0 */ - def min(e: Column): Column = withExpr { Min(e.expr) } + def min(e: Column): Column = withAggregateFunction { Min(e.expr) } /** * Aggregate function: returns the minimum value of the column in a group. @@ -379,7 +388,7 @@ object functions extends LegacyFunctions { * @group agg_funcs * @since 1.6.0 */ - def skewness(e: Column): Column = withExpr { Skewness(e.expr) } + def skewness(e: Column): Column = withAggregateFunction { Skewness(e.expr) } /** * Aggregate function: alias for [[stddev_samp]]. @@ -387,7 +396,7 @@ object functions extends LegacyFunctions { * @group agg_funcs * @since 1.6.0 */ - def stddev(e: Column): Column = withExpr { StddevSamp(e.expr) } + def stddev(e: Column): Column = withAggregateFunction { StddevSamp(e.expr) } /** * Aggregate function: returns the unbiased sample standard deviation of @@ -396,7 +405,7 @@ object functions extends LegacyFunctions { * @group agg_funcs * @since 1.6.0 */ - def stddev_samp(e: Column): Column = withExpr { StddevSamp(e.expr) } + def stddev_samp(e: Column): Column = withAggregateFunction { StddevSamp(e.expr) } /** * Aggregate function: returns the population standard deviation of @@ -405,7 +414,7 @@ object functions extends LegacyFunctions { * @group agg_funcs * @since 1.6.0 */ - def stddev_pop(e: Column): Column = withExpr { StddevPop(e.expr) } + def stddev_pop(e: Column): Column = withAggregateFunction { StddevPop(e.expr) } /** * Aggregate function: returns the sum of all values in the expression. @@ -413,7 +422,7 @@ object functions extends LegacyFunctions { * @group agg_funcs * @since 1.3.0 */ - def sum(e: Column): Column = withExpr { Sum(e.expr) } + def sum(e: Column): Column = withAggregateFunction { Sum(e.expr) } /** * Aggregate function: returns the sum of all values in the given column. @@ -429,7 +438,7 @@ object functions extends LegacyFunctions { * @group agg_funcs * @since 1.3.0 */ - def sumDistinct(e: Column): Column = withExpr { SumDistinct(e.expr) } + def sumDistinct(e: Column): Column = withAggregateFunction(Sum(e.expr), isDistinct = true) /** * Aggregate function: returns the sum of distinct values in the expression. @@ -445,7 +454,7 @@ object functions extends LegacyFunctions { * @group agg_funcs * @since 1.6.0 */ - def variance(e: Column): Column = withExpr { VarianceSamp(e.expr) } + def variance(e: Column): Column = withAggregateFunction { VarianceSamp(e.expr) } /** * Aggregate function: returns the unbiased variance of the values in a group. @@ -453,7 +462,7 @@ object functions extends LegacyFunctions { * @group agg_funcs * @since 1.6.0 */ - def var_samp(e: Column): Column = withExpr { VarianceSamp(e.expr) } + def var_samp(e: Column): Column = withAggregateFunction { VarianceSamp(e.expr) } /** * Aggregate function: returns the population variance of the values in a group. @@ -461,7 +470,7 @@ object functions extends LegacyFunctions { * @group agg_funcs * @since 1.6.0 */ - def var_pop(e: Column): Column = withExpr { VariancePop(e.expr) } + def var_pop(e: Column): Column = withAggregateFunction { VariancePop(e.expr) } ////////////////////////////////////////////////////////////////////////////////////////////// // Window functions diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 3de277a79a52c..441a0c6d0e36e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -237,34 +237,10 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } test("SPARK-8828 sum should return null if all input values are null") { - withSQLConf(SQLConf.USE_SQL_AGGREGATE2.key -> "true") { - withSQLConf(SQLConf.CODEGEN_ENABLED.key -> "true") { - checkAnswer( - sql("select sum(a), avg(a) from allNulls"), - Seq(Row(null, null)) - ) - } - withSQLConf(SQLConf.CODEGEN_ENABLED.key -> "false") { - checkAnswer( - sql("select sum(a), avg(a) from allNulls"), - Seq(Row(null, null)) - ) - } - } - withSQLConf(SQLConf.USE_SQL_AGGREGATE2.key -> "false") { - withSQLConf(SQLConf.CODEGEN_ENABLED.key -> "true") { - checkAnswer( - sql("select sum(a), avg(a) from allNulls"), - Seq(Row(null, null)) - ) - } - withSQLConf(SQLConf.CODEGEN_ENABLED.key -> "false") { - checkAnswer( - sql("select sum(a), avg(a) from allNulls"), - Seq(Row(null, null)) - ) - } - } + checkAnswer( + sql("select sum(a), avg(a) from allNulls"), + Seq(Row(null, null)) + ) } private def testCodeGen(sqlText: String, expectedResults: Seq[Row]): Unit = { @@ -507,29 +483,22 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } test("literal in agg grouping expressions") { - def literalInAggTest(): Unit = { - checkAnswer( - sql("SELECT a, count(1) FROM testData2 GROUP BY a, 1"), - Seq(Row(1, 2), Row(2, 2), Row(3, 2))) - checkAnswer( - sql("SELECT a, count(2) FROM testData2 GROUP BY a, 2"), - Seq(Row(1, 2), Row(2, 2), Row(3, 2))) - - checkAnswer( - sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a, 1"), - sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a")) - checkAnswer( - sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a, 1 + 2"), - sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a")) - checkAnswer( - sql("SELECT 1, 2, sum(b) FROM testData2 GROUP BY 1, 2"), - sql("SELECT 1, 2, sum(b) FROM testData2")) - } + checkAnswer( + sql("SELECT a, count(1) FROM testData2 GROUP BY a, 1"), + Seq(Row(1, 2), Row(2, 2), Row(3, 2))) + checkAnswer( + sql("SELECT a, count(2) FROM testData2 GROUP BY a, 2"), + Seq(Row(1, 2), Row(2, 2), Row(3, 2))) - literalInAggTest() - withSQLConf(SQLConf.USE_SQL_AGGREGATE2.key -> "false") { - literalInAggTest() - } + checkAnswer( + sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a, 1"), + sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a")) + checkAnswer( + sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a, 1 + 2"), + sql("SELECT a, 1, sum(b) FROM testData2 GROUP BY a")) + checkAnswer( + sql("SELECT 1, 2, sum(b) FROM testData2 GROUP BY 1, 2"), + sql("SELECT 1, 2, sum(b) FROM testData2")) } test("aggregates with nulls") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala index a229e5814df89..e31c528f3a633 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala @@ -21,16 +21,13 @@ import org.apache.spark.sql.catalyst.util.{GenericArrayData, ArrayData} import scala.beans.{BeanInfo, BeanProperty} -import com.clearspring.analytics.stream.cardinality.HyperLogLog - import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.CatalystTypeConverters -import org.apache.spark.sql.catalyst.expressions.{OpenHashSetUDT, HyperLogLogUDT} +import org.apache.spark.sql.catalyst.expressions.OpenHashSetUDT import org.apache.spark.sql.execution.datasources.parquet.ParquetTest import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ -import org.apache.spark.util.Utils import org.apache.spark.util.collection.OpenHashSet @@ -134,16 +131,6 @@ class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with ParquetT df.orderBy('int).limit(1).groupBy('int).agg(first('vec)).collect()(0).getAs[MyDenseVector](0) } - test("HyperLogLogUDT") { - val hyperLogLogUDT = HyperLogLogUDT - val hyperLogLog = new HyperLogLog(0.4) - (1 to 10).foreach(i => hyperLogLog.offer(Row(i))) - - val actual = hyperLogLogUDT.deserialize(hyperLogLogUDT.serialize(hyperLogLog)) - assert(actual.cardinality() === hyperLogLog.cardinality()) - assert(java.util.Arrays.equals(actual.getBytes, hyperLogLog.getBytes)) - } - test("OpenHashSetUDT") { val openHashSetUDT = new OpenHashSetUDT(IntegerType) val set = new OpenHashSet[Int] diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 2076c573b56c1..44634dacbde68 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -38,7 +38,7 @@ class PlannerSuite extends SharedSQLContext { private def testPartialAggregationPlan(query: LogicalPlan): Unit = { val planner = sqlContext.planner import planner._ - val plannedOption = HashAggregation(query).headOption.orElse(Aggregation(query).headOption) + val plannedOption = Aggregation(query).headOption val planned = plannedOption.getOrElse( fail(s"Could query play aggregation query $query. Is it an aggregation query?")) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index cdd885ba14203..4b4f5c6c45c7a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -152,36 +152,6 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { ) } - test("Aggregate metrics") { - withSQLConf( - SQLConf.UNSAFE_ENABLED.key -> "false", - SQLConf.CODEGEN_ENABLED.key -> "false", - SQLConf.TUNGSTEN_ENABLED.key -> "false") { - // Assume the execution plan is - // ... -> Aggregate(nodeId = 2) -> TungstenExchange(nodeId = 1) -> Aggregate(nodeId = 0) - val df = testData2.groupBy().count() // 2 partitions - testSparkPlanMetrics(df, 1, Map( - 2L -> ("Aggregate", Map( - "number of input rows" -> 6L, - "number of output rows" -> 2L)), - 0L -> ("Aggregate", Map( - "number of input rows" -> 2L, - "number of output rows" -> 1L))) - ) - - // 2 partitions and each partition contains 2 keys - val df2 = testData2.groupBy('a).count() - testSparkPlanMetrics(df2, 1, Map( - 2L -> ("Aggregate", Map( - "number of input rows" -> 6L, - "number of output rows" -> 4L)), - 0L -> ("Aggregate", Map( - "number of input rows" -> 4L, - "number of output rows" -> 3L))) - ) - } - } - test("SortBasedAggregate metrics") { // Because SortBasedAggregate may skip different rows if the number of partitions is different, // this test should use the deterministic number of partitions. diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index c5f69657f5293..ba6204633b9ca 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -584,7 +584,6 @@ class HiveContext private[hive]( HiveTableScans, DataSinks, Scripts, - HashAggregation, Aggregation, LeftSemiJoin, EquiJoinSelection, diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index ab88c1e68fd72..6f8ed413a06cd 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -38,6 +38,7 @@ import org.apache.spark.Logging import org.apache.spark.sql.{AnalysisException, catalyst} import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.{logical, _} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.trees.CurrentOrigin @@ -1508,9 +1509,10 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C UnresolvedStar(Some(UnresolvedAttribute.parseAttributeName(name))) /* Aggregate Functions */ - case Token("TOK_FUNCTIONSTAR", Token(COUNT(), Nil) :: Nil) => Count(Literal(1)) - case Token("TOK_FUNCTIONDI", Token(COUNT(), Nil) :: args) => CountDistinct(args.map(nodeToExpr)) - case Token("TOK_FUNCTIONDI", Token(SUM(), Nil) :: arg :: Nil) => SumDistinct(nodeToExpr(arg)) + case Token("TOK_FUNCTIONDI", Token(COUNT(), Nil) :: args) => + Count(args.map(nodeToExpr)).toAggregateExpression(isDistinct = true) + case Token("TOK_FUNCTIONSTAR", Token(COUNT(), Nil) :: Nil) => + Count(Literal(1)).toAggregateExpression() /* Casts */ case Token("TOK_FUNCTION", Token("TOK_STRING", Nil) :: arg :: Nil) => diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala index ea36c132bb190..6bf2c53440baf 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -69,11 +69,7 @@ class ScalaAggregateFunction(schema: StructType) extends UserDefinedAggregateFun abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { import testImplicits._ - var originalUseAggregate2: Boolean = _ - override def beforeAll(): Unit = { - originalUseAggregate2 = sqlContext.conf.useSqlAggregate2 - sqlContext.setConf(SQLConf.USE_SQL_AGGREGATE2.key, "true") val data1 = Seq[(Integer, Integer)]( (1, 10), (null, -60), @@ -120,7 +116,6 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te sqlContext.sql("DROP TABLE IF EXISTS agg1") sqlContext.sql("DROP TABLE IF EXISTS agg2") sqlContext.dropTempTable("emptyTable") - sqlContext.setConf(SQLConf.USE_SQL_AGGREGATE2.key, originalUseAggregate2.toString) } test("empty table") { @@ -447,73 +442,80 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te } test("single distinct column set") { - // DISTINCT is not meaningful with Max and Min, so we just ignore the DISTINCT keyword. - checkAnswer( - sqlContext.sql( - """ - |SELECT - | min(distinct value1), - | sum(distinct value1), - | avg(value1), - | avg(value2), - | max(distinct value1) - |FROM agg2 - """.stripMargin), - Row(-60, 70.0, 101.0/9.0, 5.6, 100)) - - checkAnswer( - sqlContext.sql( - """ - |SELECT - | mydoubleavg(distinct value1), - | avg(value1), - | avg(value2), - | key, - | mydoubleavg(value1 - 1), - | mydoubleavg(distinct value1) * 0.1, - | avg(value1 + value2) - |FROM agg2 - |GROUP BY key - """.stripMargin), - Row(120.0, 70.0/3.0, -10.0/3.0, 1, 67.0/3.0 + 100.0, 12.0, 20.0) :: - Row(100.0, 1.0/3.0, 1.0, 2, -2.0/3.0 + 100.0, 10.0, 2.0) :: - Row(null, null, 3.0, 3, null, null, null) :: - Row(110.0, 10.0, 20.0, null, 109.0, 11.0, 30.0) :: Nil) - - checkAnswer( - sqlContext.sql( - """ - |SELECT - | key, - | mydoubleavg(distinct value1), - | mydoublesum(value2), - | mydoublesum(distinct value1), - | mydoubleavg(distinct value1), - | mydoubleavg(value1) - |FROM agg2 - |GROUP BY key - """.stripMargin), - Row(1, 120.0, -10.0, 40.0, 120.0, 70.0/3.0 + 100.0) :: - Row(2, 100.0, 3.0, 0.0, 100.0, 1.0/3.0 + 100.0) :: - Row(3, null, 3.0, null, null, null) :: - Row(null, 110.0, 60.0, 30.0, 110.0, 110.0) :: Nil) - - checkAnswer( - sqlContext.sql( - """ - |SELECT - | count(value1), - | count(*), - | count(1), - | count(DISTINCT value1), - | key - |FROM agg2 - |GROUP BY key - """.stripMargin), - Row(3, 3, 3, 2, 1) :: - Row(3, 4, 4, 2, 2) :: - Row(0, 2, 2, 0, 3) :: - Row(3, 4, 4, 3, null) :: Nil) + Seq(true, false).foreach { specializeSingleDistinctAgg => + val conf = + (SQLConf.SPECIALIZE_SINGLE_DISTINCT_AGG_PLANNING.key, + specializeSingleDistinctAgg.toString) + withSQLConf(conf) { + // DISTINCT is not meaningful with Max and Min, so we just ignore the DISTINCT keyword. + checkAnswer( + sqlContext.sql( + """ + |SELECT + | min(distinct value1), + | sum(distinct value1), + | avg(value1), + | avg(value2), + | max(distinct value1) + |FROM agg2 + """.stripMargin), + Row(-60, 70.0, 101.0/9.0, 5.6, 100)) + + checkAnswer( + sqlContext.sql( + """ + |SELECT + | mydoubleavg(distinct value1), + | avg(value1), + | avg(value2), + | key, + | mydoubleavg(value1 - 1), + | mydoubleavg(distinct value1) * 0.1, + | avg(value1 + value2) + |FROM agg2 + |GROUP BY key + """.stripMargin), + Row(120.0, 70.0/3.0, -10.0/3.0, 1, 67.0/3.0 + 100.0, 12.0, 20.0) :: + Row(100.0, 1.0/3.0, 1.0, 2, -2.0/3.0 + 100.0, 10.0, 2.0) :: + Row(null, null, 3.0, 3, null, null, null) :: + Row(110.0, 10.0, 20.0, null, 109.0, 11.0, 30.0) :: Nil) + + checkAnswer( + sqlContext.sql( + """ + |SELECT + | key, + | mydoubleavg(distinct value1), + | mydoublesum(value2), + | mydoublesum(distinct value1), + | mydoubleavg(distinct value1), + | mydoubleavg(value1) + |FROM agg2 + |GROUP BY key + """.stripMargin), + Row(1, 120.0, -10.0, 40.0, 120.0, 70.0/3.0 + 100.0) :: + Row(2, 100.0, 3.0, 0.0, 100.0, 1.0/3.0 + 100.0) :: + Row(3, null, 3.0, null, null, null) :: + Row(null, 110.0, 60.0, 30.0, 110.0, 110.0) :: Nil) + + checkAnswer( + sqlContext.sql( + """ + |SELECT + | count(value1), + | count(*), + | count(1), + | count(DISTINCT value1), + | key + |FROM agg2 + |GROUP BY key + """.stripMargin), + Row(3, 3, 3, 2, 1) :: + Row(3, 4, 4, 2, 2) :: + Row(0, 2, 2, 0, 3) :: + Row(3, 4, 4, 3, null) :: Nil) + } + } } test("single distinct multiple columns set") { @@ -699,48 +701,6 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te val corr7 = sqlContext.sql("SELECT corr(b, c) FROM covar_tab").collect()(0).getDouble(0) assert(math.abs(corr7 - 0.6633880657639323) < 1e-12) - - withSQLConf(SQLConf.USE_SQL_AGGREGATE2.key -> "false") { - val errorMessage = intercept[SparkException] { - val df = Seq.tabulate(10)(i => (1.0 * i, 2.0 * i, i * -1.0)).toDF("a", "b", "c") - val corr1 = df.repartition(2).groupBy().agg(corr("a", "b")).collect()(0).getDouble(0) - }.getMessage - assert(errorMessage.contains("java.lang.UnsupportedOperationException: " + - "Corr only supports the new AggregateExpression2")) - } - } - - test("test Last implemented based on AggregateExpression1") { - // TODO: Remove this test once we remove AggregateExpression1. - import org.apache.spark.sql.functions._ - val df = Seq((1, 1), (2, 2), (3, 3)).toDF("i", "j").repartition(1) - withSQLConf( - SQLConf.SHUFFLE_PARTITIONS.key -> "1", - SQLConf.USE_SQL_AGGREGATE2.key -> "false") { - - checkAnswer( - df.groupBy("i").agg(last("j")), - df - ) - } - } - - test("error handling") { - withSQLConf("spark.sql.useAggregate2" -> "false") { - val errorMessage = intercept[AnalysisException] { - sqlContext.sql( - """ - |SELECT - | key, - | sum(value + 1.5 * key), - | mydoublesum(value), - | mydoubleavg(value) - |FROM agg1 - |GROUP BY key - """.stripMargin).collect() - }.getMessage - assert(errorMessage.contains("implemented based on the new Aggregate Function interface")) - } } test("no aggregation function (SPARK-11486)") { From 47735cdc2a878cfdbe76316d3ff8314a45dabf54 Mon Sep 17 00:00:00 2001 From: "Oscar D. Lara Yejas" Date: Tue, 10 Nov 2015 11:07:57 -0800 Subject: [PATCH 78/88] [SPARK-10863][SPARKR] Method coltypes() (New version) This is a follow up on PR #8984, as the corresponding branch for such PR was damaged. Author: Oscar D. Lara Yejas Closes #9579 from olarayej/SPARK-10863_NEW14. --- R/pkg/DESCRIPTION | 1 + R/pkg/NAMESPACE | 6 ++-- R/pkg/R/DataFrame.R | 49 ++++++++++++++++++++++++++++++++ R/pkg/R/generics.R | 4 +++ R/pkg/R/schema.R | 15 +--------- R/pkg/R/types.R | 43 ++++++++++++++++++++++++++++ R/pkg/inst/tests/test_sparkSQL.R | 24 +++++++++++++++- 7 files changed, 124 insertions(+), 18 deletions(-) create mode 100644 R/pkg/R/types.R diff --git a/R/pkg/DESCRIPTION b/R/pkg/DESCRIPTION index 3d6edb70ec98e..369714f7b99c2 100644 --- a/R/pkg/DESCRIPTION +++ b/R/pkg/DESCRIPTION @@ -34,4 +34,5 @@ Collate: 'serialize.R' 'sparkR.R' 'stats.R' + 'types.R' 'utils.R' diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 56b8ed0bf271b..52fd6c9f76c54 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -23,9 +23,11 @@ export("setJobGroup", exportClasses("DataFrame") exportMethods("arrange", + "as.data.frame", "attach", "cache", "collect", + "coltypes", "columns", "count", "cov", @@ -262,6 +264,4 @@ export("structField", "structType", "structType.jobj", "structType.structField", - "print.structType") - -export("as.data.frame") + "print.structType") \ No newline at end of file diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index e9013aa34a84f..cc868069d1e5a 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -2152,3 +2152,52 @@ setMethod("with", newEnv <- assignNewEnv(data) eval(substitute(expr), envir = newEnv, enclos = newEnv) }) + +#' Returns the column types of a DataFrame. +#' +#' @name coltypes +#' @title Get column types of a DataFrame +#' @family dataframe_funcs +#' @param x (DataFrame) +#' @return value (character) A character vector with the column types of the given DataFrame +#' @rdname coltypes +#' @examples \dontrun{ +#' irisDF <- createDataFrame(sqlContext, iris) +#' coltypes(irisDF) +#' } +setMethod("coltypes", + signature(x = "DataFrame"), + function(x) { + # Get the data types of the DataFrame by invoking dtypes() function + types <- sapply(dtypes(x), function(x) {x[[2]]}) + + # Map Spark data types into R's data types using DATA_TYPES environment + rTypes <- sapply(types, USE.NAMES=F, FUN=function(x) { + + # Check for primitive types + type <- PRIMITIVE_TYPES[[x]] + + if (is.null(type)) { + # Check for complex types + for (t in names(COMPLEX_TYPES)) { + if (substring(x, 1, nchar(t)) == t) { + type <- COMPLEX_TYPES[[t]] + break + } + } + + if (is.null(type)) { + stop(paste("Unsupported data type: ", x)) + } + } + type + }) + + # Find which types don't have mapping to R + naIndices <- which(is.na(rTypes)) + + # Assign the original scala data types to the unmatched ones + rTypes[naIndices] <- types[naIndices] + + rTypes + }) \ No newline at end of file diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index efef7d66b522c..89731affeb898 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -1047,3 +1047,7 @@ setGeneric("attach") #' @rdname with #' @export setGeneric("with") + +#' @rdname coltypes +#' @export +setGeneric("coltypes", function(x) { standardGeneric("coltypes") }) \ No newline at end of file diff --git a/R/pkg/R/schema.R b/R/pkg/R/schema.R index 6f0e9a94e9bfa..c6ddb562270b7 100644 --- a/R/pkg/R/schema.R +++ b/R/pkg/R/schema.R @@ -115,20 +115,7 @@ structField.jobj <- function(x) { } checkType <- function(type) { - primtiveTypes <- c("byte", - "integer", - "float", - "double", - "numeric", - "character", - "string", - "binary", - "raw", - "logical", - "boolean", - "timestamp", - "date") - if (type %in% primtiveTypes) { + if (!is.null(PRIMITIVE_TYPES[[type]])) { return() } else { # Check complex types diff --git a/R/pkg/R/types.R b/R/pkg/R/types.R new file mode 100644 index 0000000000000..1828c23ab0f6d --- /dev/null +++ b/R/pkg/R/types.R @@ -0,0 +1,43 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# types.R. This file handles the data type mapping between Spark and R + +# The primitive data types, where names(PRIMITIVE_TYPES) are Scala types whereas +# values are equivalent R types. This is stored in an environment to allow for +# more efficient look up (environments use hashmaps). +PRIMITIVE_TYPES <- as.environment(list( + "byte"="integer", + "tinyint"="integer", + "smallint"="integer", + "integer"="integer", + "bigint"="numeric", + "float"="numeric", + "double"="numeric", + "decimal"="numeric", + "string"="character", + "binary"="raw", + "boolean"="logical", + "timestamp"="POSIXct", + "date"="Date")) + +# The complex data types. These do not have any direct mapping to R's types. +COMPLEX_TYPES <- list( + "map"=NA, + "array"=NA, + "struct"=NA) + +# The full list of data types. +DATA_TYPES <- as.environment(c(as.list(PRIMITIVE_TYPES), COMPLEX_TYPES)) diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index fbdb9a8f1ef6b..06f52d021cff8 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -1467,8 +1467,9 @@ test_that("SQL error message is returned from JVM", { expect_equal(grepl("Table not found: blah", retError), TRUE) }) +irisDF <- createDataFrame(sqlContext, iris) + test_that("Method as.data.frame as a synonym for collect()", { - irisDF <- createDataFrame(sqlContext, iris) expect_equal(as.data.frame(irisDF), collect(irisDF)) irisDF2 <- irisDF[irisDF$Species == "setosa", ] expect_equal(as.data.frame(irisDF2), collect(irisDF2)) @@ -1503,6 +1504,27 @@ test_that("with() on a DataFrame", { expect_equal(nrow(sum2), 35) }) +test_that("Method coltypes() to get R's data types of a DataFrame", { + expect_equal(coltypes(irisDF), c(rep("numeric", 4), "character")) + + data <- data.frame(c1=c(1,2,3), + c2=c(T,F,T), + c3=c("2015/01/01 10:00:00", "2015/01/02 10:00:00", "2015/01/03 10:00:00")) + + schema <- structType(structField("c1", "byte"), + structField("c3", "boolean"), + structField("c4", "timestamp")) + + # Test primitive types + DF <- createDataFrame(sqlContext, data, schema) + expect_equal(coltypes(DF), c("integer", "logical", "POSIXct")) + + # Test complex types + x <- createDataFrame(sqlContext, list(list(as.environment( + list("a"="b", "c"="d", "e"="f"))))) + expect_equal(coltypes(x), "map") +}) + unlink(parquetPath) unlink(jsonPath) unlink(jsonPathNa) From dfcfcbcc0448ebc6f02eba6bf0495832a321c87e Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 10 Nov 2015 11:14:25 -0800 Subject: [PATCH 79/88] [SPARK-11578][SQL][FOLLOW-UP] complete the user facing api for typed aggregation Currently the user facing api for typed aggregation has some limitations: * the customized typed aggregation must be the first of aggregation list * the customized typed aggregation can only use long as buffer type * the customized typed aggregation can only use flat type as result type This PR tries to remove these limitations. Author: Wenchen Fan Closes #9599 from cloud-fan/agg. --- .../catalyst/encoders/ExpressionEncoder.scala | 6 +++ .../aggregate/TypedAggregateExpression.scala | 50 +++++++++++++----- .../spark/sql/expressions/Aggregator.scala | 5 ++ .../spark/sql/DatasetAggregatorSuite.scala | 52 +++++++++++++++++++ 4 files changed, 99 insertions(+), 14 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index c287aebeeee05..005c0627f56b8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -185,6 +185,12 @@ case class ExpressionEncoder[T]( }) } + def shift(delta: Int): ExpressionEncoder[T] = { + copy(constructExpression = constructExpression transform { + case r: BoundReference => r.copy(ordinal = r.ordinal + delta) + }) + } + /** * Returns a copy of this encoder where the expressions used to create an object given an * input row have been modified to pull the object out from a nested struct, instead of the diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala index 24d8122b6222b..0e5bc1f9abf28 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala @@ -20,13 +20,13 @@ package org.apache.spark.sql.execution.aggregate import scala.language.existentials import org.apache.spark.Logging +import org.apache.spark.sql.expressions.Aggregator import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.{encoderFor, Encoder} import org.apache.spark.sql.catalyst.expressions.aggregate.ImperativeAggregate -import org.apache.spark.sql.expressions.Aggregator import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.types.{StructType, DataType} +import org.apache.spark.sql.types._ object TypedAggregateExpression { def apply[A, B : Encoder, C : Encoder]( @@ -67,8 +67,11 @@ case class TypedAggregateExpression( override def nullable: Boolean = true - // TODO: this assumes flat results... - override def dataType: DataType = cEncoder.schema.head.dataType + override def dataType: DataType = if (cEncoder.flat) { + cEncoder.schema.head.dataType + } else { + cEncoder.schema + } override def deterministic: Boolean = true @@ -93,32 +96,51 @@ case class TypedAggregateExpression( case a: AttributeReference => inputMapping(a) }) - // TODO: this probably only works when we are in the first column. val bAttributes = bEncoder.schema.toAttributes lazy val boundB = bEncoder.resolve(bAttributes).bind(bAttributes) + private def updateBuffer(buffer: MutableRow, value: InternalRow): Unit = { + // todo: need a more neat way to assign the value. + var i = 0 + while (i < aggBufferAttributes.length) { + aggBufferSchema(i).dataType match { + case IntegerType => buffer.setInt(mutableAggBufferOffset + i, value.getInt(i)) + case LongType => buffer.setLong(mutableAggBufferOffset + i, value.getLong(i)) + } + i += 1 + } + } + override def initialize(buffer: MutableRow): Unit = { - // TODO: We need to either force Aggregator to have a zero or we need to eliminate the need for - // this in execution. - buffer.setInt(mutableAggBufferOffset, aggregator.zero.asInstanceOf[Int]) + val zero = bEncoder.toRow(aggregator.zero) + updateBuffer(buffer, zero) } override def update(buffer: MutableRow, input: InternalRow): Unit = { val inputA = boundA.fromRow(input) - val currentB = boundB.fromRow(buffer) + val currentB = boundB.shift(mutableAggBufferOffset).fromRow(buffer) val merged = aggregator.reduce(currentB, inputA) val returned = boundB.toRow(merged) - buffer.setInt(mutableAggBufferOffset, returned.getInt(0)) + + updateBuffer(buffer, returned) } override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = { - buffer1.setLong( - mutableAggBufferOffset, - buffer1.getLong(mutableAggBufferOffset) + buffer2.getLong(inputAggBufferOffset)) + val b1 = boundB.shift(mutableAggBufferOffset).fromRow(buffer1) + val b2 = boundB.shift(inputAggBufferOffset).fromRow(buffer2) + val merged = aggregator.merge(b1, b2) + val returned = boundB.toRow(merged) + + updateBuffer(buffer1, returned) } override def eval(buffer: InternalRow): Any = { - buffer.getInt(mutableAggBufferOffset) + val b = boundB.shift(mutableAggBufferOffset).fromRow(buffer) + val result = cEncoder.toRow(aggregator.present(b)) + dataType match { + case _: StructType => result + case _ => result.get(0, dataType) + } } override def toString: String = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala index 8cc25c2440633..3c1c457e06d5f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala @@ -57,6 +57,11 @@ abstract class Aggregator[-A, B, C] { */ def reduce(b: B, a: A): B + /** + * Merge two intermediate values + */ + def merge(b1: B, b2: B): B + /** * Transform the output of the reduction. */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala index 340470c096b87..206095a519762 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala @@ -34,9 +34,41 @@ class SumOf[I, N : Numeric](f: I => N) extends Aggregator[I, N, N] with Serializ override def reduce(b: N, a: I): N = numeric.plus(b, f(a)) + override def merge(b1: N, b2: N): N = numeric.plus(b1, b2) + override def present(reduction: N): N = reduction } +object TypedAverage extends Aggregator[(String, Int), (Long, Long), Double] with Serializable { + override def zero: (Long, Long) = (0, 0) + + override def reduce(countAndSum: (Long, Long), input: (String, Int)): (Long, Long) = { + (countAndSum._1 + 1, countAndSum._2 + input._2) + } + + override def merge(b1: (Long, Long), b2: (Long, Long)): (Long, Long) = { + (b1._1 + b2._1, b1._2 + b2._2) + } + + override def present(countAndSum: (Long, Long)): Double = countAndSum._2 / countAndSum._1 +} + +object ComplexResultAgg extends Aggregator[(String, Int), (Long, Long), (Long, Long)] + with Serializable { + + override def zero: (Long, Long) = (0, 0) + + override def reduce(countAndSum: (Long, Long), input: (String, Int)): (Long, Long) = { + (countAndSum._1 + 1, countAndSum._2 + input._2) + } + + override def merge(b1: (Long, Long), b2: (Long, Long)): (Long, Long) = { + (b1._1 + b2._1, b1._2 + b2._2) + } + + override def present(reduction: (Long, Long)): (Long, Long) = reduction +} + class DatasetAggregatorSuite extends QueryTest with SharedSQLContext { import testImplicits._ @@ -62,4 +94,24 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext { count("*")), ("a", 30, 30, 2L), ("b", 3, 3, 2L), ("c", 1, 1, 1L)) } + + test("typed aggregation: complex case") { + val ds = Seq("a" -> 1, "a" -> 3, "b" -> 3).toDS() + + checkAnswer( + ds.groupBy(_._1).agg( + expr("avg(_2)").as[Double], + TypedAverage.toColumn), + ("a", 2.0, 2.0), ("b", 3.0, 3.0)) + } + + test("typed aggregation: complex result type") { + val ds = Seq("a" -> 1, "a" -> 3, "b" -> 3).toDS() + + checkAnswer( + ds.groupBy(_._1).agg( + expr("avg(_2)").as[Double], + ComplexResultAgg.toColumn), + ("a", 2.0, (2L, 4L)), ("b", 3.0, (1L, 3L))) + } } From 53600854c270d4c953fe95fbae528740b5cf6603 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 10 Nov 2015 11:21:31 -0800 Subject: [PATCH 80/88] [SPARK-11590][SQL] use native json_tuple in lateral view Author: Wenchen Fan Closes #9562 from cloud-fan/json-tuple. --- .../expressions/jsonExpressions.scala | 23 +++++--------- .../expressions/JsonExpressionsSuite.scala | 30 ++++++++++-------- .../org/apache/spark/sql/DataFrame.scala | 8 +++-- .../org/apache/spark/sql/functions.scala | 12 +++++++ .../apache/spark/sql/JsonFunctionsSuite.scala | 23 ++++++++------ .../org/apache/spark/sql/hive/HiveQl.scala | 4 +++ .../apache/spark/sql/hive/HiveQlSuite.scala | 13 ++++++++ .../sql/hive/execution/SQLQuerySuite.scala | 31 +++++++++++++++++++ 8 files changed, 104 insertions(+), 40 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala index 8c9853e628d2c..8cd73236a7876 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -314,7 +314,7 @@ case class GetJsonObject(json: Expression, path: Expression) } case class JsonTuple(children: Seq[Expression]) - extends Expression with CodegenFallback { + extends Generator with CodegenFallback { import SharedFactory._ @@ -324,8 +324,8 @@ case class JsonTuple(children: Seq[Expression]) } // if processing fails this shared value will be returned - @transient private lazy val nullRow: InternalRow = - new GenericInternalRow(Array.ofDim[Any](fieldExpressions.length)) + @transient private lazy val nullRow: Seq[InternalRow] = + new GenericInternalRow(Array.ofDim[Any](fieldExpressions.length)) :: Nil // the json body is the first child @transient private lazy val jsonExpr: Expression = children.head @@ -344,15 +344,8 @@ case class JsonTuple(children: Seq[Expression]) // and count the number of foldable fields, we'll use this later to optimize evaluation @transient private lazy val constantFields: Int = foldableFieldNames.count(_ != null) - override lazy val dataType: StructType = { - val fields = fieldExpressions.zipWithIndex.map { - case (_, idx) => StructField( - name = s"c$idx", // mirroring GenericUDTFJSONTuple.initialize - dataType = StringType, - nullable = true) - } - - StructType(fields) + override def elementTypes: Seq[(DataType, Boolean, String)] = fieldExpressions.zipWithIndex.map { + case (_, idx) => (StringType, true, s"c$idx") } override def prettyName: String = "json_tuple" @@ -367,7 +360,7 @@ case class JsonTuple(children: Seq[Expression]) } } - override def eval(input: InternalRow): InternalRow = { + override def eval(input: InternalRow): TraversableOnce[InternalRow] = { val json = jsonExpr.eval(input).asInstanceOf[UTF8String] if (json == null) { return nullRow @@ -383,7 +376,7 @@ case class JsonTuple(children: Seq[Expression]) } } - private def parseRow(parser: JsonParser, input: InternalRow): InternalRow = { + private def parseRow(parser: JsonParser, input: InternalRow): Seq[InternalRow] = { // only objects are supported if (parser.nextToken() != JsonToken.START_OBJECT) { return nullRow @@ -433,7 +426,7 @@ case class JsonTuple(children: Seq[Expression]) parser.skipChildren() } - new GenericInternalRow(row) + new GenericInternalRow(row) :: Nil } private def copyCurrentStructure(generator: JsonGenerator, parser: JsonParser): Unit = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala index f33125f463e14..7b754091f4714 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala @@ -209,8 +209,12 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { Literal("f5") :: Nil + private def checkJsonTuple(jt: JsonTuple, expected: InternalRow): Unit = { + assert(jt.eval(null).toSeq.head === expected) + } + test("json_tuple - hive key 1") { - checkEvaluation( + checkJsonTuple( JsonTuple( Literal("""{"f1": "value1", "f2": "value2", "f3": 3, "f5": 5.23}""") :: jsonTupleQuery), @@ -218,7 +222,7 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } test("json_tuple - hive key 2") { - checkEvaluation( + checkJsonTuple( JsonTuple( Literal("""{"f1": "value12", "f3": "value3", "f2": 2, "f4": 4.01}""") :: jsonTupleQuery), @@ -226,7 +230,7 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } test("json_tuple - hive key 2 (mix of foldable fields)") { - checkEvaluation( + checkJsonTuple( JsonTuple(Literal("""{"f1": "value12", "f3": "value3", "f2": 2, "f4": 4.01}""") :: Literal("f1") :: NonFoldableLiteral("f2") :: @@ -238,7 +242,7 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } test("json_tuple - hive key 3") { - checkEvaluation( + checkJsonTuple( JsonTuple( Literal("""{"f1": "value13", "f4": "value44", "f3": "value33", "f2": 2, "f5": 5.01}""") :: jsonTupleQuery), @@ -247,7 +251,7 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } test("json_tuple - hive key 3 (nonfoldable json)") { - checkEvaluation( + checkJsonTuple( JsonTuple( NonFoldableLiteral( """{"f1": "value13", "f4": "value44", @@ -258,7 +262,7 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } test("json_tuple - hive key 3 (nonfoldable fields)") { - checkEvaluation( + checkJsonTuple( JsonTuple(Literal( """{"f1": "value13", "f4": "value44", | "f3": "value33", "f2": 2, "f5": 5.01}""".stripMargin) :: @@ -273,43 +277,43 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } test("json_tuple - hive key 4 - null json") { - checkEvaluation( + checkJsonTuple( JsonTuple(Literal(null) :: jsonTupleQuery), InternalRow.fromSeq(Seq(null, null, null, null, null))) } test("json_tuple - hive key 5 - null and empty fields") { - checkEvaluation( + checkJsonTuple( JsonTuple(Literal("""{"f1": "", "f5": null}""") :: jsonTupleQuery), InternalRow.fromSeq(Seq(UTF8String.fromString(""), null, null, null, null))) } test("json_tuple - hive key 6 - invalid json (array)") { - checkEvaluation( + checkJsonTuple( JsonTuple(Literal("[invalid JSON string]") :: jsonTupleQuery), InternalRow.fromSeq(Seq(null, null, null, null, null))) } test("json_tuple - invalid json (object start only)") { - checkEvaluation( + checkJsonTuple( JsonTuple(Literal("{") :: jsonTupleQuery), InternalRow.fromSeq(Seq(null, null, null, null, null))) } test("json_tuple - invalid json (no object end)") { - checkEvaluation( + checkJsonTuple( JsonTuple(Literal("""{"foo": "bar"""") :: jsonTupleQuery), InternalRow.fromSeq(Seq(null, null, null, null, null))) } test("json_tuple - invalid json (invalid json)") { - checkEvaluation( + checkJsonTuple( JsonTuple(Literal("\\") :: jsonTupleQuery), InternalRow.fromSeq(Seq(null, null, null, null, null))) } test("json_tuple - preserve newlines") { - checkEvaluation( + checkJsonTuple( JsonTuple(Literal("{\"a\":\"b\nc\"}") :: Literal("a") :: Nil), InternalRow.fromSeq(Seq(UTF8String.fromString("b\nc")))) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 3b69247dc54ef..9368435a63c35 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -750,10 +750,14 @@ class DataFrame private[sql]( // will remove intermediate Alias for ExtractValue chain, and we need to alias it again to // make it a NamedExpression. case Column(u: UnresolvedAttribute) => UnresolvedAlias(u) + case Column(expr: NamedExpression) => expr - // Leave an unaliased explode with an empty list of names since the analyzer will generate the - // correct defaults after the nested expression's type has been resolved. + + // Leave an unaliased generator with an empty list of names since the analyzer will generate + // the correct defaults after the nested expression's type has been resolved. case Column(explode: Explode) => MultiAlias(explode, Nil) + case Column(jt: JsonTuple) => MultiAlias(jt, Nil) + case Column(expr: Expression) => Alias(expr, expr.prettyString)() } Project(namedExpressions.toSeq, logicalPlan) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 22104e4d48617..a59d738010f7c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -2307,6 +2307,18 @@ object functions extends LegacyFunctions { */ def explode(e: Column): Column = withExpr { Explode(e.expr) } + /** + * Creates a new row for a json column according to the given field names. + * + * @group collection_funcs + * @since 1.6.0 + */ + @scala.annotation.varargs + def json_tuple(json: Column, fields: String*): Column = withExpr { + require(fields.length > 0, "at least 1 field name should be given.") + JsonTuple(json.expr +: fields.map(Literal.apply)) + } + /** * Returns length of array or map. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala index e3531d0d6d799..14fd56fc8c222 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala @@ -41,23 +41,26 @@ class JsonFunctionsSuite extends QueryTest with SharedSQLContext { test("json_tuple select") { val df: DataFrame = tuples.toDF("key", "jstring") - val expected = Row("1", Row("value1", "value2", "3", null, "5.23")) :: - Row("2", Row("value12", "2", "value3", "4.01", null)) :: - Row("3", Row("value13", "2", "value33", "value44", "5.01")) :: - Row("4", Row(null, null, null, null, null)) :: - Row("5", Row("", null, null, null, null)) :: - Row("6", Row(null, null, null, null, null)) :: + val expected = + Row("1", "value1", "value2", "3", null, "5.23") :: + Row("2", "value12", "2", "value3", "4.01", null) :: + Row("3", "value13", "2", "value33", "value44", "5.01") :: + Row("4", null, null, null, null, null) :: + Row("5", "", null, null, null, null) :: + Row("6", null, null, null, null, null) :: Nil - checkAnswer(df.selectExpr("key", "json_tuple(jstring, 'f1', 'f2', 'f3', 'f4', 'f5')"), expected) + checkAnswer( + df.select($"key", functions.json_tuple($"jstring", "f1", "f2", "f3", "f4", "f5")), + expected) } test("json_tuple filter and group") { val df: DataFrame = tuples.toDF("key", "jstring") val expr = df - .selectExpr("json_tuple(jstring, 'f1', 'f2') as jt") - .where($"jt.c0".isNotNull) - .groupBy($"jt.c1") + .select(functions.json_tuple($"jstring", "f1", "f2")) + .where($"c0".isNotNull) + .groupBy($"c1") .count() val expected = Row(null, 1) :: diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index 6f8ed413a06cd..091caab921fe9 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -1821,6 +1821,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C } val explode = "(?i)explode".r + val jsonTuple = "(?i)json_tuple".r def nodesToGenerator(nodes: Seq[Node]): (Generator, Seq[String]) = { val function = nodes.head @@ -1833,6 +1834,9 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C case Token("TOK_FUNCTION", Token(explode(), Nil) :: child :: Nil) => (Explode(nodeToExpr(child)), attributes) + case Token("TOK_FUNCTION", Token(jsonTuple(), Nil) :: children) => + (JsonTuple(children.map(nodeToExpr)), attributes) + case Token("TOK_FUNCTION", Token(functionName, Nil) :: children) => val functionInfo: FunctionInfo = Option(FunctionRegistry.getFunctionInfo(functionName.toLowerCase)).getOrElse( diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala index 528a7398b10df..a330362b4e1d1 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala @@ -18,6 +18,8 @@ package org.apache.spark.sql.hive import org.apache.hadoop.hive.serde.serdeConstants +import org.apache.spark.sql.catalyst.expressions.JsonTuple +import org.apache.spark.sql.catalyst.plans.logical.Generate import org.scalatest.BeforeAndAfterAll import org.apache.spark.SparkFunSuite @@ -183,4 +185,15 @@ class HiveQlSuite extends SparkFunSuite with BeforeAndAfterAll { assertError("select interval '.1111111111' second", "nanosecond 1111111111 outside range") } + + test("use native json_tuple instead of hive's UDTF in LATERAL VIEW") { + val plan = HiveQl.parseSql( + """ + |SELECT * + |FROM (SELECT '{"f1": "value1", "f2": 12}' json) test + |LATERAL VIEW json_tuple(json, 'f1', 'f2') jt AS a, b + """.stripMargin) + + assert(plan.children.head.asInstanceOf[Generate].generator.isInstanceOf[JsonTuple]) + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 9a425d7f6b265..3427152b2da02 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -1448,4 +1448,35 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { Row("1", "10") :: Row("2", "20") :: Row("3", "30") :: Row("4", "40") :: Nil) } } + + test("SPARK-11590: use native json_tuple in lateral view") { + checkAnswer(sql( + """ + |SELECT a, b + |FROM (SELECT '{"f1": "value1", "f2": 12}' json) test + |LATERAL VIEW json_tuple(json, 'f1', 'f2') jt AS a, b + """.stripMargin), Row("value1", "12")) + + // we should use `c0`, `c1`... as the name of fields if no alias is provided, to follow hive. + checkAnswer(sql( + """ + |SELECT c0, c1 + |FROM (SELECT '{"f1": "value1", "f2": 12}' json) test + |LATERAL VIEW json_tuple(json, 'f1', 'f2') jt + """.stripMargin), Row("value1", "12")) + + // we can also use `json_tuple` in project list. + checkAnswer(sql( + """ + |SELECT json_tuple(json, 'f1', 'f2') + |FROM (SELECT '{"f1": "value1", "f2": 12}' json) test + """.stripMargin), Row("value1", "12")) + + // we can also mix `json_tuple` with other project expressions. + checkAnswer(sql( + """ + |SELECT json_tuple(json, 'f1', 'f2'), 3.14, str + |FROM (SELECT '{"f1": "value1", "f2": 12}' json, 'hello' as str) test + """.stripMargin), Row("value1", "12", 3.14, "hello")) + } } From 87aedc48c01dffbd880e6ca84076ed47c68f88d0 Mon Sep 17 00:00:00 2001 From: Nong Li Date: Tue, 10 Nov 2015 11:28:53 -0800 Subject: [PATCH 81/88] [SPARK-10371][SQL] Implement subexpr elimination for UnsafeProjections This patch adds the building blocks for codegening subexpr elimination and implements it end to end for UnsafeProjection. The building blocks can be used to do the same thing for other operators. It introduces some utilities to compute common sub expressions. Expressions can be added to this data structure. The expr and its children will be recursively matched against existing expressions (ones previously added) and grouped into common groups. This is built using the existing `semanticEquals`. It does not understand things like commutative or associative expressions. This can be done as future work. After building this data structure, the codegen process takes advantage of it by: 1. Generating a helper function in the generated class that computes the common subexpression. This is done for all common subexpressions that have at least two occurrences and the expression tree is sufficiently complex. 2. When generating the apply() function, if the helper function exists, call that instead of regenerating the expression tree. Repeated calls to the helper function shortcircuit the evaluation logic. Author: Nong Li Author: Nong Li This patch had conflicts when merged, resolved by Committer: Michael Armbrust Closes #9480 from nongli/spark-10371. --- .../expressions/EquivalentExpressions.scala | 106 ++++++++++++ .../sql/catalyst/expressions/Expression.scala | 50 +++++- .../sql/catalyst/expressions/Projection.scala | 16 ++ .../expressions/codegen/CodeGenerator.scala | 110 ++++++++++++- .../codegen/GenerateUnsafeProjection.scala | 36 ++++- .../expressions/namedExpressions.scala | 4 + .../SubexpressionEliminationSuite.scala | 153 ++++++++++++++++++ .../scala/org/apache/spark/sql/SQLConf.scala | 8 + .../spark/sql/execution/SparkPlan.scala | 5 + .../spark/sql/execution/basicOperators.scala | 3 +- .../org/apache/spark/sql/SQLQuerySuite.scala | 48 ++++++ 11 files changed, 523 insertions(+), 16 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala new file mode 100644 index 0000000000000..e7380d21f98af --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala @@ -0,0 +1,106 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import scala.collection.mutable + +/** + * This class is used to compute equality of (sub)expression trees. Expressions can be added + * to this class and they subsequently query for expression equality. Expression trees are + * considered equal if for the same input(s), the same result is produced. + */ +class EquivalentExpressions { + /** + * Wrapper around an Expression that provides semantic equality. + */ + case class Expr(e: Expression) { + val hash = e.semanticHash() + override def equals(o: Any): Boolean = o match { + case other: Expr => e.semanticEquals(other.e) + case _ => false + } + override def hashCode: Int = hash + } + + // For each expression, the set of equivalent expressions. + private val equivalenceMap: mutable.HashMap[Expr, mutable.MutableList[Expression]] = + new mutable.HashMap[Expr, mutable.MutableList[Expression]] + + /** + * Adds each expression to this data structure, grouping them with existing equivalent + * expressions. Non-recursive. + * Returns if there was already a matching expression. + */ + def addExpr(expr: Expression): Boolean = { + if (expr.deterministic) { + val e: Expr = Expr(expr) + val f = equivalenceMap.get(e) + if (f.isDefined) { + f.get.+= (expr) + true + } else { + equivalenceMap.put(e, mutable.MutableList(expr)) + false + } + } else { + false + } + } + + /** + * Adds the expression to this datastructure recursively. Stops if a matching expression + * is found. That is, if `expr` has already been added, its children are not added. + * If ignoreLeaf is true, leaf nodes are ignored. + */ + def addExprTree(root: Expression, ignoreLeaf: Boolean = true): Unit = { + val skip = root.isInstanceOf[LeafExpression] && ignoreLeaf + if (!skip && root.deterministic && !addExpr(root)) { + root.children.foreach(addExprTree(_, ignoreLeaf)) + } + } + + /** + * Returns all fo the expression trees that are equivalent to `e`. Returns + * an empty collection if there are none. + */ + def getEquivalentExprs(e: Expression): Seq[Expression] = { + equivalenceMap.get(Expr(e)).getOrElse(mutable.MutableList()) + } + + /** + * Returns all the equivalent sets of expressions. + */ + def getAllEquivalentExprs: Seq[Seq[Expression]] = { + equivalenceMap.values.map(_.toSeq).toSeq + } + + /** + * Returns the state of the datastructure as a string. If all is false, skips sets of equivalent + * expressions with cardinality 1. + */ + def debugString(all: Boolean = false): String = { + val sb: mutable.StringBuilder = new StringBuilder() + sb.append("Equivalent expressions:\n") + equivalenceMap.foreach { case (k, v) => { + if (all || v.length > 1) { + sb.append(" " + v.mkString(", ")).append("\n") + } + }} + sb.toString() + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 96fcc799e537a..7d5741eefcc7d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -92,12 +92,24 @@ abstract class Expression extends TreeNode[Expression] { * @return [[GeneratedExpressionCode]] */ def gen(ctx: CodeGenContext): GeneratedExpressionCode = { - val isNull = ctx.freshName("isNull") - val primitive = ctx.freshName("primitive") - val ve = GeneratedExpressionCode("", isNull, primitive) - ve.code = genCode(ctx, ve) - // Add `this` in the comment. - ve.copy(s"/* $this */\n" + ve.code) + val subExprState = ctx.subExprEliminationExprs.get(this) + if (subExprState.isDefined) { + // This expression is repeated meaning the code to evaluated has already been added + // as a function, `subExprState.fnName`. Just call that. + val code = + s""" + |/* $this */ + |${subExprState.get.fnName}(${ctx.INPUT_ROW}); + |""".stripMargin.trim + GeneratedExpressionCode(code, subExprState.get.code.isNull, subExprState.get.code.value) + } else { + val isNull = ctx.freshName("isNull") + val primitive = ctx.freshName("primitive") + val ve = GeneratedExpressionCode("", isNull, primitive) + ve.code = genCode(ctx, ve) + // Add `this` in the comment. + ve.copy(s"/* $this */\n" + ve.code.trim) + } } /** @@ -145,11 +157,37 @@ abstract class Expression extends TreeNode[Expression] { case (i1, i2) => i1 == i2 } } + // Non-determinstic expressions cannot be equal + if (!deterministic || !other.deterministic) return false val elements1 = this.productIterator.toSeq val elements2 = other.asInstanceOf[Product].productIterator.toSeq checkSemantic(elements1, elements2) } + /** + * Returns the hash for this expression. Expressions that compute the same result, even if + * they differ cosmetically should return the same hash. + */ + def semanticHash() : Int = { + def computeHash(e: Seq[Any]): Int = { + // See http://stackoverflow.com/questions/113511/hash-code-implementation + var hash: Int = 17 + e.foreach(i => { + val h: Int = i match { + case (e: Expression) => e.semanticHash() + case (Some(e: Expression)) => e.semanticHash() + case (t: Traversable[_]) => computeHash(t.toSeq) + case null => 0 + case (o) => o.hashCode() + } + hash = hash * 37 + h + }) + hash + } + + computeHash(this.productIterator.toSeq) + } + /** * Checks the input data types, returns `TypeCheckResult.success` if it's valid, * or returns a `TypeCheckResult` with an error message if invalid. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index 79dabe8e925ad..9f0b7821ae74a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -144,6 +144,22 @@ object UnsafeProjection { def create(exprs: Seq[Expression], inputSchema: Seq[Attribute]): UnsafeProjection = { create(exprs.map(BindReferences.bindReference(_, inputSchema))) } + + /** + * Same as other create()'s but allowing enabling/disabling subexpression elimination. + * TODO: refactor the plumbing and clean this up. + */ + def create( + exprs: Seq[Expression], + inputSchema: Seq[Attribute], + subexpressionEliminationEnabled: Boolean): UnsafeProjection = { + val e = exprs.map(BindReferences.bindReference(_, inputSchema)) + .map(_ transform { + case CreateStruct(children) => CreateStructUnsafe(children) + case CreateNamedStruct(children) => CreateNamedStructUnsafe(children) + }) + GenerateUnsafeProjection.generate(e, subexpressionEliminationEnabled) + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index f0f7a6cf0cc4d..60a3d6018496c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -92,6 +92,33 @@ class CodeGenContext { addedFunctions += ((funcName, funcCode)) } + /** + * Holds expressions that are equivalent. Used to perform subexpression elimination + * during codegen. + * + * For expressions that appear more than once, generate additional code to prevent + * recomputing the value. + * + * For example, consider two exprsesion generated from this SQL statement: + * SELECT (col1 + col2), (col1 + col2) / col3. + * + * equivalentExpressions will match the tree containing `col1 + col2` and it will only + * be evaluated once. + */ + val equivalentExpressions: EquivalentExpressions = new EquivalentExpressions + + // State used for subexpression elimination. + case class SubExprEliminationState( + val isLoaded: String, code: GeneratedExpressionCode, val fnName: String) + + // Foreach expression that is participating in subexpression elimination, the state to use. + val subExprEliminationExprs: mutable.HashMap[Expression, SubExprEliminationState] = + mutable.HashMap[Expression, SubExprEliminationState]() + + // The collection of isLoaded variables that need to be reset on each row. + val subExprIsLoadedVariables: mutable.ArrayBuffer[String] = + mutable.ArrayBuffer.empty[String] + final val JAVA_BOOLEAN = "boolean" final val JAVA_BYTE = "byte" final val JAVA_SHORT = "short" @@ -317,6 +344,87 @@ class CodeGenContext { functions.map(name => s"$name($row);").mkString("\n") } } + + /** + * Checks and sets up the state and codegen for subexpression elimination. This finds the + * common subexpresses, generates the functions that evaluate those expressions and populates + * the mapping of common subexpressions to the generated functions. + */ + private def subexpressionElimination(expressions: Seq[Expression]) = { + // Add each expression tree and compute the common subexpressions. + expressions.foreach(equivalentExpressions.addExprTree(_)) + + // Get all the exprs that appear at least twice and set up the state for subexpression + // elimination. + val commonExprs = equivalentExpressions.getAllEquivalentExprs.filter(_.size > 1) + commonExprs.foreach(e => { + val expr = e.head + val isLoaded = freshName("isLoaded") + val isNull = freshName("isNull") + val primitive = freshName("primitive") + val fnName = freshName("evalExpr") + + // Generate the code for this expression tree and wrap it in a function. + val code = expr.gen(this) + val fn = + s""" + |private void $fnName(InternalRow ${INPUT_ROW}) { + | if (!$isLoaded) { + | ${code.code.trim} + | $isLoaded = true; + | $isNull = ${code.isNull}; + | $primitive = ${code.value}; + | } + |} + """.stripMargin + code.code = fn + code.isNull = isNull + code.value = primitive + + addNewFunction(fnName, fn) + + // Add a state and a mapping of the common subexpressions that are associate with this + // state. Adding this expression to subExprEliminationExprMap means it will call `fn` + // when it is code generated. This decision should be a cost based one. + // + // The cost of doing subexpression elimination is: + // 1. Extra function call, although this is probably *good* as the JIT can decide to + // inline or not. + // 2. Extra branch to check isLoaded. This branch is likely to be predicted correctly + // very often. The reason it is not loaded is because of a prior branch. + // 3. Extra store into isLoaded. + // The benefit doing subexpression elimination is: + // 1. Running the expression logic. Even for a simple expression, it is likely more than 3 + // above. + // 2. Less code. + // Currently, we will do this for all non-leaf only expression trees (i.e. expr trees with + // at least two nodes) as the cost of doing it is expected to be low. + + // Maintain the loaded value and isNull as member variables. This is necessary if the codegen + // function is split across multiple functions. + // TODO: maintaining this as a local variable probably allows the compiler to do better + // optimizations. + addMutableState("boolean", isLoaded, s"$isLoaded = false;") + addMutableState("boolean", isNull, s"$isNull = false;") + addMutableState(javaType(expr.dataType), primitive, + s"$primitive = ${defaultValue(expr.dataType)};") + subExprIsLoadedVariables += isLoaded + + val state = SubExprEliminationState(isLoaded, code, fnName) + e.foreach(subExprEliminationExprs.put(_, state)) + }) + } + + /** + * Generates code for expressions. If doSubexpressionElimination is true, subexpression + * elimination will be performed. Subexpression elimination assumes that the code will for each + * expression will be combined in the `expressions` order. + */ + def generateExpressions(expressions: Seq[Expression], + doSubexpressionElimination: Boolean = false): Seq[GeneratedExpressionCode] = { + if (doSubexpressionElimination) subexpressionElimination(expressions) + expressions.map(e => e.gen(this)) + } } /** @@ -349,7 +457,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin } protected def declareAddedFunctions(ctx: CodeGenContext): String = { - ctx.addedFunctions.map { case (funcName, funcCode) => funcCode }.mkString("\n") + ctx.addedFunctions.map { case (funcName, funcCode) => funcCode }.mkString("\n").trim } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 2136f82ba4752..9ef226141421b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -139,9 +139,9 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro s""" ${input.code} if (${input.isNull}) { - $setNull + ${setNull.trim} } else { - $writeField + ${writeField.trim} } """ } @@ -149,7 +149,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro s""" $rowWriter.initialize($bufferHolder, ${inputs.length}); ${ctx.splitExpressions(row, writeFields)} - """ + """.trim } // TODO: if the nullability of array element is correct, we can use it to save null check. @@ -275,8 +275,11 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro """ } - def createCode(ctx: CodeGenContext, expressions: Seq[Expression]): GeneratedExpressionCode = { - val exprEvals = expressions.map(e => e.gen(ctx)) + def createCode( + ctx: CodeGenContext, + expressions: Seq[Expression], + useSubexprElimination: Boolean = false): GeneratedExpressionCode = { + val exprEvals = ctx.generateExpressions(expressions, useSubexprElimination) val exprTypes = expressions.map(_.dataType) val result = ctx.freshName("result") @@ -285,10 +288,15 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val holderClass = classOf[BufferHolder].getName ctx.addMutableState(holderClass, bufferHolder, s"this.$bufferHolder = new $holderClass();") + // Reset the isLoaded flag for each row. + val subexprReset = ctx.subExprIsLoadedVariables.map { v => s"${v} = false;" }.mkString("\n") + val code = s""" $bufferHolder.reset(); + $subexprReset ${writeExpressionsToBuffer(ctx, ctx.INPUT_ROW, exprEvals, exprTypes, bufferHolder)} + $result.pointTo($bufferHolder.buffer, ${expressions.length}, $bufferHolder.totalSize()); """ GeneratedExpressionCode(code, "false", result) @@ -300,10 +308,21 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro protected def bind(in: Seq[Expression], inputSchema: Seq[Attribute]): Seq[Expression] = in.map(BindReferences.bindReference(_, inputSchema)) + def generate( + expressions: Seq[Expression], + subexpressionEliminationEnabled: Boolean): UnsafeProjection = { + create(canonicalize(expressions), subexpressionEliminationEnabled) + } + protected def create(expressions: Seq[Expression]): UnsafeProjection = { - val ctx = newCodeGenContext() + create(expressions, false) + } - val eval = createCode(ctx, expressions) + private def create( + expressions: Seq[Expression], + subexpressionEliminationEnabled: Boolean): UnsafeProjection = { + val ctx = newCodeGenContext() + val eval = createCode(ctx, expressions, subexpressionEliminationEnabled) val code = s""" public Object generate($exprType[] exprs) { @@ -315,6 +334,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro private $exprType[] expressions; ${declareMutableStates(ctx)} + ${declareAddedFunctions(ctx)} public SpecificUnsafeProjection($exprType[] expressions) { @@ -328,7 +348,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro } public UnsafeRow apply(InternalRow ${ctx.INPUT_ROW}) { - ${eval.code} + ${eval.code.trim} return ${eval.value}; } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index 9ab5c299d0f55..f80bcfcb0b0bf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -203,6 +203,10 @@ case class AttributeReference( case _ => false } + override def semanticHash(): Int = { + this.exprId.hashCode() + } + override def hashCode: Int = { // See http://stackoverflow.com/questions/113511/hash-code-implementation var h = 17 diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala new file mode 100644 index 0000000000000..9de066e99d637 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala @@ -0,0 +1,153 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.types.IntegerType + +class SubexpressionEliminationSuite extends SparkFunSuite { + test("Semantic equals and hash") { + val id = ExprId(1) + val a: AttributeReference = AttributeReference("name", IntegerType)() + val b1 = a.withName("name2").withExprId(id) + val b2 = a.withExprId(id) + + assert(b1 != b2) + assert(a != b1) + assert(b1.semanticEquals(b2)) + assert(!b1.semanticEquals(a)) + assert(a.hashCode != b1.hashCode) + assert(b1.hashCode == b2.hashCode) + assert(b1.semanticHash() == b2.semanticHash()) + } + + test("Expression Equivalence - basic") { + val equivalence = new EquivalentExpressions + assert(equivalence.getAllEquivalentExprs.isEmpty) + + val oneA = Literal(1) + val oneB = Literal(1) + val twoA = Literal(2) + var twoB = Literal(2) + + assert(equivalence.getEquivalentExprs(oneA).isEmpty) + assert(equivalence.getEquivalentExprs(twoA).isEmpty) + + // Add oneA and test if it is returned. Since it is a group of one, it does not. + assert(!equivalence.addExpr(oneA)) + assert(equivalence.getEquivalentExprs(oneA).size == 1) + assert(equivalence.getEquivalentExprs(twoA).isEmpty) + assert(equivalence.addExpr((oneA))) + assert(equivalence.getEquivalentExprs(oneA).size == 2) + + // Add B and make sure they can see each other. + assert(equivalence.addExpr(oneB)) + // Use exists and reference equality because of how equals is defined. + assert(equivalence.getEquivalentExprs(oneA).exists(_ eq oneB)) + assert(equivalence.getEquivalentExprs(oneA).exists(_ eq oneA)) + assert(equivalence.getEquivalentExprs(oneB).exists(_ eq oneA)) + assert(equivalence.getEquivalentExprs(oneB).exists(_ eq oneB)) + assert(equivalence.getEquivalentExprs(twoA).isEmpty) + assert(equivalence.getAllEquivalentExprs.size == 1) + assert(equivalence.getAllEquivalentExprs.head.size == 3) + assert(equivalence.getAllEquivalentExprs.head.contains(oneA)) + assert(equivalence.getAllEquivalentExprs.head.contains(oneB)) + + val add1 = Add(oneA, oneB) + val add2 = Add(oneA, oneB) + + equivalence.addExpr(add1) + equivalence.addExpr(add2) + + assert(equivalence.getAllEquivalentExprs.size == 2) + assert(equivalence.getEquivalentExprs(add2).exists(_ eq add1)) + assert(equivalence.getEquivalentExprs(add2).size == 2) + assert(equivalence.getEquivalentExprs(add1).exists(_ eq add2)) + } + + test("Expression Equivalence - Trees") { + val one = Literal(1) + val two = Literal(2) + + val add = Add(one, two) + val abs = Abs(add) + val add2 = Add(add, add) + + var equivalence = new EquivalentExpressions + equivalence.addExprTree(add, true) + equivalence.addExprTree(abs, true) + equivalence.addExprTree(add2, true) + + // Should only have one equivalence for `one + two` + assert(equivalence.getAllEquivalentExprs.filter(_.size > 1).size == 1) + assert(equivalence.getAllEquivalentExprs.filter(_.size > 1).head.size == 4) + + // Set up the expressions + // one * two, + // (one * two) * (one * two) + // sqrt( (one * two) * (one * two) ) + // (one * two) + sqrt( (one * two) * (one * two) ) + equivalence = new EquivalentExpressions + val mul = Multiply(one, two) + val mul2 = Multiply(mul, mul) + val sqrt = Sqrt(mul2) + val sum = Add(mul2, sqrt) + equivalence.addExprTree(mul, true) + equivalence.addExprTree(mul2, true) + equivalence.addExprTree(sqrt, true) + equivalence.addExprTree(sum, true) + + // (one * two), (one * two) * (one * two) and sqrt( (one * two) * (one * two) ) should be found + assert(equivalence.getAllEquivalentExprs.filter(_.size > 1).size == 3) + assert(equivalence.getEquivalentExprs(mul).size == 3) + assert(equivalence.getEquivalentExprs(mul2).size == 3) + assert(equivalence.getEquivalentExprs(sqrt).size == 2) + assert(equivalence.getEquivalentExprs(sum).size == 1) + + // Some expressions inspired by TPCH-Q1 + // sum(l_quantity) as sum_qty, + // sum(l_extendedprice) as sum_base_price, + // sum(l_extendedprice * (1 - l_discount)) as sum_disc_price, + // sum(l_extendedprice * (1 - l_discount) * (1 + l_tax)) as sum_charge, + // avg(l_extendedprice) as avg_price, + // avg(l_discount) as avg_disc + equivalence = new EquivalentExpressions + val quantity = Literal(1) + val price = Literal(1.1) + val discount = Literal(.24) + val tax = Literal(0.1) + equivalence.addExprTree(quantity, false) + equivalence.addExprTree(price, false) + equivalence.addExprTree(Multiply(price, Subtract(Literal(1), discount)), false) + equivalence.addExprTree( + Multiply( + Multiply(price, Subtract(Literal(1), discount)), + Add(Literal(1), tax)), false) + equivalence.addExprTree(price, false) + equivalence.addExprTree(discount, false) + // quantity, price, discount and (price * (1 - discount)) + assert(equivalence.getAllEquivalentExprs.filter(_.size > 1).size == 4) + } + + test("Expression equivalence - non deterministic") { + val sum = Add(Rand(0), Rand(0)) + val equivalence = new EquivalentExpressions + equivalence.addExpr(sum) + equivalence.addExpr(sum) + assert(equivalence.getAllEquivalentExprs.isEmpty) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index b7314189b5403..89e196c066007 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -268,6 +268,11 @@ private[spark] object SQLConf { doc = "When true, use the new optimized Tungsten physical execution backend.", isPublic = false) + val SUBEXPRESSION_ELIMINATION_ENABLED = booleanConf("spark.sql.subexpressionElimination.enabled", + defaultValue = Some(true), // use CODEGEN_ENABLED as default + doc = "When true, common subexpressions will be eliminated.", + isPublic = false) + val DIALECT = stringConf( "spark.sql.dialect", defaultValue = Some("sql"), @@ -541,6 +546,9 @@ private[sql] class SQLConf extends Serializable with CatalystConf { private[spark] def unsafeEnabled: Boolean = getConf(UNSAFE_ENABLED, getConf(TUNGSTEN_ENABLED)) + private[spark] def subexpressionEliminationEnabled: Boolean = + getConf(SUBEXPRESSION_ELIMINATION_ENABLED, codegenEnabled) + private[spark] def autoBroadcastJoinThreshold: Int = getConf(AUTO_BROADCASTJOIN_THRESHOLD) private[spark] def defaultSizeInBytes: Long = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 8bb293ae87e64..8650ac500b652 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -66,6 +66,11 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ } else { false } + val subexpressionEliminationEnabled: Boolean = if (sqlContext != null) { + sqlContext.conf.subexpressionEliminationEnabled + } else { + false + } /** * Whether the "prepare" method is called. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 145de0db9edaa..303d636164adb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -70,7 +70,8 @@ case class TungstenProject(projectList: Seq[NamedExpression], child: SparkPlan) protected override def doExecute(): RDD[InternalRow] = { val numRows = longMetric("numRows") child.execute().mapPartitions { iter => - val project = UnsafeProjection.create(projectList, child.output) + val project = UnsafeProjection.create(projectList, child.output, + subexpressionEliminationEnabled) iter.map { row => numRows += 1 project(row) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 441a0c6d0e36e..19e850a46fdfc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -1970,4 +1970,52 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { Row(1, 1) :: Row(1, 2) :: Row(2, 1) :: Row(2, 2) :: Row(3, 1) :: Row(3, 2) :: Nil) } } + + test("Common subexpression elimination") { + // select from a table to prevent constant folding. + val df = sql("SELECT a, b from testData2 limit 1") + checkAnswer(df, Row(1, 1)) + + checkAnswer(df.selectExpr("a + 1", "a + 1"), Row(2, 2)) + checkAnswer(df.selectExpr("a + 1", "a + 1 + 1"), Row(2, 3)) + + // This does not work because the expressions get grouped like (a + a) + 1 + checkAnswer(df.selectExpr("a + 1", "a + a + 1"), Row(2, 3)) + checkAnswer(df.selectExpr("a + 1", "a + (a + 1)"), Row(2, 3)) + + // Identity udf that tracks the number of times it is called. + val countAcc = sparkContext.accumulator(0, "CallCount") + sqlContext.udf.register("testUdf", (x: Int) => { + countAcc.++=(1) + x + }) + + // Evaluates df, verifying it is equal to the expectedResult and the accumulator's value + // is correct. + def verifyCallCount(df: DataFrame, expectedResult: Row, expectedCount: Int): Unit = { + countAcc.setValue(0) + checkAnswer(df, expectedResult) + assert(countAcc.value == expectedCount) + } + + verifyCallCount(df.selectExpr("testUdf(a)"), Row(1), 1) + verifyCallCount(df.selectExpr("testUdf(a)", "testUdf(a)"), Row(1, 1), 1) + verifyCallCount(df.selectExpr("testUdf(a + 1)", "testUdf(a + 1)"), Row(2, 2), 1) + verifyCallCount(df.selectExpr("testUdf(a + 1)", "testUdf(a)"), Row(2, 1), 2) + verifyCallCount( + df.selectExpr("testUdf(a + 1) + testUdf(a + 1)", "testUdf(a + 1)"), Row(4, 2), 1) + + verifyCallCount( + df.selectExpr("testUdf(a + 1) + testUdf(1 + b)", "testUdf(a + 1)"), Row(4, 2), 2) + + // Would be nice if semantic equals for `+` understood commutative + verifyCallCount( + df.selectExpr("testUdf(a + 1) + testUdf(1 + a)", "testUdf(a + 1)"), Row(4, 2), 2) + + // Try disabling it via configuration. + sqlContext.setConf("spark.sql.subexpressionElimination.enabled", "false") + verifyCallCount(df.selectExpr("testUdf(a)", "testUdf(a)"), Row(1, 1), 2) + sqlContext.setConf("spark.sql.subexpressionElimination.enabled", "true") + verifyCallCount(df.selectExpr("testUdf(a)", "testUdf(a)"), Row(1, 1), 1) + } } From f14e95115c0939a77ebcb00209696a87fd651ff9 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Tue, 10 Nov 2015 11:34:36 -0800 Subject: [PATCH 82/88] [ML][R] SparkR::glm summary result to compare with native R Follow up #9561. Due to [SPARK-11587](https://issues.apache.org/jira/browse/SPARK-11587) has been fixed, we should compare SparkR::glm summary result with native R output rather than hard-code one. mengxr Author: Yanbo Liang Closes #9590 from yanboliang/glm-r-test. --- R/pkg/R/mllib.R | 2 +- R/pkg/inst/tests/test_mllib.R | 31 ++++++++++--------------------- 2 files changed, 11 insertions(+), 22 deletions(-) diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R index 7126b7cde4bd7..f23e1c7f1fce4 100644 --- a/R/pkg/R/mllib.R +++ b/R/pkg/R/mllib.R @@ -106,7 +106,7 @@ setMethod("summary", signature(object = "PipelineModel"), coefficients <- matrix(coefficients, ncol = 4) colnames(coefficients) <- c("Estimate", "Std. Error", "t value", "Pr(>|t|)") rownames(coefficients) <- unlist(features) - return(list(DevianceResiduals = devianceResiduals, Coefficients = coefficients)) + return(list(devianceResiduals = devianceResiduals, coefficients = coefficients)) } else { coefficients <- as.matrix(unlist(coefficients)) colnames(coefficients) <- c("Estimate") diff --git a/R/pkg/inst/tests/test_mllib.R b/R/pkg/inst/tests/test_mllib.R index 42287ea19adc5..d497ad8c9daa3 100644 --- a/R/pkg/inst/tests/test_mllib.R +++ b/R/pkg/inst/tests/test_mllib.R @@ -72,22 +72,17 @@ test_that("feature interaction vs native glm", { test_that("summary coefficients match with native glm", { training <- createDataFrame(sqlContext, iris) stats <- summary(glm(Sepal_Width ~ Sepal_Length + Species, data = training, solver = "normal")) - coefs <- unlist(stats$Coefficients) - devianceResiduals <- unlist(stats$DevianceResiduals) + coefs <- unlist(stats$coefficients) + devianceResiduals <- unlist(stats$devianceResiduals) - rCoefs <- as.vector(coef(glm(Sepal.Width ~ Sepal.Length + Species, data = iris))) - rStdError <- c(0.23536, 0.04630, 0.07207, 0.09331) - rTValue <- c(7.123, 7.557, -13.644, -10.798) - rPValue <- c(0.0, 0.0, 0.0, 0.0) + rStats <- summary(glm(Sepal.Width ~ Sepal.Length + Species, data = iris)) + rCoefs <- unlist(rStats$coefficients) rDevianceResiduals <- c(-0.95096, 0.72918) - expect_true(all(abs(rCoefs - coefs[1:4]) < 1e-6)) - expect_true(all(abs(rStdError - coefs[5:8]) < 1e-5)) - expect_true(all(abs(rTValue - coefs[9:12]) < 1e-3)) - expect_true(all(abs(rPValue - coefs[13:16]) < 1e-6)) + expect_true(all(abs(rCoefs - coefs) < 1e-5)) expect_true(all(abs(rDevianceResiduals - devianceResiduals) < 1e-5)) expect_true(all( - rownames(stats$Coefficients) == + rownames(stats$coefficients) == c("(Intercept)", "Sepal_Length", "Species_versicolor", "Species_virginica"))) }) @@ -96,21 +91,15 @@ test_that("summary coefficients match with native glm of family 'binomial'", { training <- filter(df, df$Species != "setosa") stats <- summary(glm(Species ~ Sepal_Length + Sepal_Width, data = training, family = "binomial")) - coefs <- as.vector(stats$Coefficients) + coefs <- as.vector(stats$coefficients[,1]) rTraining <- iris[iris$Species %in% c("versicolor","virginica"),] rCoefs <- as.vector(coef(glm(Species ~ Sepal.Length + Sepal.Width, data = rTraining, family = binomial(link = "logit")))) - rStdError <- c(3.0974, 0.5169, 0.8628) - rTValue <- c(-4.212, 3.680, 0.469) - rPValue <- c(0.000, 0.000, 0.639) - - expect_true(all(abs(rCoefs - coefs[1:3]) < 1e-4)) - expect_true(all(abs(rStdError - coefs[4:6]) < 1e-4)) - expect_true(all(abs(rTValue - coefs[7:9]) < 1e-3)) - expect_true(all(abs(rPValue - coefs[10:12]) < 1e-3)) + + expect_true(all(abs(rCoefs - coefs) < 1e-4)) expect_true(all( - rownames(stats$Coefficients) == + rownames(stats$coefficients) == c("(Intercept)", "Sepal_Length", "Sepal_Width"))) }) From 18350a57004eb87cafa9504ff73affab4b818e06 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Tue, 10 Nov 2015 11:36:43 -0800 Subject: [PATCH 83/88] [SPARK-11618][ML] Minor refactoring of basic ML import/export Refactoring * separated overwrite and param save logic in DefaultParamsWriter * added sparkVersion to DefaultParamsWriter CC: mengxr Author: Joseph K. Bradley Closes #9587 from jkbradley/logreg-io. --- .../org/apache/spark/ml/util/ReadWrite.scala | 57 ++++++++++--------- 1 file changed, 30 insertions(+), 27 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala index ea790e0dddc7f..cbdf913ba8dfa 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala @@ -51,6 +51,9 @@ private[util] sealed trait BaseReadWrite { protected final def sqlContext: SQLContext = optionSQLContext.getOrElse { SQLContext.getOrCreate(SparkContext.getOrCreate()) } + + /** Returns the [[SparkContext]] underlying [[sqlContext]] */ + protected final def sc: SparkContext = sqlContext.sparkContext } /** @@ -58,7 +61,7 @@ private[util] sealed trait BaseReadWrite { */ @Experimental @Since("1.6.0") -abstract class Writer extends BaseReadWrite { +abstract class Writer extends BaseReadWrite with Logging { protected var shouldOverwrite: Boolean = false @@ -67,7 +70,29 @@ abstract class Writer extends BaseReadWrite { */ @Since("1.6.0") @throws[IOException]("If the input path already exists but overwrite is not enabled.") - def save(path: String): Unit + def save(path: String): Unit = { + val hadoopConf = sc.hadoopConfiguration + val fs = FileSystem.get(hadoopConf) + val p = new Path(path) + if (fs.exists(p)) { + if (shouldOverwrite) { + logInfo(s"Path $path already exists. It will be overwritten.") + // TODO: Revert back to the original content if save is not successful. + fs.delete(p, true) + } else { + throw new IOException( + s"Path $path already exists. Please use write.overwrite().save(path) to overwrite it.") + } + } + saveImpl(path) + } + + /** + * [[save()]] handles overwriting and then calls this method. Subclasses should override this + * method to implement the actual saving of the instance. + */ + @Since("1.6.0") + protected def saveImpl(path: String): Unit /** * Overwrites if the output path already exists. @@ -147,28 +172,9 @@ trait Readable[T] { * data (e.g., models with coefficients). * @param instance object to save */ -private[ml] class DefaultParamsWriter(instance: Params) extends Writer with Logging { - - /** - * Saves the ML component to the input path. - */ - override def save(path: String): Unit = { - val sc = sqlContext.sparkContext - - val hadoopConf = sc.hadoopConfiguration - val fs = FileSystem.get(hadoopConf) - val p = new Path(path) - if (fs.exists(p)) { - if (shouldOverwrite) { - logInfo(s"Path $path already exists. It will be overwritten.") - // TODO: Revert back to the original content if save is not successful. - fs.delete(p, true) - } else { - throw new IOException( - s"Path $path already exists. Please use write.overwrite().save(path) to overwrite it.") - } - } +private[ml] class DefaultParamsWriter(instance: Params) extends Writer { + override protected def saveImpl(path: String): Unit = { val uid = instance.uid val cls = instance.getClass.getName val params = instance.extractParamMap().toSeq.asInstanceOf[Seq[ParamPair[Any]]] @@ -177,6 +183,7 @@ private[ml] class DefaultParamsWriter(instance: Params) extends Writer with Logg }.toList val metadata = ("class" -> cls) ~ ("timestamp" -> System.currentTimeMillis()) ~ + ("sparkVersion" -> sc.version) ~ ("uid" -> uid) ~ ("paramMap" -> jsonParams) val metadataPath = new Path(path, "metadata").toString @@ -193,12 +200,8 @@ private[ml] class DefaultParamsWriter(instance: Params) extends Writer with Logg */ private[ml] class DefaultParamsReader[T] extends Reader[T] { - /** - * Loads the ML component from the input path. - */ override def load(path: String): T = { implicit val format = DefaultFormats - val sc = sqlContext.sparkContext val metadataPath = new Path(path, "metadata").toString val metadataStr = sc.textFile(metadataPath, 1).first() val metadata = parse(metadataStr) From dba1a62cf1baa9ae1ee665d592e01dfad78331a2 Mon Sep 17 00:00:00 2001 From: unknown Date: Tue, 10 Nov 2015 14:25:06 -0800 Subject: [PATCH 84/88] [SPARK-7316][MLLIB] RDD sliding window with step Implementation of step capability for sliding window function in MLlib's RDD. Though one can use current sliding window with step 1 and then filter every Nth window, it will take more time and space (N*data.count times more than needed). For example, below are the results for various windows and steps on 10M data points: Window | Step | Time | Windows produced ------------ | ------------- | ---------- | ---------- 128 | 1 | 6.38 | 9999873 128 | 10 | 0.9 | 999988 128 | 100 | 0.41 | 99999 1024 | 1 | 44.67 | 9998977 1024 | 10 | 4.74 | 999898 1024 | 100 | 0.78 | 99990 ``` import org.apache.spark.mllib.rdd.RDDFunctions._ val rdd = sc.parallelize(1 to 10000000, 10) rdd.count val window = 1024 val step = 1 val t = System.nanoTime(); val windows = rdd.sliding(window, step); println(windows.count); println((System.nanoTime() - t) / 1e9) ``` Author: unknown Author: Alexander Ulanov Author: Xiangrui Meng Closes #5855 from avulanov/SPARK-7316-sliding. --- .../apache/spark/mllib/rdd/RDDFunctions.scala | 11 ++- .../apache/spark/mllib/rdd/SlidingRDD.scala | 71 ++++++++++--------- .../spark/mllib/rdd/RDDFunctionsSuite.scala | 11 +-- 3 files changed, 54 insertions(+), 39 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala b/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala index 78172843be56e..19a047ded257c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala @@ -37,15 +37,20 @@ class RDDFunctions[T: ClassTag](self: RDD[T]) extends Serializable { * trigger a Spark job if the parent RDD has more than one partitions and the window size is * greater than 1. */ - def sliding(windowSize: Int): RDD[Array[T]] = { + def sliding(windowSize: Int, step: Int): RDD[Array[T]] = { require(windowSize > 0, s"Sliding window size must be positive, but got $windowSize.") - if (windowSize == 1) { + if (windowSize == 1 && step == 1) { self.map(Array(_)) } else { - new SlidingRDD[T](self, windowSize) + new SlidingRDD[T](self, windowSize, step) } } + /** + * [[sliding(Int, Int)*]] with step = 1. + */ + def sliding(windowSize: Int): RDD[Array[T]] = sliding(windowSize, 1) + /** * Reduces the elements of this RDD in a multi-level tree pattern. * diff --git a/mllib/src/main/scala/org/apache/spark/mllib/rdd/SlidingRDD.scala b/mllib/src/main/scala/org/apache/spark/mllib/rdd/SlidingRDD.scala index 1facf83d806d0..ead8db6344998 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/rdd/SlidingRDD.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/rdd/SlidingRDD.scala @@ -24,13 +24,13 @@ import org.apache.spark.{TaskContext, Partition} import org.apache.spark.rdd.RDD private[mllib] -class SlidingRDDPartition[T](val idx: Int, val prev: Partition, val tail: Seq[T]) +class SlidingRDDPartition[T](val idx: Int, val prev: Partition, val tail: Seq[T], val offset: Int) extends Partition with Serializable { override val index: Int = idx } /** - * Represents a RDD from grouping items of its parent RDD in fixed size blocks by passing a sliding + * Represents an RDD from grouping items of its parent RDD in fixed size blocks by passing a sliding * window over them. The ordering is first based on the partition index and then the ordering of * items within each partition. This is similar to sliding in Scala collections, except that it * becomes an empty RDD if the window size is greater than the total number of items. It needs to @@ -40,19 +40,24 @@ class SlidingRDDPartition[T](val idx: Int, val prev: Partition, val tail: Seq[T] * * @param parent the parent RDD * @param windowSize the window size, must be greater than 1 + * @param step step size for windows * - * @see [[org.apache.spark.mllib.rdd.RDDFunctions#sliding]] + * @see [[org.apache.spark.mllib.rdd.RDDFunctions.sliding(Int, Int)*]] + * @see [[scala.collection.IterableLike.sliding(Int, Int)*]] */ private[mllib] -class SlidingRDD[T: ClassTag](@transient val parent: RDD[T], val windowSize: Int) +class SlidingRDD[T: ClassTag](@transient val parent: RDD[T], val windowSize: Int, val step: Int) extends RDD[Array[T]](parent) { - require(windowSize > 1, s"Window size must be greater than 1, but got $windowSize.") + require(windowSize > 0 && step > 0 && !(windowSize == 1 && step == 1), + "Window size and step must be greater than 0, " + + s"and they cannot be both 1, but got windowSize = $windowSize and step = $step.") override def compute(split: Partition, context: TaskContext): Iterator[Array[T]] = { val part = split.asInstanceOf[SlidingRDDPartition[T]] (firstParent[T].iterator(part.prev, context) ++ part.tail) - .sliding(windowSize) + .drop(part.offset) + .sliding(windowSize, step) .withPartial(false) .map(_.toArray) } @@ -62,40 +67,42 @@ class SlidingRDD[T: ClassTag](@transient val parent: RDD[T], val windowSize: Int override def getPartitions: Array[Partition] = { val parentPartitions = parent.partitions - val n = parentPartitions.size + val n = parentPartitions.length if (n == 0) { Array.empty } else if (n == 1) { - Array(new SlidingRDDPartition[T](0, parentPartitions(0), Seq.empty)) + Array(new SlidingRDDPartition[T](0, parentPartitions(0), Seq.empty, 0)) } else { - val n1 = n - 1 val w1 = windowSize - 1 - // Get the first w1 items of each partition, starting from the second partition. - val nextHeads = - parent.context.runJob(parent, (iter: Iterator[T]) => iter.take(w1).toArray, 1 until n) - val partitions = mutable.ArrayBuffer[SlidingRDDPartition[T]]() + // Get partition sizes and first w1 elements. + val (sizes, heads) = parent.mapPartitions { iter => + val w1Array = iter.take(w1).toArray + Iterator.single((w1Array.length + iter.length, w1Array)) + }.collect().unzip + val partitions = mutable.ArrayBuffer.empty[SlidingRDDPartition[T]] var i = 0 + var cumSize = 0 var partitionIndex = 0 - while (i < n1) { - var j = i - val tail = mutable.ListBuffer[T]() - // Keep appending to the current tail until appended a head of size w1. - while (j < n1 && nextHeads(j).size < w1) { - tail ++= nextHeads(j) - j += 1 + while (i < n) { + val mod = cumSize % step + val offset = if (mod == 0) 0 else step - mod + val size = sizes(i) + if (offset < size) { + val tail = mutable.ListBuffer.empty[T] + // Keep appending to the current tail until it has w1 elements. + var j = i + 1 + while (j < n && tail.length < w1) { + tail ++= heads(j).take(w1 - tail.length) + j += 1 + } + if (sizes(i) + tail.length >= offset + windowSize) { + partitions += + new SlidingRDDPartition[T](partitionIndex, parentPartitions(i), tail, offset) + partitionIndex += 1 + } } - if (j < n1) { - tail ++= nextHeads(j) - j += 1 - } - partitions += new SlidingRDDPartition[T](partitionIndex, parentPartitions(i), tail) - partitionIndex += 1 - // Skip appended heads. - i = j - } - // If the head of last partition has size w1, we also need to add this partition. - if (nextHeads.last.size == w1) { - partitions += new SlidingRDDPartition[T](partitionIndex, parentPartitions(n1), Seq.empty) + cumSize += size + i += 1 } partitions.toArray } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala index bc64172614830..ac93733bab5f5 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala @@ -28,9 +28,12 @@ class RDDFunctionsSuite extends SparkFunSuite with MLlibTestSparkContext { for (numPartitions <- 1 to 8) { val rdd = sc.parallelize(data, numPartitions) for (windowSize <- 1 to 6) { - val sliding = rdd.sliding(windowSize).collect().map(_.toList).toList - val expected = data.sliding(windowSize).map(_.toList).toList - assert(sliding === expected) + for (step <- 1 to 3) { + val sliding = rdd.sliding(windowSize, step).collect().map(_.toList).toList + val expected = data.sliding(windowSize, step) + .map(_.toList).toList.filter(l => l.size == windowSize) + assert(sliding === expected) + } } assert(rdd.sliding(7).collect().isEmpty, "Should return an empty RDD if the window size is greater than the number of items.") @@ -40,7 +43,7 @@ class RDDFunctionsSuite extends SparkFunSuite with MLlibTestSparkContext { test("sliding with empty partitions") { val data = Seq(Seq(1, 2, 3), Seq.empty[Int], Seq(4), Seq.empty[Int], Seq(5, 6, 7)) val rdd = sc.parallelize(data, data.length).flatMap(s => s) - assert(rdd.partitions.size === data.length) + assert(rdd.partitions.length === data.length) val sliding = rdd.sliding(3).collect().toSeq.map(_.toSeq) val expected = data.flatMap(x => x).sliding(3).toSeq.map(_.toSeq) assert(sliding === expected) From 724cf7a38c551bf2a79b87a8158bbe1725f9f888 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Tue, 10 Nov 2015 14:30:19 -0800 Subject: [PATCH 85/88] [SPARK-11616][SQL] Improve toString for Dataset Author: Michael Armbrust Closes #9586 from marmbrus/dataset-toString. --- .../org/apache/spark/sql/DataFrame.scala | 14 ++----- .../scala/org/apache/spark/sql/Dataset.scala | 4 +- .../spark/sql/execution/Queryable.scala | 37 +++++++++++++++++++ .../org/apache/spark/sql/DatasetSuite.scala | 5 +++ 4 files changed, 47 insertions(+), 13 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/Queryable.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 9368435a63c35..691b476fff8d6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -39,7 +39,7 @@ import org.apache.spark.sql.catalyst.encoders.Encoder import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.{Inner, JoinType} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, ScalaReflection, SqlParser} -import org.apache.spark.sql.execution.{EvaluatePython, ExplainCommand, FileRelation, LogicalRDD, QueryExecution, SQLExecution} +import org.apache.spark.sql.execution.{EvaluatePython, ExplainCommand, FileRelation, LogicalRDD, QueryExecution, Queryable, SQLExecution} import org.apache.spark.sql.execution.datasources.{CreateTableUsingAsSelect, LogicalRelation} import org.apache.spark.sql.execution.datasources.json.JacksonGenerator import org.apache.spark.sql.sources.HadoopFsRelation @@ -116,7 +116,8 @@ private[sql] object DataFrame { @Experimental class DataFrame private[sql]( @transient val sqlContext: SQLContext, - @DeveloperApi @transient val queryExecution: QueryExecution) extends Serializable { + @DeveloperApi @transient val queryExecution: QueryExecution) + extends Queryable with Serializable { // Note for Spark contributors: if adding or updating any action in `DataFrame`, please make sure // you wrap it with `withNewExecutionId` if this actions doesn't call other action. @@ -234,15 +235,6 @@ class DataFrame private[sql]( sb.toString() } - override def toString: String = { - try { - schema.map(f => s"${f.name}: ${f.dataType.simpleString}").mkString("[", ", ", "]") - } catch { - case NonFatal(e) => - s"Invalid tree; ${e.getMessage}:\n$queryExecution" - } - } - /** * Returns the object itself. * @group basic diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 6d2968e2881f8..a7e5ab19bf846 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.encoders._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.Inner import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.execution.QueryExecution +import org.apache.spark.sql.execution.{Queryable, QueryExecution} import org.apache.spark.sql.types.StructType /** @@ -62,7 +62,7 @@ import org.apache.spark.sql.types.StructType class Dataset[T] private[sql]( @transient val sqlContext: SQLContext, @transient val queryExecution: QueryExecution, - unresolvedEncoder: Encoder[T]) extends Serializable { + unresolvedEncoder: Encoder[T]) extends Queryable with Serializable { /** The encoder for this [[Dataset]] that has been resolved to its output schema. */ private[sql] implicit val encoder: ExpressionEncoder[T] = unresolvedEncoder match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Queryable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Queryable.scala new file mode 100644 index 0000000000000..9ca383896a09b --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Queryable.scala @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.sql.types.StructType + +import scala.util.control.NonFatal + +/** A trait that holds shared code between DataFrames and Datasets. */ +private[sql] trait Queryable { + def schema: StructType + def queryExecution: QueryExecution + + override def toString: String = { + try { + schema.map(f => s"${f.name}: ${f.dataType.simpleString}").mkString("[", ", ", "]") + } catch { + case NonFatal(e) => + s"Invalid tree; ${e.getMessage}:\n$queryExecution" + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index aea5a700d0204..621148528714f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -313,4 +313,9 @@ class DatasetSuite extends QueryTest with SharedSQLContext { val joined = ds1.joinWith(ds2, $"a.value" === $"b.value") checkAnswer(joined, ("2", 2)) } + + test("toString") { + val ds = Seq((1, 2)).toDS() + assert(ds.toString == "[_1: int, _2: int]") + } } From 638c51d9380081b3b8182be2c2460bd53b8b0a4f Mon Sep 17 00:00:00 2001 From: Pravin Gadakh Date: Tue, 10 Nov 2015 14:47:04 -0800 Subject: [PATCH 86/88] [SPARK-11550][DOCS] Replace example code in mllib-optimization.md using include_example Author: Pravin Gadakh Closes #9516 from pravingadakh/SPARK-11550. --- docs/mllib-optimization.md | 145 +----------------- .../examples/mllib/JavaLBFGSExample.java | 108 +++++++++++++ .../spark/examples/mllib/LBFGSExample.scala | 90 +++++++++++ 3 files changed, 200 insertions(+), 143 deletions(-) create mode 100644 examples/src/main/java/org/apache/spark/examples/mllib/JavaLBFGSExample.java create mode 100644 examples/src/main/scala/org/apache/spark/examples/mllib/LBFGSExample.scala diff --git a/docs/mllib-optimization.md b/docs/mllib-optimization.md index a3bd130ba077c..ad7bcd9bfd407 100644 --- a/docs/mllib-optimization.md +++ b/docs/mllib-optimization.md @@ -220,154 +220,13 @@ L-BFGS optimizer.
Refer to the [`LBFGS` Scala docs](api/scala/index.html#org.apache.spark.mllib.optimization.LBFGS) and [`SquaredL2Updater` Scala docs](api/scala/index.html#org.apache.spark.mllib.optimization.SquaredL2Updater) for details on the API. -{% highlight scala %} -import org.apache.spark.SparkContext -import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics -import org.apache.spark.mllib.linalg.Vectors -import org.apache.spark.mllib.util.MLUtils -import org.apache.spark.mllib.classification.LogisticRegressionModel -import org.apache.spark.mllib.optimization.{LBFGS, LogisticGradient, SquaredL2Updater} - -val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") -val numFeatures = data.take(1)(0).features.size - -// Split data into training (60%) and test (40%). -val splits = data.randomSplit(Array(0.6, 0.4), seed = 11L) - -// Append 1 into the training data as intercept. -val training = splits(0).map(x => (x.label, MLUtils.appendBias(x.features))).cache() - -val test = splits(1) - -// Run training algorithm to build the model -val numCorrections = 10 -val convergenceTol = 1e-4 -val maxNumIterations = 20 -val regParam = 0.1 -val initialWeightsWithIntercept = Vectors.dense(new Array[Double](numFeatures + 1)) - -val (weightsWithIntercept, loss) = LBFGS.runLBFGS( - training, - new LogisticGradient(), - new SquaredL2Updater(), - numCorrections, - convergenceTol, - maxNumIterations, - regParam, - initialWeightsWithIntercept) - -val model = new LogisticRegressionModel( - Vectors.dense(weightsWithIntercept.toArray.slice(0, weightsWithIntercept.size - 1)), - weightsWithIntercept(weightsWithIntercept.size - 1)) - -// Clear the default threshold. -model.clearThreshold() - -// Compute raw scores on the test set. -val scoreAndLabels = test.map { point => - val score = model.predict(point.features) - (score, point.label) -} - -// Get evaluation metrics. -val metrics = new BinaryClassificationMetrics(scoreAndLabels) -val auROC = metrics.areaUnderROC() - -println("Loss of each step in training process") -loss.foreach(println) -println("Area under ROC = " + auROC) -{% endhighlight %} +{% include_example scala/org/apache/spark/examples/mllib/LBFGSExample.scala %}
Refer to the [`LBFGS` Java docs](api/java/org/apache/spark/mllib/optimization/LBFGS.html) and [`SquaredL2Updater` Java docs](api/java/org/apache/spark/mllib/optimization/SquaredL2Updater.html) for details on the API. -{% highlight java %} -import java.util.Arrays; -import java.util.Random; - -import scala.Tuple2; - -import org.apache.spark.api.java.*; -import org.apache.spark.api.java.function.Function; -import org.apache.spark.mllib.classification.LogisticRegressionModel; -import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics; -import org.apache.spark.mllib.linalg.Vector; -import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.mllib.optimization.*; -import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.mllib.util.MLUtils; -import org.apache.spark.SparkConf; -import org.apache.spark.SparkContext; - -public class LBFGSExample { - public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("L-BFGS Example"); - SparkContext sc = new SparkContext(conf); - String path = "data/mllib/sample_libsvm_data.txt"; - JavaRDD data = MLUtils.loadLibSVMFile(sc, path).toJavaRDD(); - int numFeatures = data.take(1).get(0).features().size(); - - // Split initial RDD into two... [60% training data, 40% testing data]. - JavaRDD trainingInit = data.sample(false, 0.6, 11L); - JavaRDD test = data.subtract(trainingInit); - - // Append 1 into the training data as intercept. - JavaRDD> training = data.map( - new Function>() { - public Tuple2 call(LabeledPoint p) { - return new Tuple2(p.label(), MLUtils.appendBias(p.features())); - } - }); - training.cache(); - - // Run training algorithm to build the model. - int numCorrections = 10; - double convergenceTol = 1e-4; - int maxNumIterations = 20; - double regParam = 0.1; - Vector initialWeightsWithIntercept = Vectors.dense(new double[numFeatures + 1]); - - Tuple2 result = LBFGS.runLBFGS( - training.rdd(), - new LogisticGradient(), - new SquaredL2Updater(), - numCorrections, - convergenceTol, - maxNumIterations, - regParam, - initialWeightsWithIntercept); - Vector weightsWithIntercept = result._1(); - double[] loss = result._2(); - - final LogisticRegressionModel model = new LogisticRegressionModel( - Vectors.dense(Arrays.copyOf(weightsWithIntercept.toArray(), weightsWithIntercept.size() - 1)), - (weightsWithIntercept.toArray())[weightsWithIntercept.size() - 1]); - - // Clear the default threshold. - model.clearThreshold(); - - // Compute raw scores on the test set. - JavaRDD> scoreAndLabels = test.map( - new Function>() { - public Tuple2 call(LabeledPoint p) { - Double score = model.predict(p.features()); - return new Tuple2(score, p.label()); - } - }); - - // Get evaluation metrics. - BinaryClassificationMetrics metrics = - new BinaryClassificationMetrics(scoreAndLabels.rdd()); - double auROC = metrics.areaUnderROC(); - - System.out.println("Loss of each step in training process"); - for (double l : loss) - System.out.println(l); - System.out.println("Area under ROC = " + auROC); - } -} -{% endhighlight %} +{% include_example java/org/apache/spark/examples/mllib/JavaLBFGSExample.java %}
diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaLBFGSExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaLBFGSExample.java new file mode 100644 index 0000000000000..355883f61bd64 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaLBFGSExample.java @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib; + +// $example on$ +import java.util.Arrays; + +import scala.Tuple2; + +import org.apache.spark.api.java.*; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.mllib.classification.LogisticRegressionModel; +import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.mllib.optimization.*; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.util.MLUtils; +import org.apache.spark.SparkConf; +import org.apache.spark.SparkContext; +// $example off$ + +public class JavaLBFGSExample { + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("L-BFGS Example"); + SparkContext sc = new SparkContext(conf); + + // $example on$ + String path = "data/mllib/sample_libsvm_data.txt"; + JavaRDD data = MLUtils.loadLibSVMFile(sc, path).toJavaRDD(); + int numFeatures = data.take(1).get(0).features().size(); + + // Split initial RDD into two... [60% training data, 40% testing data]. + JavaRDD trainingInit = data.sample(false, 0.6, 11L); + JavaRDD test = data.subtract(trainingInit); + + // Append 1 into the training data as intercept. + JavaRDD> training = data.map( + new Function>() { + public Tuple2 call(LabeledPoint p) { + return new Tuple2(p.label(), MLUtils.appendBias(p.features())); + } + }); + training.cache(); + + // Run training algorithm to build the model. + int numCorrections = 10; + double convergenceTol = 1e-4; + int maxNumIterations = 20; + double regParam = 0.1; + Vector initialWeightsWithIntercept = Vectors.dense(new double[numFeatures + 1]); + + Tuple2 result = LBFGS.runLBFGS( + training.rdd(), + new LogisticGradient(), + new SquaredL2Updater(), + numCorrections, + convergenceTol, + maxNumIterations, + regParam, + initialWeightsWithIntercept); + Vector weightsWithIntercept = result._1(); + double[] loss = result._2(); + + final LogisticRegressionModel model = new LogisticRegressionModel( + Vectors.dense(Arrays.copyOf(weightsWithIntercept.toArray(), weightsWithIntercept.size() - 1)), + (weightsWithIntercept.toArray())[weightsWithIntercept.size() - 1]); + + // Clear the default threshold. + model.clearThreshold(); + + // Compute raw scores on the test set. + JavaRDD> scoreAndLabels = test.map( + new Function>() { + public Tuple2 call(LabeledPoint p) { + Double score = model.predict(p.features()); + return new Tuple2(score, p.label()); + } + }); + + // Get evaluation metrics. + BinaryClassificationMetrics metrics = + new BinaryClassificationMetrics(scoreAndLabels.rdd()); + double auROC = metrics.areaUnderROC(); + + System.out.println("Loss of each step in training process"); + for (double l : loss) + System.out.println(l); + System.out.println("Area under ROC = " + auROC); + // $example off$ + } +} + diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/LBFGSExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/LBFGSExample.scala new file mode 100644 index 0000000000000..61d2e7715f53d --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/LBFGSExample.scala @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.mllib + +// $example on$ +import org.apache.spark.mllib.classification.LogisticRegressionModel +import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.optimization.{LBFGS, LogisticGradient, SquaredL2Updater} +import org.apache.spark.mllib.util.MLUtils +// $example off$ + +import org.apache.spark.{SparkConf, SparkContext} + +object LBFGSExample { + + def main(args: Array[String]): Unit = { + + val conf = new SparkConf().setAppName("LBFGSExample") + val sc = new SparkContext(conf) + + // $example on$ + val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") + val numFeatures = data.take(1)(0).features.size + + // Split data into training (60%) and test (40%). + val splits = data.randomSplit(Array(0.6, 0.4), seed = 11L) + + // Append 1 into the training data as intercept. + val training = splits(0).map(x => (x.label, MLUtils.appendBias(x.features))).cache() + + val test = splits(1) + + // Run training algorithm to build the model + val numCorrections = 10 + val convergenceTol = 1e-4 + val maxNumIterations = 20 + val regParam = 0.1 + val initialWeightsWithIntercept = Vectors.dense(new Array[Double](numFeatures + 1)) + + val (weightsWithIntercept, loss) = LBFGS.runLBFGS( + training, + new LogisticGradient(), + new SquaredL2Updater(), + numCorrections, + convergenceTol, + maxNumIterations, + regParam, + initialWeightsWithIntercept) + + val model = new LogisticRegressionModel( + Vectors.dense(weightsWithIntercept.toArray.slice(0, weightsWithIntercept.size - 1)), + weightsWithIntercept(weightsWithIntercept.size - 1)) + + // Clear the default threshold. + model.clearThreshold() + + // Compute raw scores on the test set. + val scoreAndLabels = test.map { point => + val score = model.predict(point.features) + (score, point.label) + } + + // Get evaluation metrics. + val metrics = new BinaryClassificationMetrics(scoreAndLabels) + val auROC = metrics.areaUnderROC() + + println("Loss of each step in training process") + loss.foreach(println) + println("Area under ROC = " + auROC) + // $example off$ + } +} +// scalastyle:on println From 32790fe7249b0efe2cbc5c4ee2df0fb687dcd624 Mon Sep 17 00:00:00 2001 From: felixcheung Date: Tue, 10 Nov 2015 15:47:10 -0800 Subject: [PATCH 87/88] [SPARK-11567] [PYTHON] Add Python API for corr Aggregate function like `df.agg(corr("col1", "col2")` davies Author: felixcheung Closes #9536 from felixcheung/pyfunc. --- python/pyspark/sql/functions.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 6e1cbde4239f3..c3da513c13897 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -255,6 +255,22 @@ def coalesce(*cols): return Column(jc) +@since(1.6) +def corr(col1, col2): + """Returns a new :class:`Column` for the Pearson Correlation Coefficient for ``col1`` + and ``col2``. + + >>> a = [x * x - 2 * x + 3.5 for x in range(20)] + >>> b = range(20) + >>> corrDf = sqlContext.createDataFrame(zip(a, b)) + >>> corrDf = corrDf.agg(corr(corrDf._1, corrDf._2).alias('c')) + >>> corrDf.selectExpr('abs(c - 0.9572339139475857) < 1e-16 as t').collect() + [Row(t=True)] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.corr(_to_java_column(col1), _to_java_column(col2))) + + @since(1.3) def countDistinct(col, *cols): """Returns a new :class:`Column` for distinct count of ``col`` or ``cols``. From 1dde39d796bbf42336051a86bedf871c7fddd513 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 10 Nov 2015 15:58:30 -0800 Subject: [PATCH 88/88] [SPARK-9818] Re-enable Docker tests for JDBC data source This patch re-enables tests for the Docker JDBC data source. These tests were reverted in #4872 due to transitive dependency conflicts introduced by the `docker-client` library. This patch should avoid those problems by using a version of `docker-client` which shades its transitive dependencies and by performing some build-magic to work around problems with that shaded JAR. In addition, I significantly refactored the tests to simplify the setup and teardown code and to fix several Docker networking issues which caused problems when running in `boot2docker`. Closes #8101. Author: Josh Rosen Author: Yijie Shen Closes #9503 from JoshRosen/docker-jdbc-tests. --- docker-integration-tests/pom.xml | 149 ++++++++++++++++ .../sql/jdbc/DockerJDBCIntegrationSuite.scala | 160 ++++++++++++++++++ .../sql/jdbc/MySQLIntegrationSuite.scala | 153 +++++++++++++++++ .../sql/jdbc/PostgresIntegrationSuite.scala | 82 +++++++++ .../org/apache/spark/util/DockerUtils.scala | 68 ++++++++ pom.xml | 14 ++ project/SparkBuild.scala | 14 +- .../org/apache/spark/tags/DockerTest.java | 26 +++ 8 files changed, 664 insertions(+), 2 deletions(-) create mode 100644 docker-integration-tests/pom.xml create mode 100644 docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DockerJDBCIntegrationSuite.scala create mode 100644 docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala create mode 100644 docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala create mode 100644 docker-integration-tests/src/test/scala/org/apache/spark/util/DockerUtils.scala create mode 100644 tags/src/main/java/org/apache/spark/tags/DockerTest.java diff --git a/docker-integration-tests/pom.xml b/docker-integration-tests/pom.xml new file mode 100644 index 0000000000000..dee0c4aa37ae8 --- /dev/null +++ b/docker-integration-tests/pom.xml @@ -0,0 +1,149 @@ + + + + + 4.0.0 + + org.apache.spark + spark-parent_2.10 + 1.6.0-SNAPSHOT + ../pom.xml + + + spark-docker-integration-tests_2.10 + jar + Spark Project Docker Integration Tests + http://spark.apache.org/ + + docker-integration-tests + + + + + com.spotify + docker-client + shaded + test + + + + com.fasterxml.jackson.jaxrs + jackson-jaxrs-json-provider + + + com.fasterxml.jackson.datatype + jackson-datatype-guava + + + com.fasterxml.jackson.core + jackson-databind + + + org.glassfish.jersey.core + jersey-client + + + org.glassfish.jersey.connectors + jersey-apache-connector + + + org.glassfish.jersey.media + jersey-media-json-jackson + + + + + + com.google.guava + guava + 18.0 + + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test + + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + + + org.apache.spark + spark-sql_${scala.binary.version} + ${project.version} + test + + + org.apache.spark + spark-sql_${scala.binary.version} + ${project.version} + test-jar + test + + + org.apache.spark + spark-test-tags_${scala.binary.version} + ${project.version} + test + + + + com.sun.jersey + jersey-server + 1.19 + test + + + com.sun.jersey + jersey-core + 1.19 + test + + + com.sun.jersey + jersey-servlet + 1.19 + test + + + com.sun.jersey + jersey-json + 1.19 + test + + + stax + stax-api + + + + + + diff --git a/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DockerJDBCIntegrationSuite.scala b/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DockerJDBCIntegrationSuite.scala new file mode 100644 index 0000000000000..c503c4a13b482 --- /dev/null +++ b/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DockerJDBCIntegrationSuite.scala @@ -0,0 +1,160 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.jdbc + +import java.net.ServerSocket +import java.sql.Connection + +import scala.collection.JavaConverters._ +import scala.util.control.NonFatal + +import com.spotify.docker.client._ +import com.spotify.docker.client.messages.{ContainerConfig, HostConfig, PortBinding} +import org.scalatest.BeforeAndAfterAll +import org.scalatest.concurrent.Eventually +import org.scalatest.time.SpanSugar._ + +import org.apache.spark.SparkFunSuite +import org.apache.spark.util.DockerUtils +import org.apache.spark.sql.test.SharedSQLContext + +abstract class DatabaseOnDocker { + /** + * The docker image to be pulled. + */ + val imageName: String + + /** + * Environment variables to set inside of the Docker container while launching it. + */ + val env: Map[String, String] + + /** + * The container-internal JDBC port that the database listens on. + */ + val jdbcPort: Int + + /** + * Return a JDBC URL that connects to the database running at the given IP address and port. + */ + def getJdbcUrl(ip: String, port: Int): String +} + +abstract class DockerJDBCIntegrationSuite + extends SparkFunSuite + with BeforeAndAfterAll + with Eventually + with SharedSQLContext { + + val db: DatabaseOnDocker + + private var docker: DockerClient = _ + private var containerId: String = _ + protected var jdbcUrl: String = _ + + override def beforeAll() { + super.beforeAll() + try { + docker = DefaultDockerClient.fromEnv.build() + // Check that Docker is actually up + try { + docker.ping() + } catch { + case NonFatal(e) => + log.error("Exception while connecting to Docker. Check whether Docker is running.") + throw e + } + // Ensure that the Docker image is installed: + try { + docker.inspectImage(db.imageName) + } catch { + case e: ImageNotFoundException => + log.warn(s"Docker image ${db.imageName} not found; pulling image from registry") + docker.pull(db.imageName) + } + // Configure networking (necessary for boot2docker / Docker Machine) + val externalPort: Int = { + val sock = new ServerSocket(0) + val port = sock.getLocalPort + sock.close() + port + } + val dockerIp = DockerUtils.getDockerIp() + val hostConfig: HostConfig = HostConfig.builder() + .networkMode("bridge") + .portBindings( + Map(s"${db.jdbcPort}/tcp" -> List(PortBinding.of(dockerIp, externalPort)).asJava).asJava) + .build() + // Create the database container: + val config = ContainerConfig.builder() + .image(db.imageName) + .networkDisabled(false) + .env(db.env.map { case (k, v) => s"$k=$v" }.toSeq.asJava) + .hostConfig(hostConfig) + .exposedPorts(s"${db.jdbcPort}/tcp") + .build() + containerId = docker.createContainer(config).id + // Start the container and wait until the database can accept JDBC connections: + docker.startContainer(containerId) + jdbcUrl = db.getJdbcUrl(dockerIp, externalPort) + eventually(timeout(60.seconds), interval(1.seconds)) { + val conn = java.sql.DriverManager.getConnection(jdbcUrl) + conn.close() + } + // Run any setup queries: + val conn: Connection = java.sql.DriverManager.getConnection(jdbcUrl) + try { + dataPreparation(conn) + } finally { + conn.close() + } + } catch { + case NonFatal(e) => + try { + afterAll() + } finally { + throw e + } + } + } + + override def afterAll() { + try { + if (docker != null) { + try { + if (containerId != null) { + docker.killContainer(containerId) + docker.removeContainer(containerId) + } + } catch { + case NonFatal(e) => + logWarning(s"Could not stop container $containerId", e) + } finally { + docker.close() + } + } + } finally { + super.afterAll() + } + } + + /** + * Prepare databases and tables for testing. + */ + def dataPreparation(connection: Connection): Unit +} diff --git a/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala b/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala new file mode 100644 index 0000000000000..c68e4dc4933b1 --- /dev/null +++ b/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala @@ -0,0 +1,153 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.jdbc + +import java.math.BigDecimal +import java.sql.{Connection, Date, Timestamp} +import java.util.Properties + +import org.apache.spark.tags.DockerTest + +@DockerTest +class MySQLIntegrationSuite extends DockerJDBCIntegrationSuite { + override val db = new DatabaseOnDocker { + override val imageName = "mysql:5.7.9" + override val env = Map( + "MYSQL_ROOT_PASSWORD" -> "rootpass" + ) + override val jdbcPort: Int = 3306 + override def getJdbcUrl(ip: String, port: Int): String = + s"jdbc:mysql://$ip:$port/mysql?user=root&password=rootpass" + } + + override def dataPreparation(conn: Connection): Unit = { + conn.prepareStatement("CREATE DATABASE foo").executeUpdate() + conn.prepareStatement("CREATE TABLE tbl (x INTEGER, y TEXT(8))").executeUpdate() + conn.prepareStatement("INSERT INTO tbl VALUES (42,'fred')").executeUpdate() + conn.prepareStatement("INSERT INTO tbl VALUES (17,'dave')").executeUpdate() + + conn.prepareStatement("CREATE TABLE numbers (onebit BIT(1), tenbits BIT(10), " + + "small SMALLINT, med MEDIUMINT, nor INT, big BIGINT, deci DECIMAL(40,20), flt FLOAT, " + + "dbl DOUBLE)").executeUpdate() + conn.prepareStatement("INSERT INTO numbers VALUES (b'0', b'1000100101', " + + "17, 77777, 123456789, 123456789012345, 123456789012345.123456789012345, " + + "42.75, 1.0000000000000002)").executeUpdate() + + conn.prepareStatement("CREATE TABLE dates (d DATE, t TIME, dt DATETIME, ts TIMESTAMP, " + + "yr YEAR)").executeUpdate() + conn.prepareStatement("INSERT INTO dates VALUES ('1991-11-09', '13:31:24', " + + "'1996-01-01 01:23:45', '2009-02-13 23:31:30', '2001')").executeUpdate() + + // TODO: Test locale conversion for strings. + conn.prepareStatement("CREATE TABLE strings (a CHAR(10), b VARCHAR(10), c TINYTEXT, " + + "d TEXT, e MEDIUMTEXT, f LONGTEXT, g BINARY(4), h VARBINARY(10), i BLOB)" + ).executeUpdate() + conn.prepareStatement("INSERT INTO strings VALUES ('the', 'quick', 'brown', 'fox', " + + "'jumps', 'over', 'the', 'lazy', 'dog')").executeUpdate() + } + + test("Basic test") { + val df = sqlContext.read.jdbc(jdbcUrl, "tbl", new Properties) + val rows = df.collect() + assert(rows.length == 2) + val types = rows(0).toSeq.map(x => x.getClass.toString) + assert(types.length == 2) + assert(types(0).equals("class java.lang.Integer")) + assert(types(1).equals("class java.lang.String")) + } + + test("Numeric types") { + val df = sqlContext.read.jdbc(jdbcUrl, "numbers", new Properties) + val rows = df.collect() + assert(rows.length == 1) + val types = rows(0).toSeq.map(x => x.getClass.toString) + assert(types.length == 9) + assert(types(0).equals("class java.lang.Boolean")) + assert(types(1).equals("class java.lang.Long")) + assert(types(2).equals("class java.lang.Integer")) + assert(types(3).equals("class java.lang.Integer")) + assert(types(4).equals("class java.lang.Integer")) + assert(types(5).equals("class java.lang.Long")) + assert(types(6).equals("class java.math.BigDecimal")) + assert(types(7).equals("class java.lang.Double")) + assert(types(8).equals("class java.lang.Double")) + assert(rows(0).getBoolean(0) == false) + assert(rows(0).getLong(1) == 0x225) + assert(rows(0).getInt(2) == 17) + assert(rows(0).getInt(3) == 77777) + assert(rows(0).getInt(4) == 123456789) + assert(rows(0).getLong(5) == 123456789012345L) + val bd = new BigDecimal("123456789012345.12345678901234500000") + assert(rows(0).getAs[BigDecimal](6).equals(bd)) + assert(rows(0).getDouble(7) == 42.75) + assert(rows(0).getDouble(8) == 1.0000000000000002) + } + + test("Date types") { + val df = sqlContext.read.jdbc(jdbcUrl, "dates", new Properties) + val rows = df.collect() + assert(rows.length == 1) + val types = rows(0).toSeq.map(x => x.getClass.toString) + assert(types.length == 5) + assert(types(0).equals("class java.sql.Date")) + assert(types(1).equals("class java.sql.Timestamp")) + assert(types(2).equals("class java.sql.Timestamp")) + assert(types(3).equals("class java.sql.Timestamp")) + assert(types(4).equals("class java.sql.Date")) + assert(rows(0).getAs[Date](0).equals(Date.valueOf("1991-11-09"))) + assert(rows(0).getAs[Timestamp](1).equals(Timestamp.valueOf("1970-01-01 13:31:24"))) + assert(rows(0).getAs[Timestamp](2).equals(Timestamp.valueOf("1996-01-01 01:23:45"))) + assert(rows(0).getAs[Timestamp](3).equals(Timestamp.valueOf("2009-02-13 23:31:30"))) + assert(rows(0).getAs[Date](4).equals(Date.valueOf("2001-01-01"))) + } + + test("String types") { + val df = sqlContext.read.jdbc(jdbcUrl, "strings", new Properties) + val rows = df.collect() + assert(rows.length == 1) + val types = rows(0).toSeq.map(x => x.getClass.toString) + assert(types.length == 9) + assert(types(0).equals("class java.lang.String")) + assert(types(1).equals("class java.lang.String")) + assert(types(2).equals("class java.lang.String")) + assert(types(3).equals("class java.lang.String")) + assert(types(4).equals("class java.lang.String")) + assert(types(5).equals("class java.lang.String")) + assert(types(6).equals("class [B")) + assert(types(7).equals("class [B")) + assert(types(8).equals("class [B")) + assert(rows(0).getString(0).equals("the")) + assert(rows(0).getString(1).equals("quick")) + assert(rows(0).getString(2).equals("brown")) + assert(rows(0).getString(3).equals("fox")) + assert(rows(0).getString(4).equals("jumps")) + assert(rows(0).getString(5).equals("over")) + assert(java.util.Arrays.equals(rows(0).getAs[Array[Byte]](6), Array[Byte](116, 104, 101, 0))) + assert(java.util.Arrays.equals(rows(0).getAs[Array[Byte]](7), Array[Byte](108, 97, 122, 121))) + assert(java.util.Arrays.equals(rows(0).getAs[Array[Byte]](8), Array[Byte](100, 111, 103))) + } + + test("Basic write test") { + val df1 = sqlContext.read.jdbc(jdbcUrl, "numbers", new Properties) + val df2 = sqlContext.read.jdbc(jdbcUrl, "dates", new Properties) + val df3 = sqlContext.read.jdbc(jdbcUrl, "strings", new Properties) + df1.write.jdbc(jdbcUrl, "numberscopy", new Properties) + df2.write.jdbc(jdbcUrl, "datescopy", new Properties) + df3.write.jdbc(jdbcUrl, "stringscopy", new Properties) + } +} diff --git a/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala b/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala new file mode 100644 index 0000000000000..164a7f396280c --- /dev/null +++ b/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.jdbc + +import java.sql.Connection +import java.util.Properties + +import org.apache.spark.tags.DockerTest + +@DockerTest +class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite { + override val db = new DatabaseOnDocker { + override val imageName = "postgres:9.4.5" + override val env = Map( + "POSTGRES_PASSWORD" -> "rootpass" + ) + override val jdbcPort = 5432 + override def getJdbcUrl(ip: String, port: Int): String = + s"jdbc:postgresql://$ip:$port/postgres?user=postgres&password=rootpass" + } + + override def dataPreparation(conn: Connection): Unit = { + conn.prepareStatement("CREATE DATABASE foo").executeUpdate() + conn.setCatalog("foo") + conn.prepareStatement("CREATE TABLE bar (a text, b integer, c double precision, d bigint, " + + "e bit(1), f bit(10), g bytea, h boolean, i inet, j cidr)").executeUpdate() + conn.prepareStatement("INSERT INTO bar VALUES ('hello', 42, 1.25, 123456789012345, B'0', " + + "B'1000100101', E'\\\\xDEADBEEF', true, '172.16.0.42', '192.168.0.0/16')").executeUpdate() + } + + test("Type mapping for various types") { + val df = sqlContext.read.jdbc(jdbcUrl, "bar", new Properties) + val rows = df.collect() + assert(rows.length == 1) + val types = rows(0).toSeq.map(x => x.getClass.toString) + assert(types.length == 10) + assert(types(0).equals("class java.lang.String")) + assert(types(1).equals("class java.lang.Integer")) + assert(types(2).equals("class java.lang.Double")) + assert(types(3).equals("class java.lang.Long")) + assert(types(4).equals("class java.lang.Boolean")) + assert(types(5).equals("class [B")) + assert(types(6).equals("class [B")) + assert(types(7).equals("class java.lang.Boolean")) + assert(types(8).equals("class java.lang.String")) + assert(types(9).equals("class java.lang.String")) + assert(rows(0).getString(0).equals("hello")) + assert(rows(0).getInt(1) == 42) + assert(rows(0).getDouble(2) == 1.25) + assert(rows(0).getLong(3) == 123456789012345L) + assert(rows(0).getBoolean(4) == false) + // BIT(10)'s come back as ASCII strings of ten ASCII 0's and 1's... + assert(java.util.Arrays.equals(rows(0).getAs[Array[Byte]](5), + Array[Byte](49, 48, 48, 48, 49, 48, 48, 49, 48, 49))) + assert(java.util.Arrays.equals(rows(0).getAs[Array[Byte]](6), + Array[Byte](0xDE.toByte, 0xAD.toByte, 0xBE.toByte, 0xEF.toByte))) + assert(rows(0).getBoolean(7) == true) + assert(rows(0).getString(8) == "172.16.0.42") + assert(rows(0).getString(9) == "192.168.0.0/16") + } + + test("Basic write test") { + val df = sqlContext.read.jdbc(jdbcUrl, "bar", new Properties) + df.write.jdbc(jdbcUrl, "public.barcopy", new Properties) + // Test only that it doesn't crash. + } +} diff --git a/docker-integration-tests/src/test/scala/org/apache/spark/util/DockerUtils.scala b/docker-integration-tests/src/test/scala/org/apache/spark/util/DockerUtils.scala new file mode 100644 index 0000000000000..87271776d8564 --- /dev/null +++ b/docker-integration-tests/src/test/scala/org/apache/spark/util/DockerUtils.scala @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import java.net.{Inet4Address, NetworkInterface, InetAddress} + +import scala.collection.JavaConverters._ +import scala.sys.process._ +import scala.util.Try + +private[spark] object DockerUtils { + + def getDockerIp(): String = { + /** If docker-machine is setup on this box, attempts to find the ip from it. */ + def findFromDockerMachine(): Option[String] = { + sys.env.get("DOCKER_MACHINE_NAME").flatMap { name => + Try(Seq("/bin/bash", "-c", s"docker-machine ip $name 2>/dev/null").!!.trim).toOption + } + } + sys.env.get("DOCKER_IP") + .orElse(findFromDockerMachine()) + .orElse(Try(Seq("/bin/bash", "-c", "boot2docker ip 2>/dev/null").!!.trim).toOption) + .getOrElse { + // This block of code is based on Utils.findLocalInetAddress(), but is modified to blacklist + // certain interfaces. + val address = InetAddress.getLocalHost + // Address resolves to something like 127.0.1.1, which happens on Debian; try to find + // a better address using the local network interfaces + // getNetworkInterfaces returns ifs in reverse order compared to ifconfig output order + // on unix-like system. On windows, it returns in index order. + // It's more proper to pick ip address following system output order. + val blackListedIFs = Seq( + "vboxnet0", // Mac + "docker0" // Linux + ) + val activeNetworkIFs = NetworkInterface.getNetworkInterfaces.asScala.toSeq.filter { i => + !blackListedIFs.contains(i.getName) + } + val reOrderedNetworkIFs = activeNetworkIFs.reverse + for (ni <- reOrderedNetworkIFs) { + val addresses = ni.getInetAddresses.asScala + .filterNot(addr => addr.isLinkLocalAddress || addr.isLoopbackAddress).toSeq + if (addresses.nonEmpty) { + val addr = addresses.find(_.isInstanceOf[Inet4Address]).getOrElse(addresses.head) + // because of Inet6Address.toHostName may add interface at the end if it knows about it + val strippedAddress = InetAddress.getByAddress(addr.getAddress) + return strippedAddress.getHostAddress + } + } + address.getHostAddress + } + } +} diff --git a/pom.xml b/pom.xml index fd8c773513881..c499a80aa0f43 100644 --- a/pom.xml +++ b/pom.xml @@ -98,6 +98,7 @@ sql/catalyst sql/core sql/hive + docker-integration-tests unsafe assembly external/twitter @@ -778,6 +779,19 @@ 0.11 test + + com.spotify + docker-client + shaded + 3.2.1 + test + + + guava + com.google.guava + + + org.apache.curator curator-recipes diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index a9fb741d75933..b7c619224329f 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -43,8 +43,9 @@ object BuildCommons { "streaming-zeromq", "launcher", "unsafe", "test-tags").map(ProjectRef(buildLocation, _)) val optionallyEnabledProjects@Seq(yarn, yarnStable, java8Tests, sparkGangliaLgpl, - streamingKinesisAsl) = Seq("yarn", "yarn-stable", "java8-tests", "ganglia-lgpl", - "streaming-kinesis-asl").map(ProjectRef(buildLocation, _)) + streamingKinesisAsl, dockerIntegrationTests) = + Seq("yarn", "yarn-stable", "java8-tests", "ganglia-lgpl", "streaming-kinesis-asl", + "docker-integration-tests").map(ProjectRef(buildLocation, _)) val assemblyProjects@Seq(assembly, examples, networkYarn, streamingFlumeAssembly, streamingKafkaAssembly, streamingMqttAssembly, streamingKinesisAslAssembly) = Seq("assembly", "examples", "network-yarn", "streaming-flume-assembly", "streaming-kafka-assembly", "streaming-mqtt-assembly", "streaming-kinesis-asl-assembly") @@ -240,6 +241,8 @@ object SparkBuild extends PomBuild { enable(Flume.settings)(streamingFlumeSink) + enable(DockerIntegrationTests.settings)(dockerIntegrationTests) + /** * Adds the ability to run the spark shell directly from SBT without building an assembly @@ -291,6 +294,13 @@ object Flume { lazy val settings = sbtavro.SbtAvro.avroSettings } +object DockerIntegrationTests { + // This serves to override the override specified in DependencyOverrides: + lazy val settings = Seq( + dependencyOverrides += "com.google.guava" % "guava" % "18.0" + ) +} + /** * Overrides to work around sbt's dependency resolution being different from Maven's. */ diff --git a/tags/src/main/java/org/apache/spark/tags/DockerTest.java b/tags/src/main/java/org/apache/spark/tags/DockerTest.java new file mode 100644 index 0000000000000..0fecf3b8f979a --- /dev/null +++ b/tags/src/main/java/org/apache/spark/tags/DockerTest.java @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.tags; + +import java.lang.annotation.*; +import org.scalatest.TagAnnotation; + +@TagAnnotation +@Retention(RetentionPolicy.RUNTIME) +@Target({ElementType.METHOD, ElementType.TYPE}) +public @interface DockerTest { }