From d11d5b95ef82a208c579daa0073bdc072a682be5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=BA=91=E5=B3=A4?= Date: Fri, 1 May 2015 23:50:12 +0800 Subject: [PATCH 01/10] [SPARK-7294] ADD BETWEEN --- python/pyspark/sql/dataframe.py | 12 ++++++++++++ python/pyspark/sql/tests.py | 6 ++++++ .../main/scala/org/apache/spark/sql/Column.scala | 14 ++++++++++++++ .../apache/spark/sql/ColumnExpressionSuite.scala | 6 ++++++ .../test/scala/org/apache/spark/sql/TestData.scala | 11 +++++++++++ 5 files changed, 49 insertions(+) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 5908ebc990a56..a4cbc7396e386 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1289,6 +1289,18 @@ def cast(self, dataType): raise TypeError("unexpected type: %s" % type(dataType)) return Column(jc) + @ignore_unicode_prefix + def between(self, col1, col2): + """ A boolean expression that is evaluated to true if the value of this + expression is between the given columns. + + >>> df[df.col1.between(col2, col3)].collect() + [Row(col1=5, col2=6, col3=8)] + """ + #sc = SparkContext._active_spark_context + jc = self > col1 & self < col2 + return Column(jc) + def __repr__(self): return 'Column<%s>' % self._jc.toString().encode('utf8') diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 5640bb5ea2346..206e3b7fd08f2 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -426,6 +426,12 @@ def test_rand_functions(self): for row in rndn: assert row[1] >= -4.0 and row[1] <= 4.0, "got: %s" % row[1] + def test_between_function(self): + df = self.sqlCtx.parallelize([Row(a=1, b=2, c=3), Row(a=2, b=1, c=3), Row(a=4, b=1, c=3)]).toDF() + self.assertEqual([False, True, False], + df.select(df.a.between(df.b, df.c)).collect()) + + def test_save_and_load(self): df = self.df tmpPath = tempfile.mkdtemp() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 33f9d0b37d006..8e0eab7918131 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -295,6 +295,20 @@ class Column(protected[sql] val expr: Expression) extends Logging { */ def eqNullSafe(other: Any): Column = this <=> other + /** + * Between col1 and col2. + * + * @group java_expr_ops + */ + def between(col1: String, col2: String): Column = between(Column(col1), Column(col2)) + + /** + * Between col1 and col2. + * + * @group java_expr_ops + */ + def between(col1: Column, col2: Column): Column = And(GreaterThan(this.expr, col1.expr), LessThan(this.expr, col2.expr)) + /** * True if the current expression is null. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index 6322faf4d9907..0a81f884e9a16 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -208,6 +208,12 @@ class ColumnExpressionSuite extends QueryTest { testData2.collect().toSeq.filter(r => r.getInt(0) <= r.getInt(1))) } + test("between") { + checkAnswer( + testData4.filter($"a".between($"b", $"c")), + testData4.collect().toSeq.filter(r => r.getInt(0) > r.getInt(1) && r.getInt(0) < r.getInt(2))) + } + val booleanData = TestSQLContext.createDataFrame(TestSQLContext.sparkContext.parallelize( Row(false, false) :: Row(false, true) :: diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala index 225b51bd73d6c..487d07249922f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala @@ -57,6 +57,17 @@ object TestData { TestData2(3, 2) :: Nil, 2).toDF() testData2.registerTempTable("testData2") + case class TestData4(a: Int, b: Int, c: Int) + val testData4 = + TestSQLContext.sparkContext.parallelize( + TestData4(0, 1, 2) :: + TestData4(1, 2, 3) :: + TestData4(2, 1, 0) :: + TestData4(2, 2, 4) :: + TestData4(3, 1, 6) :: + TestData4(3, 2, 0) :: Nil, 2).toDF() + testData4.registerTempTable("TestData4") + case class DecimalData(a: BigDecimal, b: BigDecimal) val decimalData = From baf839b4a4aa8d7d4ab8cdb1a5b82affd3ce376e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=BA=91=E5=B3=A4?= Date: Sat, 2 May 2015 09:39:17 +0800 Subject: [PATCH 02/10] [SPARK-7294] ADD BETWEEN --- python/pyspark/sql/dataframe.py | 7 +++---- python/pyspark/sql/tests.py | 4 ++-- .../main/scala/org/apache/spark/sql/Column.scala | 13 +++++++++---- .../apache/spark/sql/ColumnExpressionSuite.scala | 2 +- 4 files changed, 15 insertions(+), 11 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index a4cbc7396e386..8c09bf23f3cc0 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1290,15 +1290,14 @@ def cast(self, dataType): return Column(jc) @ignore_unicode_prefix - def between(self, col1, col2): + def between(self, lowerBound, upperBound): """ A boolean expression that is evaluated to true if the value of this expression is between the given columns. - >>> df[df.col1.between(col2, col3)].collect() + >>> df[df.col1.between(lowerBound, upperBound)].collect() [Row(col1=5, col2=6, col3=8)] """ - #sc = SparkContext._active_spark_context - jc = self > col1 & self < col2 + jc = (self >= lowerBound) & (self <= upperBound) return Column(jc) def __repr__(self): diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 206e3b7fd08f2..b5faedfe15e46 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -427,8 +427,8 @@ def test_rand_functions(self): assert row[1] >= -4.0 and row[1] <= 4.0, "got: %s" % row[1] def test_between_function(self): - df = self.sqlCtx.parallelize([Row(a=1, b=2, c=3), Row(a=2, b=1, c=3), Row(a=4, b=1, c=3)]).toDF() - self.assertEqual([False, True, False], + df = self.sqlCtx.parallelize([Row(a=1, b=2, c=3), Row(a=2, b=1, c=3), Row(a=4, b=1, c=4)]).toDF() + self.assertEqual([False, True, True], df.select(df.a.between(df.b, df.c)).collect()) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 8e0eab7918131..b51b6368eeb56 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -296,18 +296,23 @@ class Column(protected[sql] val expr: Expression) extends Logging { def eqNullSafe(other: Any): Column = this <=> other /** - * Between col1 and col2. + * True if the current column is between the lower bound and upper bound, inclusive. * * @group java_expr_ops */ - def between(col1: String, col2: String): Column = between(Column(col1), Column(col2)) + def between(lowerBound: String, upperBound: String): Column = { + between(Column(lowerBound), Column(upperBound)) + } /** - * Between col1 and col2. + * True if the current column is between the lower bound and upper bound, inclusive. * * @group java_expr_ops */ - def between(col1: Column, col2: Column): Column = And(GreaterThan(this.expr, col1.expr), LessThan(this.expr, col2.expr)) + def between(lowerBound: Column, upperBound: Column): Column = { + And(GreaterThanOrEqual(this.expr, lowerBound.expr), + LessThanOrEqual(this.expr, upperBound.expr)) + } /** * True if the current expression is null. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index 0a81f884e9a16..b63c1814adc3d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -211,7 +211,7 @@ class ColumnExpressionSuite extends QueryTest { test("between") { checkAnswer( testData4.filter($"a".between($"b", $"c")), - testData4.collect().toSeq.filter(r => r.getInt(0) > r.getInt(1) && r.getInt(0) < r.getInt(2))) + testData4.collect().toSeq.filter(r => r.getInt(0) >= r.getInt(1) && r.getInt(0) <= r.getInt(2))) } val booleanData = TestSQLContext.createDataFrame(TestSQLContext.sparkContext.parallelize( From 7d623680d2c726a53b9e36c78f654e34c40f3dba Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=BA=91=E5=B3=A4?= Date: Sat, 2 May 2015 14:17:10 +0800 Subject: [PATCH 03/10] [SPARK-7294] ADD BETWEEN --- python/pyspark/sql/dataframe.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 8c09bf23f3cc0..2538bd139bb3f 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1297,8 +1297,7 @@ def between(self, lowerBound, upperBound): >>> df[df.col1.between(lowerBound, upperBound)].collect() [Row(col1=5, col2=6, col3=8)] """ - jc = (self >= lowerBound) & (self <= upperBound) - return Column(jc) + return (self >= lowerBound) & (self <= upperBound) def __repr__(self): return 'Column<%s>' % self._jc.toString().encode('utf8') From f080f8d118f00e4f27936d55e74d391bac690c33 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=BA=91=E5=B3=A4?= Date: Sat, 2 May 2015 22:00:12 +0800 Subject: [PATCH 04/10] update pep8 --- python/pyspark/sql/tests.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index edf9f95a8ce65..000dab99ea730 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -439,7 +439,9 @@ def test_rand_functions(self): assert row[1] >= -4.0 and row[1] <= 4.0, "got: %s" % row[1] def test_between_function(self): - df = self.sqlCtx.parallelize([Row(a=1, b=2, c=3), Row(a=2, b=1, c=3), Row(a=4, b=1, c=4)]).toDF() + df = self.sqlCtx.parallelize([Row(a=1, b=2, c=3), + Row(a=2, b=1, c=3), + Row(a=4, b=1, c=4)]).toDF() self.assertEqual([False, True, True], df.select(df.a.between(df.b, df.c)).collect()) From 7b9b8583b25f3417e9b6c5672598d325678c7769 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=BA=91=E5=B3=A4?= Date: Sun, 3 May 2015 14:04:08 +0800 Subject: [PATCH 05/10] undo --- .../main/scala/org/apache/spark/sql/Column.scala | 3 +-- .../apache/spark/sql/ColumnExpressionSuite.scala | 14 +++++++++++--- .../test/scala/org/apache/spark/sql/TestData.scala | 11 ----------- 3 files changed, 12 insertions(+), 16 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index b51b6368eeb56..590c9c2db97a9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -310,8 +310,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { * @group java_expr_ops */ def between(lowerBound: Column, upperBound: Column): Column = { - And(GreaterThanOrEqual(this.expr, lowerBound.expr), - LessThanOrEqual(this.expr, upperBound.expr)) + (this >= lowerBound) && (this <= upperBound) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index b63c1814adc3d..dcea32f97c840 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -209,9 +209,17 @@ class ColumnExpressionSuite extends QueryTest { } test("between") { - checkAnswer( - testData4.filter($"a".between($"b", $"c")), - testData4.collect().toSeq.filter(r => r.getInt(0) >= r.getInt(1) && r.getInt(0) <= r.getInt(2))) + val testData = TestSQLContext.sparkContext.parallelize( + (0, 1, 2) :: + (1, 2, 3) :: + (2, 1, 0) :: + (2, 2, 4) :: + (3, 1, 6) :: + (3, 2, 0) :: Nil).toDF("a", "b", "c") + testData.registerTempTable("TestData4") + checkAnswer( + testData.filter($"a".between($"b", $"c")), + testData.collect().toSeq.filter(r => r.getInt(0) >= r.getInt(1) && r.getInt(0) <= r.getInt(2))) } val booleanData = TestSQLContext.createDataFrame(TestSQLContext.sparkContext.parallelize( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala index 487d07249922f..225b51bd73d6c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala @@ -57,17 +57,6 @@ object TestData { TestData2(3, 2) :: Nil, 2).toDF() testData2.registerTempTable("testData2") - case class TestData4(a: Int, b: Int, c: Int) - val testData4 = - TestSQLContext.sparkContext.parallelize( - TestData4(0, 1, 2) :: - TestData4(1, 2, 3) :: - TestData4(2, 1, 0) :: - TestData4(2, 2, 4) :: - TestData4(3, 1, 6) :: - TestData4(3, 2, 0) :: Nil, 2).toDF() - testData4.registerTempTable("TestData4") - case class DecimalData(a: BigDecimal, b: BigDecimal) val decimalData = From 7e64d1ecad965f56d5cb6887f64fd8cfd0b8263b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=BA=91=E5=B3=A4?= Date: Mon, 4 May 2015 14:19:27 +0800 Subject: [PATCH 06/10] Update --- python/pyspark/sql/tests.py | 1 - .../scala/org/apache/spark/sql/Column.scala | 11 +---------- .../spark/sql/ColumnExpressionSuite.scala | 18 +++++++++--------- 3 files changed, 10 insertions(+), 20 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 000dab99ea730..074eae5cb2b16 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -445,7 +445,6 @@ def test_between_function(self): self.assertEqual([False, True, True], df.select(df.a.between(df.b, df.c)).collect()) - def test_save_and_load(self): df = self.df tmpPath = tempfile.mkdtemp() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 590c9c2db97a9..c0503bf047052 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -300,16 +300,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { * * @group java_expr_ops */ - def between(lowerBound: String, upperBound: String): Column = { - between(Column(lowerBound), Column(upperBound)) - } - - /** - * True if the current column is between the lower bound and upper bound, inclusive. - * - * @group java_expr_ops - */ - def between(lowerBound: Column, upperBound: Column): Column = { + def between(lowerBound: Any, upperBound: Any): Column = { (this >= lowerBound) && (this <= upperBound) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index dcea32f97c840..3c1ad656fc855 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -211,15 +211,15 @@ class ColumnExpressionSuite extends QueryTest { test("between") { val testData = TestSQLContext.sparkContext.parallelize( (0, 1, 2) :: - (1, 2, 3) :: - (2, 1, 0) :: - (2, 2, 4) :: - (3, 1, 6) :: - (3, 2, 0) :: Nil).toDF("a", "b", "c") - testData.registerTempTable("TestData4") - checkAnswer( - testData.filter($"a".between($"b", $"c")), - testData.collect().toSeq.filter(r => r.getInt(0) >= r.getInt(1) && r.getInt(0) <= r.getInt(2))) + (1, 2, 3) :: + (2, 1, 0) :: + (2, 2, 4) :: + (3, 1, 6) :: + (3, 2, 0) :: Nil).toDF("a", "b", "c") + val expectAnswer = testData.collect().toSeq. + filter(r => r.getInt(0) >= r.getInt(1) && r.getInt(0) <= r.getInt(2)) + + checkAnswer(testData.filter($"a".between($"b", $"c")), expectAnswer) } val booleanData = TestSQLContext.createDataFrame(TestSQLContext.sparkContext.parallelize( From c54d90440ac911bc5bfb6439f53229232e403317 Mon Sep 17 00:00:00 2001 From: kaka1992 Date: Tue, 5 May 2015 10:41:38 +0800 Subject: [PATCH 07/10] Fix empty map bug. --- python/pyspark/sql/dataframe.py | 3 --- python/pyspark/sql/tests.py | 7 ++++--- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 5d1e7b630bf3a..66bd851af2b21 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1334,9 +1334,6 @@ def cast(self, dataType): def between(self, lowerBound, upperBound): """ A boolean expression that is evaluated to true if the value of this expression is between the given columns. - - >>> df[df.col1.between(lowerBound, upperBound)].collect() - [Row(col1=5, col2=6, col3=8)] """ return (self >= lowerBound) & (self <= upperBound) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 074eae5cb2b16..c48dba7ce4054 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -439,9 +439,10 @@ def test_rand_functions(self): assert row[1] >= -4.0 and row[1] <= 4.0, "got: %s" % row[1] def test_between_function(self): - df = self.sqlCtx.parallelize([Row(a=1, b=2, c=3), - Row(a=2, b=1, c=3), - Row(a=4, b=1, c=4)]).toDF() + df = self.sqlCtx.parallelize([ + Row(a=1, b=2, c=3), + Row(a=2, b=1, c=3), + Row(a=4, b=1, c=4)]).toDF() self.assertEqual([False, True, True], df.select(df.a.between(df.b, df.c)).collect()) From d2e7f722bbe32ec1b5e0adce2f749feb67102926 Mon Sep 17 00:00:00 2001 From: kaka1992 Date: Tue, 5 May 2015 11:11:52 +0800 Subject: [PATCH 08/10] Fix python style in sql/test. --- python/pyspark/sql/tests.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index c48dba7ce4054..846d3eea3e333 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -444,7 +444,7 @@ def test_between_function(self): Row(a=2, b=1, c=3), Row(a=4, b=1, c=4)]).toDF() self.assertEqual([False, True, True], - df.select(df.a.between(df.b, df.c)).collect()) + df.select(df.a.between(df.b, df.c)).collect()) def test_save_and_load(self): df = self.df From f92881631a8b73730f2acb08965ba66652c1d11e Mon Sep 17 00:00:00 2001 From: kaka1992 Date: Tue, 5 May 2015 15:15:34 +0800 Subject: [PATCH 09/10] Fix python style in sql/test. --- python/pyspark/sql/tests.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 846d3eea3e333..8aef22af3d445 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -439,7 +439,7 @@ def test_rand_functions(self): assert row[1] >= -4.0 and row[1] <= 4.0, "got: %s" % row[1] def test_between_function(self): - df = self.sqlCtx.parallelize([ + df = self.sc.parallelize([ Row(a=1, b=2, c=3), Row(a=2, b=1, c=3), Row(a=4, b=1, c=4)]).toDF() From b15360d4c6a56336ab9c8b0ca8b3a8467fe9a82e Mon Sep 17 00:00:00 2001 From: kaka1992 Date: Tue, 5 May 2015 17:06:20 +0800 Subject: [PATCH 10/10] Fix python unit test in sql/test. =_= I forget to commit this file last time. --- python/pyspark/sql/tests.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 8aef22af3d445..77a8bddea9985 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -443,8 +443,8 @@ def test_between_function(self): Row(a=1, b=2, c=3), Row(a=2, b=1, c=3), Row(a=4, b=1, c=4)]).toDF() - self.assertEqual([False, True, True], - df.select(df.a.between(df.b, df.c)).collect()) + self.assertEqual([Row(a=2, b=1, c=3), Row(a=4, b=1, c=4)], + df.filter(df.a.between(df.b, df.c)).collect()) def test_save_and_load(self): df = self.df