Skip to content

Commit

Permalink
[SPARK-7294] ADD BETWEEN
Browse files Browse the repository at this point in the history
  • Loading branch information
云峤 committed May 2, 2015
1 parent d11d5b9 commit baf839b
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 11 deletions.
7 changes: 3 additions & 4 deletions python/pyspark/sql/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1290,15 +1290,14 @@ def cast(self, dataType):
return Column(jc)

@ignore_unicode_prefix
def between(self, col1, col2):
def between(self, lowerBound, upperBound):
""" A boolean expression that is evaluated to true if the value of this
expression is between the given columns.
>>> df[df.col1.between(col2, col3)].collect()
>>> df[df.col1.between(lowerBound, upperBound)].collect()
[Row(col1=5, col2=6, col3=8)]
"""
#sc = SparkContext._active_spark_context
jc = self > col1 & self < col2
jc = (self >= lowerBound) & (self <= upperBound)
return Column(jc)

def __repr__(self):
Expand Down
4 changes: 2 additions & 2 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,8 +427,8 @@ def test_rand_functions(self):
assert row[1] >= -4.0 and row[1] <= 4.0, "got: %s" % row[1]

def test_between_function(self):
df = self.sqlCtx.parallelize([Row(a=1, b=2, c=3), Row(a=2, b=1, c=3), Row(a=4, b=1, c=3)]).toDF()
self.assertEqual([False, True, False],
df = self.sqlCtx.parallelize([Row(a=1, b=2, c=3), Row(a=2, b=1, c=3), Row(a=4, b=1, c=4)]).toDF()
self.assertEqual([False, True, True],
df.select(df.a.between(df.b, df.c)).collect())


Expand Down
13 changes: 9 additions & 4 deletions sql/core/src/main/scala/org/apache/spark/sql/Column.scala
Original file line number Diff line number Diff line change
Expand Up @@ -296,18 +296,23 @@ class Column(protected[sql] val expr: Expression) extends Logging {
def eqNullSafe(other: Any): Column = this <=> other

/**
* Between col1 and col2.
* True if the current column is between the lower bound and upper bound, inclusive.
*
* @group java_expr_ops
*/
def between(col1: String, col2: String): Column = between(Column(col1), Column(col2))
def between(lowerBound: String, upperBound: String): Column = {
between(Column(lowerBound), Column(upperBound))
}

/**
* Between col1 and col2.
* True if the current column is between the lower bound and upper bound, inclusive.
*
* @group java_expr_ops
*/
def between(col1: Column, col2: Column): Column = And(GreaterThan(this.expr, col1.expr), LessThan(this.expr, col2.expr))
def between(lowerBound: Column, upperBound: Column): Column = {
And(GreaterThanOrEqual(this.expr, lowerBound.expr),
LessThanOrEqual(this.expr, upperBound.expr))
}

/**
* True if the current expression is null.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ class ColumnExpressionSuite extends QueryTest {
test("between") {
checkAnswer(
testData4.filter($"a".between($"b", $"c")),
testData4.collect().toSeq.filter(r => r.getInt(0) > r.getInt(1) && r.getInt(0) < r.getInt(2)))
testData4.collect().toSeq.filter(r => r.getInt(0) >= r.getInt(1) && r.getInt(0) <= r.getInt(2)))
}

val booleanData = TestSQLContext.createDataFrame(TestSQLContext.sparkContext.parallelize(
Expand Down

0 comments on commit baf839b

Please sign in to comment.