Skip to content

Commit

Permalink
[SPARK-36813][SQL][PYTHON] Propose an infrastructure of as-of join an…
Browse files Browse the repository at this point in the history
…d imlement ps.merge_asof

### What changes were proposed in this pull request?

Proposes an infrastructure for as-of join and implements `ps.merge_asof` here.

1. Introduce `AsOfJoin` logical plan
2. Rewrite the plan in the optimize phase:

- From something like (SQL syntax is not determied):

```sql
SELECT * FROM left ASOF JOIN right ON (condition, as_of on(left.t, right.t), tolerance)
```

- To

```sql
SELECT left.*, __right__.*
FROM (
     SELECT
          left.*,
          (
               SELECT MIN_BY(STRUCT(right.*), left.t - right.t) AS __nearest_right__
               FROM right
               WHERE condition AND left.t >= right.t AND right.t >= left.t - tolerance
          ) as __right__
     FROM left
     )
WHERE __right__ IS NOT NULL
```

3. The rewritten scalar-subquery will be handled by the existing decorrelation framework.

Note: APIs on SQL DataFrames and SQL syntax are TBD (e.g., [SPARK-22947](https://issues.apache.org/jira/browse/SPARK-22947)), although there are temporary APIs added here.

### Why are the changes needed?

Pandas' `merge_asof` or as-of join for SQL/DataFrame is useful for time series analysis.

### Does this PR introduce _any_ user-facing change?

Yes. `ps.merge_asof` can be used.

```py
>>> quotes
                     time ticker     bid     ask
0 2016-05-25 13:30:00.023   GOOG  720.50  720.93
1 2016-05-25 13:30:00.023   MSFT   51.95   51.96
2 2016-05-25 13:30:00.030   MSFT   51.97   51.98
3 2016-05-25 13:30:00.041   MSFT   51.99   52.00
4 2016-05-25 13:30:00.048   GOOG  720.50  720.93
5 2016-05-25 13:30:00.049   AAPL   97.99   98.01
6 2016-05-25 13:30:00.072   GOOG  720.50  720.88
7 2016-05-25 13:30:00.075   MSFT   52.01   52.03

>>> trades
                     time ticker   price  quantity
0 2016-05-25 13:30:00.023   MSFT   51.95        75
1 2016-05-25 13:30:00.038   MSFT   51.95       155
2 2016-05-25 13:30:00.048   GOOG  720.77       100
3 2016-05-25 13:30:00.048   GOOG  720.92       100
4 2016-05-25 13:30:00.048   AAPL   98.00       100

>>> ps.merge_asof(
...    trades, quotes, on="time", by="ticker"
... ).sort_values(["time", "ticker", "price"]).reset_index(drop=True)
                     time ticker   price  quantity     bid     ask
0 2016-05-25 13:30:00.023   MSFT   51.95        75   51.95   51.96
1 2016-05-25 13:30:00.038   MSFT   51.95       155   51.97   51.98
2 2016-05-25 13:30:00.048   AAPL   98.00       100     NaN     NaN
3 2016-05-25 13:30:00.048   GOOG  720.77       100  720.50  720.93
4 2016-05-25 13:30:00.048   GOOG  720.92       100  720.50  720.93

>>> ps.merge_asof(
...     trades,
...     quotes,
...     on="time",
...     by="ticker",
...     tolerance=F.expr("INTERVAL 2 MILLISECONDS")  # pd.Timedelta("2ms")
... ).sort_values(["time", "ticker", "price"]).reset_index(drop=True)
                     time ticker   price  quantity     bid     ask
0 2016-05-25 13:30:00.023   MSFT   51.95        75   51.95   51.96
1 2016-05-25 13:30:00.038   MSFT   51.95       155     NaN     NaN
2 2016-05-25 13:30:00.048   AAPL   98.00       100     NaN     NaN
3 2016-05-25 13:30:00.048   GOOG  720.77       100  720.50  720.93
4 2016-05-25 13:30:00.048   GOOG  720.92       100  720.50  720.93
```

Note: As `IntervalType` literal is not supported yet, we have to specify the `IntervalType` value with `F.expr` as a workaround.

### How was this patch tested?

Added tests.

Closes #34053 from ueshin/issues/SPARK-36813/merge_asof.

Authored-by: Takuya UESHIN <ueshin@databricks.com>
Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
  • Loading branch information
ueshin authored and HyukjinKwon committed Sep 29, 2021
1 parent a9b4c27 commit 05c0fa5
Show file tree
Hide file tree
Showing 14 changed files with 1,560 additions and 30 deletions.
497 changes: 497 additions & 0 deletions python/pyspark/pandas/namespace.py

Large diffs are not rendered by default.

138 changes: 138 additions & 0 deletions python/pyspark/pandas/tests/test_reshape.py
Expand Up @@ -24,6 +24,7 @@

from pyspark import pandas as ps
from pyspark.pandas.utils import name_like_string
from pyspark.sql.utils import AnalysisException
from pyspark.testing.pandasutils import PandasOnSparkTestCase


Expand Down Expand Up @@ -283,6 +284,143 @@ def test_get_dummies_multiindex_columns(self):
pd.get_dummies(pdf, columns=("x", 1), dtype=np.int8).rename(columns=name_like_string),
)

def test_merge_asof(self):
pdf_left = pd.DataFrame(
{"a": [1, 5, 10], "b": ["x", "y", "z"], "left_val": ["a", "b", "c"]}, index=[10, 20, 30]
)
pdf_right = pd.DataFrame(
{"a": [1, 2, 3, 6, 7], "b": ["v", "w", "x", "y", "z"], "right_val": [1, 2, 3, 6, 7]},
index=[100, 101, 102, 103, 104],
)
psdf_left = ps.from_pandas(pdf_left)
psdf_right = ps.from_pandas(pdf_right)

self.assert_eq(
pd.merge_asof(pdf_left, pdf_right, on="a").sort_values("a").reset_index(drop=True),
ps.merge_asof(psdf_left, psdf_right, on="a").sort_values("a").reset_index(drop=True),
)
self.assert_eq(
(
pd.merge_asof(pdf_left, pdf_right, left_on="a", right_on="a")
.sort_values("a")
.reset_index(drop=True)
),
(
ps.merge_asof(psdf_left, psdf_right, left_on="a", right_on="a")
.sort_values("a")
.reset_index(drop=True)
),
)
if LooseVersion(pd.__version__) >= LooseVersion("1.3"):
self.assert_eq(
pd.merge_asof(
pdf_left.set_index("a"), pdf_right, left_index=True, right_on="a"
).sort_index(),
ps.merge_asof(
psdf_left.set_index("a"), psdf_right, left_index=True, right_on="a"
).sort_index(),
)
else:
expected = pd.DataFrame(
{
"b_x": ["x", "y", "z"],
"left_val": ["a", "b", "c"],
"a": [1, 3, 7],
"b_y": ["v", "x", "z"],
"right_val": [1, 3, 7],
},
index=pd.Index([1, 5, 10], name="a"),
)
self.assert_eq(
expected,
ps.merge_asof(
psdf_left.set_index("a"), psdf_right, left_index=True, right_on="a"
).sort_index(),
)
self.assert_eq(
pd.merge_asof(
pdf_left, pdf_right.set_index("a"), left_on="a", right_index=True
).sort_index(),
ps.merge_asof(
psdf_left, psdf_right.set_index("a"), left_on="a", right_index=True
).sort_index(),
)
self.assert_eq(
pd.merge_asof(
pdf_left.set_index("a"), pdf_right.set_index("a"), left_index=True, right_index=True
).sort_index(),
ps.merge_asof(
psdf_left.set_index("a"),
psdf_right.set_index("a"),
left_index=True,
right_index=True,
).sort_index(),
)
self.assert_eq(
(
pd.merge_asof(pdf_left, pdf_right, on="a", by="b")
.sort_values("a")
.reset_index(drop=True)
),
(
ps.merge_asof(psdf_left, psdf_right, on="a", by="b")
.sort_values("a")
.reset_index(drop=True)
),
)
self.assert_eq(
(
pd.merge_asof(pdf_left, pdf_right, on="a", tolerance=1)
.sort_values("a")
.reset_index(drop=True)
),
(
ps.merge_asof(psdf_left, psdf_right, on="a", tolerance=1)
.sort_values("a")
.reset_index(drop=True)
),
)
self.assert_eq(
(
pd.merge_asof(pdf_left, pdf_right, on="a", allow_exact_matches=False)
.sort_values("a")
.reset_index(drop=True)
),
(
ps.merge_asof(psdf_left, psdf_right, on="a", allow_exact_matches=False)
.sort_values("a")
.reset_index(drop=True)
),
)
self.assert_eq(
(
pd.merge_asof(pdf_left, pdf_right, on="a", direction="forward")
.sort_values("a")
.reset_index(drop=True)
),
(
ps.merge_asof(psdf_left, psdf_right, on="a", direction="forward")
.sort_values("a")
.reset_index(drop=True)
),
)
self.assert_eq(
(
pd.merge_asof(pdf_left, pdf_right, on="a", direction="nearest")
.sort_values("a")
.reset_index(drop=True)
),
(
ps.merge_asof(psdf_left, psdf_right, on="a", direction="nearest")
.sort_values("a")
.reset_index(drop=True)
),
)

self.assertRaises(
AnalysisException, lambda: ps.merge_asof(psdf_left, psdf_right, on="a", tolerance=-1)
)


if __name__ == "__main__":
import unittest
Expand Down
117 changes: 117 additions & 0 deletions python/pyspark/sql/dataframe.py
Expand Up @@ -1357,6 +1357,123 @@ def join(self, other, on=None, how=None):
jdf = self._jdf.join(other._jdf, on, how)
return DataFrame(jdf, self.sql_ctx)

# TODO(SPARK-22947): Fix the DataFrame API.
def _joinAsOf(
self,
other,
leftAsOfColumn,
rightAsOfColumn,
on=None,
how=None,
*,
tolerance=None,
allowExactMatches=True,
direction="backward",
):
"""
Perform an as-of join.
This is similar to a left-join except that we match on nearest
key rather than equal keys.
.. versionadded:: 3.3.0
Parameters
----------
other : :class:`DataFrame`
Right side of the join
leftAsOfColumn : str or :class:`Column`
a string for the as-of join column name, or a Column
rightAsOfColumn : str or :class:`Column`
a string for the as-of join column name, or a Column
on : str, list or :class:`Column`, optional
a string for the join column name, a list of column names,
a join expression (Column), or a list of Columns.
If `on` is a string or a list of strings indicating the name of the join column(s),
the column(s) must exist on both sides, and this performs an equi-join.
how : str, optional
default ``inner``. Must be one of: ``inner`` and ``left``.
tolerance : :class:`Column`, optional
an asof tolerance within this range; must be compatible
with the merge index.
allowExactMatches : bool, optional
default ``True``.
direction : str, optional
default ``backward``. Must be one of: ``backward``, ``forward``, and ``nearest``.
Examples
--------
The following performs an as-of join between ``left`` and ``right``.
>>> left = spark.createDataFrame([(1, "a"), (5, "b"), (10, "c")], ["a", "left_val"])
>>> right = spark.createDataFrame([(1, 1), (2, 2), (3, 3), (6, 6), (7, 7)],
... ["a", "right_val"])
>>> left._joinAsOf(
... right, leftAsOfColumn="a", rightAsOfColumn="a"
... ).select(left.a, 'left_val', 'right_val').sort("a").collect()
[Row(a=1, left_val='a', right_val=1),
Row(a=5, left_val='b', right_val=3),
Row(a=10, left_val='c', right_val=7)]
>>> from pyspark.sql import functions as F
>>> left._joinAsOf(
... right, leftAsOfColumn="a", rightAsOfColumn="a", tolerance=F.lit(1)
... ).select(left.a, 'left_val', 'right_val').sort("a").collect()
[Row(a=1, left_val='a', right_val=1)]
>>> left._joinAsOf(
... right, leftAsOfColumn="a", rightAsOfColumn="a", how="left", tolerance=F.lit(1)
... ).select(left.a, 'left_val', 'right_val').sort("a").collect()
[Row(a=1, left_val='a', right_val=1),
Row(a=5, left_val='b', right_val=None),
Row(a=10, left_val='c', right_val=None)]
>>> left._joinAsOf(
... right, leftAsOfColumn="a", rightAsOfColumn="a", allowExactMatches=False
... ).select(left.a, 'left_val', 'right_val').sort("a").collect()
[Row(a=5, left_val='b', right_val=3),
Row(a=10, left_val='c', right_val=7)]
>>> left._joinAsOf(
... right, leftAsOfColumn="a", rightAsOfColumn="a", direction="forward"
... ).select(left.a, 'left_val', 'right_val').sort("a").collect()
[Row(a=1, left_val='a', right_val=1),
Row(a=5, left_val='b', right_val=6)]
"""
if isinstance(leftAsOfColumn, str):
leftAsOfColumn = self[leftAsOfColumn]
left_as_of_jcol = leftAsOfColumn._jc
if isinstance(rightAsOfColumn, str):
rightAsOfColumn = other[rightAsOfColumn]
right_as_of_jcol = rightAsOfColumn._jc

if on is not None and not isinstance(on, list):
on = [on]

if on is not None:
if isinstance(on[0], str):
on = self._jseq(on)
else:
assert isinstance(on[0], Column), "on should be Column or list of Column"
on = reduce(lambda x, y: x.__and__(y), on)
on = on._jc

if how is None:
how = "inner"
assert isinstance(how, str), "how should be a string"

if tolerance is not None:
assert isinstance(tolerance, Column), "tolerance should be Column"
tolerance = tolerance._jc

jdf = self._jdf.joinAsOf(
other._jdf,
left_as_of_jcol, right_as_of_jcol,
on,
how, tolerance, allowExactMatches, direction
)
return DataFrame(jdf, self.sql_ctx)

def sortWithinPartitions(self, *cols, **kwargs):
"""Returns a new :class:`DataFrame` with each partition sorted by the specified column(s).
Expand Down
Expand Up @@ -249,6 +249,20 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog {
s"join condition '${condition.sql}' " +
s"of type ${condition.dataType.catalogString} is not a boolean.")

case j @ AsOfJoin(_, _, _, Some(condition), _, _, _)
if condition.dataType != BooleanType =>
failAnalysis(
s"join condition '${condition.sql}' " +
s"of type ${condition.dataType.catalogString} is not a boolean.")

case j @ AsOfJoin(_, _, _, _, _, _, Some(toleranceAssertion)) =>
if (!toleranceAssertion.foldable) {
failAnalysis("Input argument tolerance must be a constant.")
}
if (!toleranceAssertion.eval().asInstanceOf[Boolean]) {
failAnalysis("Input argument tolerance must be non-negative.")
}

case a @ Aggregate(groupingExprs, aggregateExprs, child) =>
def isAggregateExpression(expr: Expression): Boolean = {
expr.isInstanceOf[AggregateExpression] || PythonUDF.isGroupedAggPandasUDF(expr)
Expand Down Expand Up @@ -506,6 +520,15 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog {
|Conflicting attributes: ${conflictingAttributes.mkString(",")}
""".stripMargin)

case j: AsOfJoin if !j.duplicateResolved =>
val conflictingAttributes = j.left.outputSet.intersect(j.right.outputSet)
failAnalysis(
s"""
|Failure when resolving conflicting references in AsOfJoin:
|$plan
|Conflicting attributes: ${conflictingAttributes.mkString(",")}
|""".stripMargin)

// TODO: although map type is not orderable, technically map type should be able to be
// used in equality comparison, remove this type check once we support it.
case o if mapColumnInSetOperation(o).isDefined =>
Expand Down
Expand Up @@ -41,14 +41,18 @@ case class ReferenceEqualPlanWrapper(plan: LogicalPlan) {
object DeduplicateRelations extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = {
renewDuplicatedRelations(mutable.HashSet.empty, plan)._1.resolveOperatorsUpWithPruning(
_.containsAnyPattern(JOIN, LATERAL_JOIN, INTERSECT, EXCEPT, UNION, COMMAND), ruleId) {
_.containsAnyPattern(JOIN, LATERAL_JOIN, AS_OF_JOIN, INTERSECT, EXCEPT, UNION, COMMAND),
ruleId) {
case p: LogicalPlan if !p.childrenResolved => p
// To resolve duplicate expression IDs for Join.
case j @ Join(left, right, _, _, _) if !j.duplicateResolved =>
j.copy(right = dedupRight(left, right))
// Resolve duplicate output for LateralJoin.
case j @ LateralJoin(left, right, _, _) if right.resolved && !j.duplicateResolved =>
j.copy(right = right.withNewPlan(dedupRight(left, right.plan)))
// Resolve duplicate output for AsOfJoin.
case j @ AsOfJoin(left, right, _, _, _, _, _) if !j.duplicateResolved =>
j.copy(right = dedupRight(left, right))
// intersect/except will be rewritten to join at the beginning of optimizer. Here we need to
// deduplicate the right side plan, so that we won't produce an invalid self-join later.
case i @ Intersect(left, right, _) if !i.duplicateResolved =>
Expand Down
Expand Up @@ -159,7 +159,8 @@ abstract class Optimizer(catalogManager: CatalogManager)
PullOutGroupingExpressions,
ComputeCurrentTime,
ReplaceCurrentLike(catalogManager),
SpecialDatetimeValues) ::
SpecialDatetimeValues,
RewriteAsOfJoin) ::
//////////////////////////////////////////////////////////////////////////////////////////
// Optimizer rules start here
//////////////////////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -282,7 +283,8 @@ abstract class Optimizer(catalogManager: CatalogManager)
RewritePredicateSubquery.ruleName ::
NormalizeFloatingNumbers.ruleName ::
ReplaceUpdateFieldsExpression.ruleName ::
PullOutGroupingExpressions.ruleName :: Nil
PullOutGroupingExpressions.ruleName ::
RewriteAsOfJoin.ruleName :: Nil

/**
* Optimize all the subqueries inside expression.
Expand Down

0 comments on commit 05c0fa5

Please sign in to comment.