diff --git a/python/pyspark/pandas/namespace.py b/python/pyspark/pandas/namespace.py index 1219b1b74b2ff..8df5d2cce5a53 100644 --- a/python/pyspark/pandas/namespace.py +++ b/python/pyspark/pandas/namespace.py @@ -73,6 +73,7 @@ align_diff_frames, default_session, is_name_like_tuple, + is_name_like_value, name_like_string, same_anchor, scol_for, @@ -83,11 +84,13 @@ InternalFrame, DEFAULT_SERIES_NAME, HIDDEN_COLUMNS, + SPARK_INDEX_NAME_FORMAT, ) from pyspark.pandas.series import Series, first_series from pyspark.pandas.spark import functions as SF from pyspark.pandas.spark.utils import as_nullable_spark_type, force_decimal_precision_scale from pyspark.pandas.indexes import Index, DatetimeIndex +from pyspark.pandas.indexes.multi import MultiIndex __all__ = [ @@ -115,6 +118,7 @@ "read_sql", "read_json", "merge", + "merge_asof", "to_numeric", "broadcast", "read_orc", @@ -2747,6 +2751,499 @@ def merge( ) +def merge_asof( + left: Union[DataFrame, Series], + right: Union[DataFrame, Series], + on: Optional[Name] = None, + left_on: Optional[Name] = None, + right_on: Optional[Name] = None, + left_index: bool = False, + right_index: bool = False, + by: Optional[Union[Name, List[Name]]] = None, + left_by: Optional[Union[Name, List[Name]]] = None, + right_by: Optional[Union[Name, List[Name]]] = None, + suffixes: Tuple[str, str] = ("_x", "_y"), + tolerance: Optional[Any] = None, + allow_exact_matches: bool = True, + direction: str = "backward", +) -> DataFrame: + """ + Perform an asof merge. + + This is similar to a left-join except that we match on nearest + key rather than equal keys. + + For each row in the left DataFrame: + + - A "backward" search selects the last row in the right DataFrame whose + 'on' key is less than or equal to the left's key. + + - A "forward" search selects the first row in the right DataFrame whose + 'on' key is greater than or equal to the left's key. + + - A "nearest" search selects the row in the right DataFrame whose 'on' + key is closest in absolute distance to the left's key. + + Optionally match on equivalent keys with 'by' before searching with 'on'. + + .. versionadded:: 3.3.0 + + Parameters + ---------- + left : DataFrame or named Series + right : DataFrame or named Series + on : label + Field name to join on. Must be found in both DataFrames. + The data MUST be ordered. Furthermore this must be a numeric column, + such as datetimelike, integer, or float. On or left_on/right_on + must be given. + left_on : label + Field name to join on in left DataFrame. + right_on : label + Field name to join on in right DataFrame. + left_index : bool + Use the index of the left DataFrame as the join key. + right_index : bool + Use the index of the right DataFrame as the join key. + by : column name or list of column names + Match on these columns before performing merge operation. + left_by : column name + Field names to match on in the left DataFrame. + right_by : column name + Field names to match on in the right DataFrame. + suffixes : 2-length sequence (tuple, list, ...) + Suffix to apply to overlapping column names in the left and right + side, respectively. + tolerance : int or Timedelta, optional, default None + Select asof tolerance within this range; must be compatible + with the merge index. + allow_exact_matches : bool, default True + + - If True, allow matching with the same 'on' value + (i.e. less-than-or-equal-to / greater-than-or-equal-to) + - If False, don't match the same 'on' value + (i.e., strictly less-than / strictly greater-than). + + direction : 'backward' (default), 'forward', or 'nearest' + Whether to search for prior, subsequent, or closest matches. + + Returns + ------- + merged : DataFrame + + See Also + -------- + merge : Merge with a database-style join. + merge_ordered : Merge with optional filling/interpolation. + + Examples + -------- + >>> left = ps.DataFrame({"a": [1, 5, 10], "left_val": ["a", "b", "c"]}) + >>> left + a left_val + 0 1 a + 1 5 b + 2 10 c + + >>> right = ps.DataFrame({"a": [1, 2, 3, 6, 7], "right_val": [1, 2, 3, 6, 7]}) + >>> right + a right_val + 0 1 1 + 1 2 2 + 2 3 3 + 3 6 6 + 4 7 7 + + >>> ps.merge_asof(left, right, on="a").sort_values("a").reset_index(drop=True) + a left_val right_val + 0 1 a 1 + 1 5 b 3 + 2 10 c 7 + + >>> ps.merge_asof( + ... left, + ... right, + ... on="a", + ... allow_exact_matches=False + ... ).sort_values("a").reset_index(drop=True) + a left_val right_val + 0 1 a NaN + 1 5 b 3.0 + 2 10 c 7.0 + + >>> ps.merge_asof( + ... left, + ... right, + ... on="a", + ... direction="forward" + ... ).sort_values("a").reset_index(drop=True) + a left_val right_val + 0 1 a 1.0 + 1 5 b 6.0 + 2 10 c NaN + + >>> ps.merge_asof( + ... left, + ... right, + ... on="a", + ... direction="nearest" + ... ).sort_values("a").reset_index(drop=True) + a left_val right_val + 0 1 a 1 + 1 5 b 6 + 2 10 c 7 + + We can use indexed DataFrames as well. + + >>> left = ps.DataFrame({"left_val": ["a", "b", "c"]}, index=[1, 5, 10]) + >>> left + left_val + 1 a + 5 b + 10 c + + >>> right = ps.DataFrame({"right_val": [1, 2, 3, 6, 7]}, index=[1, 2, 3, 6, 7]) + >>> right + right_val + 1 1 + 2 2 + 3 3 + 6 6 + 7 7 + + >>> ps.merge_asof(left, right, left_index=True, right_index=True).sort_index() + left_val right_val + 1 a 1 + 5 b 3 + 10 c 7 + + Here is a real-world times-series example + + >>> quotes = ps.DataFrame( + ... { + ... "time": [ + ... pd.Timestamp("2016-05-25 13:30:00.023"), + ... pd.Timestamp("2016-05-25 13:30:00.023"), + ... pd.Timestamp("2016-05-25 13:30:00.030"), + ... pd.Timestamp("2016-05-25 13:30:00.041"), + ... pd.Timestamp("2016-05-25 13:30:00.048"), + ... pd.Timestamp("2016-05-25 13:30:00.049"), + ... pd.Timestamp("2016-05-25 13:30:00.072"), + ... pd.Timestamp("2016-05-25 13:30:00.075") + ... ], + ... "ticker": [ + ... "GOOG", + ... "MSFT", + ... "MSFT", + ... "MSFT", + ... "GOOG", + ... "AAPL", + ... "GOOG", + ... "MSFT" + ... ], + ... "bid": [720.50, 51.95, 51.97, 51.99, 720.50, 97.99, 720.50, 52.01], + ... "ask": [720.93, 51.96, 51.98, 52.00, 720.93, 98.01, 720.88, 52.03] + ... } + ... ) + >>> quotes + time ticker bid ask + 0 2016-05-25 13:30:00.023 GOOG 720.50 720.93 + 1 2016-05-25 13:30:00.023 MSFT 51.95 51.96 + 2 2016-05-25 13:30:00.030 MSFT 51.97 51.98 + 3 2016-05-25 13:30:00.041 MSFT 51.99 52.00 + 4 2016-05-25 13:30:00.048 GOOG 720.50 720.93 + 5 2016-05-25 13:30:00.049 AAPL 97.99 98.01 + 6 2016-05-25 13:30:00.072 GOOG 720.50 720.88 + 7 2016-05-25 13:30:00.075 MSFT 52.01 52.03 + + >>> trades = ps.DataFrame( + ... { + ... "time": [ + ... pd.Timestamp("2016-05-25 13:30:00.023"), + ... pd.Timestamp("2016-05-25 13:30:00.038"), + ... pd.Timestamp("2016-05-25 13:30:00.048"), + ... pd.Timestamp("2016-05-25 13:30:00.048"), + ... pd.Timestamp("2016-05-25 13:30:00.048") + ... ], + ... "ticker": ["MSFT", "MSFT", "GOOG", "GOOG", "AAPL"], + ... "price": [51.95, 51.95, 720.77, 720.92, 98.0], + ... "quantity": [75, 155, 100, 100, 100] + ... } + ... ) + >>> trades + time ticker price quantity + 0 2016-05-25 13:30:00.023 MSFT 51.95 75 + 1 2016-05-25 13:30:00.038 MSFT 51.95 155 + 2 2016-05-25 13:30:00.048 GOOG 720.77 100 + 3 2016-05-25 13:30:00.048 GOOG 720.92 100 + 4 2016-05-25 13:30:00.048 AAPL 98.00 100 + + By default we are taking the asof of the quotes + + >>> ps.merge_asof( + ... trades, quotes, on="time", by="ticker" + ... ).sort_values(["time", "ticker", "price"]).reset_index(drop=True) + time ticker price quantity bid ask + 0 2016-05-25 13:30:00.023 MSFT 51.95 75 51.95 51.96 + 1 2016-05-25 13:30:00.038 MSFT 51.95 155 51.97 51.98 + 2 2016-05-25 13:30:00.048 AAPL 98.00 100 NaN NaN + 3 2016-05-25 13:30:00.048 GOOG 720.77 100 720.50 720.93 + 4 2016-05-25 13:30:00.048 GOOG 720.92 100 720.50 720.93 + + We only asof within 2ms between the quote time and the trade time + + >>> ps.merge_asof( + ... trades, + ... quotes, + ... on="time", + ... by="ticker", + ... tolerance=F.expr("INTERVAL 2 MILLISECONDS") # pd.Timedelta("2ms") + ... ).sort_values(["time", "ticker", "price"]).reset_index(drop=True) + time ticker price quantity bid ask + 0 2016-05-25 13:30:00.023 MSFT 51.95 75 51.95 51.96 + 1 2016-05-25 13:30:00.038 MSFT 51.95 155 NaN NaN + 2 2016-05-25 13:30:00.048 AAPL 98.00 100 NaN NaN + 3 2016-05-25 13:30:00.048 GOOG 720.77 100 720.50 720.93 + 4 2016-05-25 13:30:00.048 GOOG 720.92 100 720.50 720.93 + + We only asof within 10ms between the quote time and the trade time + and we exclude exact matches on time. However *prior* data will + propagate forward + + >>> ps.merge_asof( + ... trades, + ... quotes, + ... on="time", + ... by="ticker", + ... tolerance=F.expr("INTERVAL 10 MILLISECONDS"), # pd.Timedelta("10ms") + ... allow_exact_matches=False + ... ).sort_values(["time", "ticker", "price"]).reset_index(drop=True) + time ticker price quantity bid ask + 0 2016-05-25 13:30:00.023 MSFT 51.95 75 NaN NaN + 1 2016-05-25 13:30:00.038 MSFT 51.95 155 51.97 51.98 + 2 2016-05-25 13:30:00.048 AAPL 98.00 100 NaN NaN + 3 2016-05-25 13:30:00.048 GOOG 720.77 100 NaN NaN + 4 2016-05-25 13:30:00.048 GOOG 720.92 100 NaN NaN + """ + + def to_list(os: Optional[Union[Name, List[Name]]]) -> List[Label]: + if os is None: + return [] + elif is_name_like_tuple(os): + return [os] # type: ignore + elif is_name_like_value(os): + return [(os,)] + else: + return [o if is_name_like_tuple(o) else (o,) for o in os] + + if isinstance(left, Series): + left = left.to_frame() + if isinstance(right, Series): + right = right.to_frame() + + if on: + if left_on or right_on: + raise ValueError( + 'Can only pass argument "on" OR "left_on" and "right_on", ' + "not a combination of both." + ) + left_as_of_names = list(map(left._internal.spark_column_name_for, to_list(on))) + right_as_of_names = list(map(right._internal.spark_column_name_for, to_list(on))) + else: + if left_index: + if isinstance(left.index, MultiIndex): + raise ValueError("left can only have one index") + left_as_of_names = left._internal.index_spark_column_names + else: + left_as_of_names = list(map(left._internal.spark_column_name_for, to_list(left_on))) + if right_index: + if isinstance(right.index, MultiIndex): + raise ValueError("right can only have one index") + right_as_of_names = right._internal.index_spark_column_names + else: + right_as_of_names = list(map(right._internal.spark_column_name_for, to_list(right_on))) + + if left_as_of_names and not right_as_of_names: + raise ValueError("Must pass right_on or right_index=True") + if right_as_of_names and not left_as_of_names: + raise ValueError("Must pass left_on or left_index=True") + if not left_as_of_names and not right_as_of_names: + common = list(left.columns.intersection(right.columns)) + if len(common) == 0: + raise ValueError( + "No common columns to perform merge on. Merge options: " + "left_on=None, right_on=None, left_index=False, right_index=False" + ) + left_as_of_names = list(map(left._internal.spark_column_name_for, to_list(common))) + right_as_of_names = list(map(right._internal.spark_column_name_for, to_list(common))) + + if len(left_as_of_names) != 1: + raise ValueError("can only asof on a key for left") + if len(right_as_of_names) != 1: + raise ValueError("can only asof on a key for right") + + if by: + if left_by or right_by: + raise ValueError('Can only pass argument "on" OR "left_by" and "right_by".') + left_join_on_names = list(map(left._internal.spark_column_name_for, to_list(by))) + right_join_on_names = list(map(right._internal.spark_column_name_for, to_list(by))) + else: + left_join_on_names = list(map(left._internal.spark_column_name_for, to_list(left_by))) + right_join_on_names = list(map(right._internal.spark_column_name_for, to_list(right_by))) + + if left_join_on_names and not right_join_on_names: + raise ValueError("missing right_by") + if right_join_on_names and not left_join_on_names: + raise ValueError("missing left_by") + if len(left_join_on_names) != len(right_join_on_names): + raise ValueError("left_by and right_by must be same length") + + # We should distinguish the name to avoid ambiguous column name after merging. + right_prefix = "__right_" + right_as_of_names = [right_prefix + right_as_of_name for right_as_of_name in right_as_of_names] + right_join_on_names = [ + right_prefix + right_join_on_name for right_join_on_name in right_join_on_names + ] + + left_as_of_name = left_as_of_names[0] + right_as_of_name = right_as_of_names[0] + + def resolve(internal: InternalFrame, side: str) -> InternalFrame: + rename = lambda col: "__{}_{}".format(side, col) + internal = internal.resolved_copy + sdf = internal.spark_frame + sdf = sdf.select( + *[ + scol_for(sdf, col).alias(rename(col)) + for col in sdf.columns + if col not in HIDDEN_COLUMNS + ], + *HIDDEN_COLUMNS + ) + return internal.copy( + spark_frame=sdf, + index_spark_columns=[ + scol_for(sdf, rename(col)) for col in internal.index_spark_column_names + ], + index_fields=[field.copy(name=rename(field.name)) for field in internal.index_fields], + data_spark_columns=[ + scol_for(sdf, rename(col)) for col in internal.data_spark_column_names + ], + data_fields=[field.copy(name=rename(field.name)) for field in internal.data_fields], + ) + + left_internal = left._internal.resolved_copy + right_internal = resolve(right._internal, "right") + + left_table = left_internal.spark_frame.alias("left_table") + right_table = right_internal.spark_frame.alias("right_table") + + left_as_of_column = scol_for(left_table, left_as_of_name) + right_as_of_column = scol_for(right_table, right_as_of_name) + + if left_join_on_names: + left_join_on_columns = [scol_for(left_table, label) for label in left_join_on_names] + right_join_on_columns = [scol_for(right_table, label) for label in right_join_on_names] + on = reduce( + lambda l, r: l & r, + [l == r for l, r in zip(left_join_on_columns, right_join_on_columns)], + ) + else: + on = None + + if tolerance is not None and not isinstance(tolerance, Column): + tolerance = SF.lit(tolerance) + + as_of_joined_table = left_table._joinAsOf( + right_table, + leftAsOfColumn=left_as_of_column, + rightAsOfColumn=right_as_of_column, + on=on, + how="left", + tolerance=tolerance, + allowExactMatches=allow_exact_matches, + direction=direction, + ) + + # Unpack suffixes tuple for convenience + left_suffix = suffixes[0] + right_suffix = suffixes[1] + + # Append suffixes to columns with the same name to avoid conflicts later + duplicate_columns = set(left_internal.column_labels) & set(right_internal.column_labels) + + exprs = [] + data_columns = [] + column_labels = [] + + left_scol_for = lambda label: scol_for( + as_of_joined_table, left_internal.spark_column_name_for(label) + ) + right_scol_for = lambda label: scol_for( + as_of_joined_table, right_internal.spark_column_name_for(label) + ) + + for label in left_internal.column_labels: + col = left_internal.spark_column_name_for(label) + scol = left_scol_for(label) + if label in duplicate_columns: + spark_column_name = left_internal.spark_column_name_for(label) + if spark_column_name in (left_as_of_names + left_join_on_names) and ( + (right_prefix + spark_column_name) in (right_as_of_names + right_join_on_names) + ): + pass + else: + col = col + left_suffix + scol = scol.alias(col) + label = tuple([str(label[0]) + left_suffix] + list(label[1:])) + exprs.append(scol) + data_columns.append(col) + column_labels.append(label) + for label in right_internal.column_labels: + # recover `right_prefix` here. + col = right_internal.spark_column_name_for(label)[len(right_prefix) :] + scol = right_scol_for(label).alias(col) + if label in duplicate_columns: + spark_column_name = left_internal.spark_column_name_for(label) + if spark_column_name in left_as_of_names + left_join_on_names and ( + (right_prefix + spark_column_name) in right_as_of_names + right_join_on_names + ): + continue + else: + col = col + right_suffix + scol = scol.alias(col) + label = tuple([str(label[0]) + right_suffix] + list(label[1:])) + exprs.append(scol) + data_columns.append(col) + column_labels.append(label) + + # Retain indices if they are used for joining + if left_index or right_index: + index_spark_column_names = [ + SPARK_INDEX_NAME_FORMAT(i) for i in range(len(left_internal.index_spark_column_names)) + ] + left_index_scols = [ + scol.alias(name) + for scol, name in zip(left_internal.index_spark_columns, index_spark_column_names) + ] + exprs.extend(left_index_scols) + index_names = left_internal.index_names + else: + index_spark_column_names = [] + index_names = [] + + selected_columns = as_of_joined_table.select(*exprs) + + internal = InternalFrame( + spark_frame=selected_columns, + index_spark_columns=[scol_for(selected_columns, col) for col in index_spark_column_names], + index_names=index_names, + column_labels=column_labels, + data_spark_columns=[scol_for(selected_columns, col) for col in data_columns], + ) + return DataFrame(internal) + + @no_type_check def to_numeric(arg, errors="raise"): """ diff --git a/python/pyspark/pandas/tests/test_reshape.py b/python/pyspark/pandas/tests/test_reshape.py index 162ab78fd64ed..f2a0cb18e07de 100644 --- a/python/pyspark/pandas/tests/test_reshape.py +++ b/python/pyspark/pandas/tests/test_reshape.py @@ -24,6 +24,7 @@ from pyspark import pandas as ps from pyspark.pandas.utils import name_like_string +from pyspark.sql.utils import AnalysisException from pyspark.testing.pandasutils import PandasOnSparkTestCase @@ -283,6 +284,143 @@ def test_get_dummies_multiindex_columns(self): pd.get_dummies(pdf, columns=("x", 1), dtype=np.int8).rename(columns=name_like_string), ) + def test_merge_asof(self): + pdf_left = pd.DataFrame( + {"a": [1, 5, 10], "b": ["x", "y", "z"], "left_val": ["a", "b", "c"]}, index=[10, 20, 30] + ) + pdf_right = pd.DataFrame( + {"a": [1, 2, 3, 6, 7], "b": ["v", "w", "x", "y", "z"], "right_val": [1, 2, 3, 6, 7]}, + index=[100, 101, 102, 103, 104], + ) + psdf_left = ps.from_pandas(pdf_left) + psdf_right = ps.from_pandas(pdf_right) + + self.assert_eq( + pd.merge_asof(pdf_left, pdf_right, on="a").sort_values("a").reset_index(drop=True), + ps.merge_asof(psdf_left, psdf_right, on="a").sort_values("a").reset_index(drop=True), + ) + self.assert_eq( + ( + pd.merge_asof(pdf_left, pdf_right, left_on="a", right_on="a") + .sort_values("a") + .reset_index(drop=True) + ), + ( + ps.merge_asof(psdf_left, psdf_right, left_on="a", right_on="a") + .sort_values("a") + .reset_index(drop=True) + ), + ) + if LooseVersion(pd.__version__) >= LooseVersion("1.3"): + self.assert_eq( + pd.merge_asof( + pdf_left.set_index("a"), pdf_right, left_index=True, right_on="a" + ).sort_index(), + ps.merge_asof( + psdf_left.set_index("a"), psdf_right, left_index=True, right_on="a" + ).sort_index(), + ) + else: + expected = pd.DataFrame( + { + "b_x": ["x", "y", "z"], + "left_val": ["a", "b", "c"], + "a": [1, 3, 7], + "b_y": ["v", "x", "z"], + "right_val": [1, 3, 7], + }, + index=pd.Index([1, 5, 10], name="a"), + ) + self.assert_eq( + expected, + ps.merge_asof( + psdf_left.set_index("a"), psdf_right, left_index=True, right_on="a" + ).sort_index(), + ) + self.assert_eq( + pd.merge_asof( + pdf_left, pdf_right.set_index("a"), left_on="a", right_index=True + ).sort_index(), + ps.merge_asof( + psdf_left, psdf_right.set_index("a"), left_on="a", right_index=True + ).sort_index(), + ) + self.assert_eq( + pd.merge_asof( + pdf_left.set_index("a"), pdf_right.set_index("a"), left_index=True, right_index=True + ).sort_index(), + ps.merge_asof( + psdf_left.set_index("a"), + psdf_right.set_index("a"), + left_index=True, + right_index=True, + ).sort_index(), + ) + self.assert_eq( + ( + pd.merge_asof(pdf_left, pdf_right, on="a", by="b") + .sort_values("a") + .reset_index(drop=True) + ), + ( + ps.merge_asof(psdf_left, psdf_right, on="a", by="b") + .sort_values("a") + .reset_index(drop=True) + ), + ) + self.assert_eq( + ( + pd.merge_asof(pdf_left, pdf_right, on="a", tolerance=1) + .sort_values("a") + .reset_index(drop=True) + ), + ( + ps.merge_asof(psdf_left, psdf_right, on="a", tolerance=1) + .sort_values("a") + .reset_index(drop=True) + ), + ) + self.assert_eq( + ( + pd.merge_asof(pdf_left, pdf_right, on="a", allow_exact_matches=False) + .sort_values("a") + .reset_index(drop=True) + ), + ( + ps.merge_asof(psdf_left, psdf_right, on="a", allow_exact_matches=False) + .sort_values("a") + .reset_index(drop=True) + ), + ) + self.assert_eq( + ( + pd.merge_asof(pdf_left, pdf_right, on="a", direction="forward") + .sort_values("a") + .reset_index(drop=True) + ), + ( + ps.merge_asof(psdf_left, psdf_right, on="a", direction="forward") + .sort_values("a") + .reset_index(drop=True) + ), + ) + self.assert_eq( + ( + pd.merge_asof(pdf_left, pdf_right, on="a", direction="nearest") + .sort_values("a") + .reset_index(drop=True) + ), + ( + ps.merge_asof(psdf_left, psdf_right, on="a", direction="nearest") + .sort_values("a") + .reset_index(drop=True) + ), + ) + + self.assertRaises( + AnalysisException, lambda: ps.merge_asof(psdf_left, psdf_right, on="a", tolerance=-1) + ) + if __name__ == "__main__": import unittest diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 5a2e8cfaa1481..de289e1c1f67e 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1357,6 +1357,123 @@ def join(self, other, on=None, how=None): jdf = self._jdf.join(other._jdf, on, how) return DataFrame(jdf, self.sql_ctx) + # TODO(SPARK-22947): Fix the DataFrame API. + def _joinAsOf( + self, + other, + leftAsOfColumn, + rightAsOfColumn, + on=None, + how=None, + *, + tolerance=None, + allowExactMatches=True, + direction="backward", + ): + """ + Perform an as-of join. + + This is similar to a left-join except that we match on nearest + key rather than equal keys. + + .. versionadded:: 3.3.0 + + Parameters + ---------- + other : :class:`DataFrame` + Right side of the join + leftAsOfColumn : str or :class:`Column` + a string for the as-of join column name, or a Column + rightAsOfColumn : str or :class:`Column` + a string for the as-of join column name, or a Column + on : str, list or :class:`Column`, optional + a string for the join column name, a list of column names, + a join expression (Column), or a list of Columns. + If `on` is a string or a list of strings indicating the name of the join column(s), + the column(s) must exist on both sides, and this performs an equi-join. + how : str, optional + default ``inner``. Must be one of: ``inner`` and ``left``. + tolerance : :class:`Column`, optional + an asof tolerance within this range; must be compatible + with the merge index. + allowExactMatches : bool, optional + default ``True``. + direction : str, optional + default ``backward``. Must be one of: ``backward``, ``forward``, and ``nearest``. + + Examples + -------- + The following performs an as-of join between ``left`` and ``right``. + + >>> left = spark.createDataFrame([(1, "a"), (5, "b"), (10, "c")], ["a", "left_val"]) + >>> right = spark.createDataFrame([(1, 1), (2, 2), (3, 3), (6, 6), (7, 7)], + ... ["a", "right_val"]) + >>> left._joinAsOf( + ... right, leftAsOfColumn="a", rightAsOfColumn="a" + ... ).select(left.a, 'left_val', 'right_val').sort("a").collect() + [Row(a=1, left_val='a', right_val=1), + Row(a=5, left_val='b', right_val=3), + Row(a=10, left_val='c', right_val=7)] + + >>> from pyspark.sql import functions as F + >>> left._joinAsOf( + ... right, leftAsOfColumn="a", rightAsOfColumn="a", tolerance=F.lit(1) + ... ).select(left.a, 'left_val', 'right_val').sort("a").collect() + [Row(a=1, left_val='a', right_val=1)] + + >>> left._joinAsOf( + ... right, leftAsOfColumn="a", rightAsOfColumn="a", how="left", tolerance=F.lit(1) + ... ).select(left.a, 'left_val', 'right_val').sort("a").collect() + [Row(a=1, left_val='a', right_val=1), + Row(a=5, left_val='b', right_val=None), + Row(a=10, left_val='c', right_val=None)] + + >>> left._joinAsOf( + ... right, leftAsOfColumn="a", rightAsOfColumn="a", allowExactMatches=False + ... ).select(left.a, 'left_val', 'right_val').sort("a").collect() + [Row(a=5, left_val='b', right_val=3), + Row(a=10, left_val='c', right_val=7)] + + >>> left._joinAsOf( + ... right, leftAsOfColumn="a", rightAsOfColumn="a", direction="forward" + ... ).select(left.a, 'left_val', 'right_val').sort("a").collect() + [Row(a=1, left_val='a', right_val=1), + Row(a=5, left_val='b', right_val=6)] + """ + if isinstance(leftAsOfColumn, str): + leftAsOfColumn = self[leftAsOfColumn] + left_as_of_jcol = leftAsOfColumn._jc + if isinstance(rightAsOfColumn, str): + rightAsOfColumn = other[rightAsOfColumn] + right_as_of_jcol = rightAsOfColumn._jc + + if on is not None and not isinstance(on, list): + on = [on] + + if on is not None: + if isinstance(on[0], str): + on = self._jseq(on) + else: + assert isinstance(on[0], Column), "on should be Column or list of Column" + on = reduce(lambda x, y: x.__and__(y), on) + on = on._jc + + if how is None: + how = "inner" + assert isinstance(how, str), "how should be a string" + + if tolerance is not None: + assert isinstance(tolerance, Column), "tolerance should be Column" + tolerance = tolerance._jc + + jdf = self._jdf.joinAsOf( + other._jdf, + left_as_of_jcol, right_as_of_jcol, + on, + how, tolerance, allowExactMatches, direction + ) + return DataFrame(jdf, self.sql_ctx) + def sortWithinPartitions(self, *cols, **kwargs): """Returns a new :class:`DataFrame` with each partition sorted by the specified column(s). diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index b62e93416b371..c8614b1ea3b8c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -249,6 +249,20 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog { s"join condition '${condition.sql}' " + s"of type ${condition.dataType.catalogString} is not a boolean.") + case j @ AsOfJoin(_, _, _, Some(condition), _, _, _) + if condition.dataType != BooleanType => + failAnalysis( + s"join condition '${condition.sql}' " + + s"of type ${condition.dataType.catalogString} is not a boolean.") + + case j @ AsOfJoin(_, _, _, _, _, _, Some(toleranceAssertion)) => + if (!toleranceAssertion.foldable) { + failAnalysis("Input argument tolerance must be a constant.") + } + if (!toleranceAssertion.eval().asInstanceOf[Boolean]) { + failAnalysis("Input argument tolerance must be non-negative.") + } + case a @ Aggregate(groupingExprs, aggregateExprs, child) => def isAggregateExpression(expr: Expression): Boolean = { expr.isInstanceOf[AggregateExpression] || PythonUDF.isGroupedAggPandasUDF(expr) @@ -506,6 +520,15 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog { |Conflicting attributes: ${conflictingAttributes.mkString(",")} """.stripMargin) + case j: AsOfJoin if !j.duplicateResolved => + val conflictingAttributes = j.left.outputSet.intersect(j.right.outputSet) + failAnalysis( + s""" + |Failure when resolving conflicting references in AsOfJoin: + |$plan + |Conflicting attributes: ${conflictingAttributes.mkString(",")} + |""".stripMargin) + // TODO: although map type is not orderable, technically map type should be able to be // used in equality comparison, remove this type check once we support it. case o if mapColumnInSetOperation(o).isDefined => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala index 7b37891de2edf..5dfed394f31e5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala @@ -41,7 +41,8 @@ case class ReferenceEqualPlanWrapper(plan: LogicalPlan) { object DeduplicateRelations extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = { renewDuplicatedRelations(mutable.HashSet.empty, plan)._1.resolveOperatorsUpWithPruning( - _.containsAnyPattern(JOIN, LATERAL_JOIN, INTERSECT, EXCEPT, UNION, COMMAND), ruleId) { + _.containsAnyPattern(JOIN, LATERAL_JOIN, AS_OF_JOIN, INTERSECT, EXCEPT, UNION, COMMAND), + ruleId) { case p: LogicalPlan if !p.childrenResolved => p // To resolve duplicate expression IDs for Join. case j @ Join(left, right, _, _, _) if !j.duplicateResolved => @@ -49,6 +50,9 @@ object DeduplicateRelations extends Rule[LogicalPlan] { // Resolve duplicate output for LateralJoin. case j @ LateralJoin(left, right, _, _) if right.resolved && !j.duplicateResolved => j.copy(right = right.withNewPlan(dedupRight(left, right.plan))) + // Resolve duplicate output for AsOfJoin. + case j @ AsOfJoin(left, right, _, _, _, _, _) if !j.duplicateResolved => + j.copy(right = dedupRight(left, right)) // intersect/except will be rewritten to join at the beginning of optimizer. Here we need to // deduplicate the right side plan, so that we won't produce an invalid self-join later. case i @ Intersect(left, right, _) if !i.duplicateResolved => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index ed16185ae211e..b8c7fe752195c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -159,7 +159,8 @@ abstract class Optimizer(catalogManager: CatalogManager) PullOutGroupingExpressions, ComputeCurrentTime, ReplaceCurrentLike(catalogManager), - SpecialDatetimeValues) :: + SpecialDatetimeValues, + RewriteAsOfJoin) :: ////////////////////////////////////////////////////////////////////////////////////////// // Optimizer rules start here ////////////////////////////////////////////////////////////////////////////////////////// @@ -282,7 +283,8 @@ abstract class Optimizer(catalogManager: CatalogManager) RewritePredicateSubquery.ruleName :: NormalizeFloatingNumbers.ruleName :: ReplaceUpdateFieldsExpression.ruleName :: - PullOutGroupingExpressions.ruleName :: Nil + PullOutGroupingExpressions.ruleName :: + RewriteAsOfJoin.ruleName :: Nil /** * Optimize all the subqueries inside expression. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteAsOfJoin.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteAsOfJoin.scala new file mode 100644 index 0000000000000..bd93b50031cb9 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteAsOfJoin.scala @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules._ +import org.apache.spark.sql.catalyst.trees.TreePattern._ + +/** + * Replaces logical [[AsOfJoin]] operator using a combination of Join and Aggregate operator. + * + * Input Pseudo-Query: + * {{{ + * SELECT * FROM left ASOF JOIN right ON (condition, as_of on(left.t, right.t), tolerance) + * }}} + * + * Rewritten Query: + * {{{ + * SELECT left.*, __right__.* + * FROM ( + * SELECT + * left.*, + * ( + * SELECT MIN_BY(STRUCT(right.*), left.t - right.t) AS __nearest_right__ + * FROM right + * WHERE condition AND left.t >= right.t AND right.t >= left.t - tolerance + * ) as __right__ + * FROM left + * ) + * WHERE __right__ IS NOT NULL + * }}} + */ +object RewriteAsOfJoin extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning( + _.containsPattern(AS_OF_JOIN), ruleId) { + case AsOfJoin(left, right, asOfCondition, condition, joinType, orderExpression, _) => + val conditionWithOuterReference = + condition.map(And(_, asOfCondition)).getOrElse(asOfCondition).transformUp { + case a: AttributeReference if left.outputSet.contains(a) => + OuterReference(a) + } + val filtered = Filter(conditionWithOuterReference, right) + + val orderExpressionWithOuterReference = orderExpression.transformUp { + case a: AttributeReference if left.outputSet.contains(a) => + OuterReference(a) + } + val rightStruct = CreateStruct(right.output) + val nearestRight = MinBy(rightStruct, orderExpressionWithOuterReference) + .toAggregateExpression() + val aggExpr = Alias(nearestRight, "__nearest_right__")() + val aggregate = Aggregate(Seq.empty, Seq(aggExpr), filtered) + + val projectWithScalarSubquery = Project( + left.output :+ Alias(ScalarSubquery(aggregate, left.output), "__right__")(), + left) + + val filterRight = joinType match { + case LeftOuter => projectWithScalarSubquery + case _ => + Filter(IsNotNull(projectWithScalarSubquery.output.last), projectWithScalarSubquery) + } + + Project( + left.output ++ right.output.zipWithIndex.map { + case (out, idx) => + Alias(GetStructField(filterRight.output.last, idx), out.name)(exprId = out.exprId) + }, + filterRight) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala index da3cfb4c9de07..eeec3cd765ad0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala @@ -121,3 +121,24 @@ object LeftSemiOrAnti { case _ => None } } + +object AsOfJoinDirection { + + def apply(direction: String): AsOfJoinDirection = { + direction.toLowerCase(Locale.ROOT) match { + case "forward" => Forward + case "backward" => Backward + case "nearest" => Nearest + case _ => + val supported = Seq("forward", "backward", "nearest") + throw new IllegalArgumentException(s"Unsupported as-of join direction '$direction'. " + + "Supported as-of join direction include: " + supported.mkString("'", "', '", "'") + ".") + } + } +} + +sealed abstract class AsOfJoinDirection + +case object Forward extends AsOfJoinDirection +case object Backward extends AsOfJoinDirection +case object Nearest extends AsOfJoinDirection diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 269d18a276e1f..7b4c2bc1c61be 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -1597,3 +1597,119 @@ case class LateralJoin( copy(left = newChild) } } + +/** + * A logical plan for as-of join. + */ +case class AsOfJoin( + left: LogicalPlan, + right: LogicalPlan, + asOfCondition: Expression, + condition: Option[Expression], + joinType: JoinType, + orderExpression: Expression, + toleranceAssertion: Option[Expression]) extends BinaryNode { + + require(Seq(Inner, LeftOuter).contains(joinType), + s"Unsupported as-of join type $joinType") + + override protected def stringArgs: Iterator[Any] = super.stringArgs.take(5) + + override def output: Seq[Attribute] = { + joinType match { + case LeftOuter => + left.output ++ right.output.map(_.withNullability(true)) + case _ => + left.output ++ right.output + } + } + + def duplicateResolved: Boolean = left.outputSet.intersect(right.outputSet).isEmpty + + override lazy val resolved: Boolean = { + childrenResolved && + expressions.forall(_.resolved) && + duplicateResolved && + asOfCondition.dataType == BooleanType && + condition.forall(_.dataType == BooleanType) && + toleranceAssertion.forall { assertion => + assertion.foldable && assertion.eval().asInstanceOf[Boolean] + } + } + + final override val nodePatterns: Seq[TreePattern] = Seq(AS_OF_JOIN) + + override protected def withNewChildrenInternal( + newLeft: LogicalPlan, newRight: LogicalPlan): AsOfJoin = { + copy(left = newLeft, right = newRight) + } +} + +object AsOfJoin { + + def apply( + left: LogicalPlan, + right: LogicalPlan, + leftAsOf: Expression, + rightAsOf: Expression, + condition: Option[Expression], + joinType: JoinType, + tolerance: Option[Expression], + allowExactMatches: Boolean, + direction: AsOfJoinDirection): AsOfJoin = { + val asOfCond = makeAsOfCond(leftAsOf, rightAsOf, tolerance, allowExactMatches, direction) + val orderingExpr = makeOrderingExpr(leftAsOf, rightAsOf, direction) + AsOfJoin(left, right, asOfCond, condition, joinType, + orderingExpr, tolerance.map(t => GreaterThanOrEqual(t, Literal.default(t.dataType)))) + } + + private def makeAsOfCond( + leftAsOf: Expression, + rightAsOf: Expression, + tolerance: Option[Expression], + allowExactMatches: Boolean, + direction: AsOfJoinDirection): Expression = { + val base = (allowExactMatches, direction) match { + case (true, Backward) => GreaterThanOrEqual(leftAsOf, rightAsOf) + case (false, Backward) => GreaterThan(leftAsOf, rightAsOf) + case (true, Forward) => LessThanOrEqual(leftAsOf, rightAsOf) + case (false, Forward) => LessThan(leftAsOf, rightAsOf) + case (true, Nearest) => Literal.TrueLiteral + case (false, Nearest) => Not(EqualTo(leftAsOf, rightAsOf)) + } + tolerance match { + case Some(tolerance) => + (allowExactMatches, direction) match { + case (true, Backward) => + And(base, GreaterThanOrEqual(rightAsOf, Subtract(leftAsOf, tolerance))) + case (false, Backward) => + And(base, GreaterThan(rightAsOf, Subtract(leftAsOf, tolerance))) + case (true, Forward) => + And(base, LessThanOrEqual(rightAsOf, Add(leftAsOf, tolerance))) + case (false, Forward) => + And(base, LessThan(rightAsOf, Add(leftAsOf, tolerance))) + case (true, Nearest) => + And(GreaterThanOrEqual(rightAsOf, Subtract(leftAsOf, tolerance)), + LessThanOrEqual(rightAsOf, Add(leftAsOf, tolerance))) + case (false, Nearest) => + And(base, + And(GreaterThan(rightAsOf, Subtract(leftAsOf, tolerance)), + LessThan(rightAsOf, Add(leftAsOf, tolerance)))) + } + case None => base + } + } + + private def makeOrderingExpr( + leftAsOf: Expression, + rightAsOf: Expression, + direction: AsOfJoinDirection): Expression = { + direction match { + case Backward => Subtract(leftAsOf, rightAsOf) + case Forward => Subtract(rightAsOf, leftAsOf) + case Nearest => + If(GreaterThan(leftAsOf, rightAsOf), + Subtract(leftAsOf, rightAsOf), Subtract(rightAsOf, leftAsOf)) + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala index 2a05b8533bac1..d207ebc468973 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala @@ -144,6 +144,7 @@ object RuleIdCollection { "org.apache.spark.sql.catalyst.optimizer.ReplaceIntersectWithSemiJoin" :: "org.apache.spark.sql.catalyst.optimizer.RewriteExceptAll" :: "org.apache.spark.sql.catalyst.optimizer.RewriteIntersectAll" :: + "org.apache.spark.sql.catalyst.optimizer.RewriteAsOfJoin" :: "org.apache.spark.sql.catalyst.optimizer.SimplifyBinaryComparison" :: "org.apache.spark.sql.catalyst.optimizer.SimplifyCaseConversionExpressions" :: "org.apache.spark.sql.catalyst.optimizer.SimplifyCasts" :: diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala index bb57e5a898be2..6c1b64dd0af6e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala @@ -91,6 +91,7 @@ object TreePattern extends Enumeration { // Logical plan patterns (alphabetically ordered) val AGGREGATE: Value = Value + val AS_OF_JOIN: Value = Value val COMMAND: Value = Value val CTE: Value = Value val DISTINCT_LIKE: Value = Value diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteAsOfJoinSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteAsOfJoinSuite.scala new file mode 100644 index 0000000000000..41f8e25943d8d --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteAsOfJoinSuite.scala @@ -0,0 +1,289 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions.{CreateStruct, GetStructField, If, OuterReference, ScalarSubquery} +import org.apache.spark.sql.catalyst.expressions.aggregate.MinBy +import org.apache.spark.sql.catalyst.plans.{AsOfJoinDirection, Inner, LeftOuter, PlanTest} +import org.apache.spark.sql.catalyst.plans.logical.{AsOfJoin, LocalRelation} + +class RewriteAsOfJoinSuite extends PlanTest { + + test("simple") { + val left = LocalRelation('a.int, 'b.int, 'c.int) + val right = LocalRelation('a.int, 'b.int, 'd.int) + val query = AsOfJoin(left, right, left.output(0), right.output(0), None, Inner, + tolerance = None, allowExactMatches = true, direction = AsOfJoinDirection("backward")) + + val rewritten = RewriteAsOfJoin(query.analyze) + + val filter = OuterReference(left.output(0)) >= right.output(0) + val rightStruct = CreateStruct(right.output) + val orderExpression = OuterReference(left.output(0)) - right.output(0) + val nearestRight = MinBy(rightStruct, orderExpression) + .toAggregateExpression().as("__nearest_right__") + + val scalarSubquery = left.select( + left.output :+ ScalarSubquery( + right.where(filter).groupBy()(nearestRight), + left.output).as("__right__"): _*) + val correctAnswer = scalarSubquery + .where(scalarSubquery.output.last.isNotNull) + .select(left.output :+ + GetStructField(scalarSubquery.output.last, 0).as("a") :+ + GetStructField(scalarSubquery.output.last, 1).as("b") :+ + GetStructField(scalarSubquery.output.last, 2).as("d"): _*) + + comparePlans(rewritten, correctAnswer, checkAnalysis = false) + } + + test("condition") { + val left = LocalRelation('a.int, 'b.int, 'c.int) + val right = LocalRelation('a.int, 'b.int, 'd.int) + val query = AsOfJoin(left, right, left.output(0), right.output(0), + Some(left.output(1) === right.output(1)), Inner, + tolerance = None, allowExactMatches = true, direction = AsOfJoinDirection("backward")) + + val rewritten = RewriteAsOfJoin(query.analyze) + + val filter = OuterReference(left.output(1)) === right.output(1) && + OuterReference(left.output(0)) >= right.output(0) + val rightStruct = CreateStruct(right.output) + val orderExpression = OuterReference(left.output(0)) - right.output(0) + val nearestRight = MinBy(rightStruct, orderExpression) + .toAggregateExpression().as("__nearest_right__") + + val scalarSubquery = left.select( + left.output :+ ScalarSubquery( + right.where(filter).groupBy()(nearestRight), + left.output).as("__right__"): _*) + val correctAnswer = scalarSubquery + .where(scalarSubquery.output.last.isNotNull) + .select(left.output :+ + GetStructField(scalarSubquery.output.last, 0).as("a") :+ + GetStructField(scalarSubquery.output.last, 1).as("b") :+ + GetStructField(scalarSubquery.output.last, 2).as("d"): _*) + + comparePlans(rewritten, correctAnswer, checkAnalysis = false) + } + + test("left outer") { + val left = LocalRelation('a.int, 'b.int, 'c.int) + val right = LocalRelation('a.int, 'b.int, 'd.int) + val query = AsOfJoin(left, right, left.output(0), right.output(0), None, Inner, + tolerance = None, allowExactMatches = true, direction = AsOfJoinDirection("backward")) + + val rewritten = RewriteAsOfJoin(query.analyze) + + val filter = OuterReference(left.output(0)) >= right.output(0) + val rightStruct = CreateStruct(right.output) + val orderExpression = OuterReference(left.output(0)) - right.output(0) + val nearestRight = MinBy(rightStruct, orderExpression) + .toAggregateExpression().as("__nearest_right__") + + val scalarSubquery = left.select( + left.output :+ ScalarSubquery( + right.where(filter).groupBy()(nearestRight), + left.output).as("__right__"): _*) + val correctAnswer = scalarSubquery + .where(scalarSubquery.output.last.isNotNull) + .select(left.output :+ + GetStructField(scalarSubquery.output.last, 0).as("a") :+ + GetStructField(scalarSubquery.output.last, 1).as("b") :+ + GetStructField(scalarSubquery.output.last, 2).as("d"): _*) + + comparePlans(rewritten, correctAnswer, checkAnalysis = false) + } + + test("tolerance") { + val left = LocalRelation('a.int, 'b.int, 'c.int) + val right = LocalRelation('a.int, 'b.int, 'd.int) + val query = AsOfJoin(left, right, left.output(0), right.output(0), None, Inner, + tolerance = Some(1), allowExactMatches = true, direction = AsOfJoinDirection("backward")) + + val rewritten = RewriteAsOfJoin(query.analyze) + + val filter = OuterReference(left.output(0)) >= right.output(0) && + right.output(0) >= OuterReference(left.output(0)) - 1 + val rightStruct = CreateStruct(right.output) + val orderExpression = OuterReference(left.output(0)) - right.output(0) + val nearestRight = MinBy(rightStruct, orderExpression) + .toAggregateExpression().as("__nearest_right__") + + val scalarSubquery = left.select( + left.output :+ ScalarSubquery( + right.where(filter).groupBy()(nearestRight), + left.output).as("__right__"): _*) + val correctAnswer = scalarSubquery + .where(scalarSubquery.output.last.isNotNull) + .select(left.output :+ + GetStructField(scalarSubquery.output.last, 0).as("a") :+ + GetStructField(scalarSubquery.output.last, 1).as("b") :+ + GetStructField(scalarSubquery.output.last, 2).as("d"): _*) + + comparePlans(rewritten, correctAnswer, checkAnalysis = false) + } + + test("allowExactMatches = false") { + val left = LocalRelation('a.int, 'b.int, 'c.int) + val right = LocalRelation('a.int, 'b.int, 'd.int) + val query = AsOfJoin(left, right, left.output(0), right.output(0), None, LeftOuter, + tolerance = None, allowExactMatches = false, direction = AsOfJoinDirection("backward")) + + val rewritten = RewriteAsOfJoin(query.analyze) + + val filter = OuterReference(left.output(0)) > right.output(0) + val rightStruct = CreateStruct(right.output) + val orderExpression = OuterReference(left.output(0)) - right.output(0) + val nearestRight = MinBy(rightStruct, orderExpression) + .toAggregateExpression().as("__nearest_right__") + + val scalarSubquery = left.select( + left.output :+ ScalarSubquery( + right.where(filter).groupBy()(nearestRight), + left.output).as("__right__"): _*) + val correctAnswer = scalarSubquery + .select(left.output :+ + GetStructField(scalarSubquery.output.last, 0).as("a") :+ + GetStructField(scalarSubquery.output.last, 1).as("b") :+ + GetStructField(scalarSubquery.output.last, 2).as("d"): _*) + + comparePlans(rewritten, correctAnswer, checkAnalysis = false) + } + + test("tolerance & allowExactMatches = false") { + val left = LocalRelation('a.int, 'b.int, 'c.int) + val right = LocalRelation('a.int, 'b.int, 'd.int) + val query = AsOfJoin(left, right, left.output(0), right.output(0), None, Inner, + tolerance = Some(1), allowExactMatches = false, direction = AsOfJoinDirection("backward")) + + val rewritten = RewriteAsOfJoin(query.analyze) + + val filter = OuterReference(left.output(0)) > right.output(0) && + right.output(0) > OuterReference(left.output(0)) - 1 + val rightStruct = CreateStruct(right.output) + val orderExpression = OuterReference(left.output(0)) - right.output(0) + val nearestRight = MinBy(rightStruct, orderExpression) + .toAggregateExpression().as("__nearest_right__") + + val scalarSubquery = left.select( + left.output :+ ScalarSubquery( + right.where(filter).groupBy()(nearestRight), + left.output).as("__right__"): _*) + val correctAnswer = scalarSubquery + .where(scalarSubquery.output.last.isNotNull) + .select(left.output :+ + GetStructField(scalarSubquery.output.last, 0).as("a") :+ + GetStructField(scalarSubquery.output.last, 1).as("b") :+ + GetStructField(scalarSubquery.output.last, 2).as("d"): _*) + + comparePlans(rewritten, correctAnswer, checkAnalysis = false) + } + + test("direction = forward") { + val left = LocalRelation('a.int, 'b.int, 'c.int) + val right = LocalRelation('a.int, 'b.int, 'd.int) + val query = AsOfJoin(left, right, left.output(0), right.output(0), None, Inner, + tolerance = None, allowExactMatches = true, direction = AsOfJoinDirection("forward")) + + val rewritten = RewriteAsOfJoin(query.analyze) + + val filter = OuterReference(left.output(0)) <= right.output(0) + val rightStruct = CreateStruct(right.output) + val orderExpression = right.output(0) - OuterReference(left.output(0)) + val nearestRight = MinBy(rightStruct, orderExpression) + .toAggregateExpression().as("__nearest_right__") + + val scalarSubquery = left.select( + left.output :+ ScalarSubquery( + right.where(filter).groupBy()(nearestRight), + left.output).as("__right__"): _*) + val correctAnswer = scalarSubquery + .where(scalarSubquery.output.last.isNotNull) + .select(left.output :+ + GetStructField(scalarSubquery.output.last, 0).as("a") :+ + GetStructField(scalarSubquery.output.last, 1).as("b") :+ + GetStructField(scalarSubquery.output.last, 2).as("d"): _*) + + comparePlans(rewritten, correctAnswer, checkAnalysis = false) + } + + test("direction = nearest") { + val left = LocalRelation('a.int, 'b.int, 'c.int) + val right = LocalRelation('a.int, 'b.int, 'd.int) + val query = AsOfJoin(left, right, left.output(0), right.output(0), None, Inner, + tolerance = None, allowExactMatches = true, direction = AsOfJoinDirection("nearest")) + + val rewritten = RewriteAsOfJoin(query.analyze) + + val filter = true + val rightStruct = CreateStruct(right.output) + val orderExpression = If(OuterReference(left.output(0)) > right.output(0), + OuterReference(left.output(0)) - right.output(0), + right.output(0) - OuterReference(left.output(0))) + val nearestRight = MinBy(rightStruct, orderExpression) + .toAggregateExpression().as("__nearest_right__") + + val scalarSubquery = left.select( + left.output :+ ScalarSubquery( + right.where(filter).groupBy()(nearestRight), + left.output).as("__right__"): _*) + val correctAnswer = scalarSubquery + .where(scalarSubquery.output.last.isNotNull) + .select(left.output :+ + GetStructField(scalarSubquery.output.last, 0).as("a") :+ + GetStructField(scalarSubquery.output.last, 1).as("b") :+ + GetStructField(scalarSubquery.output.last, 2).as("d"): _*) + + comparePlans(rewritten, correctAnswer, checkAnalysis = false) + } + + test("tolerance & allowExactMatches = false & direction = nearest") { + val left = LocalRelation('a.int, 'b.int, 'c.int) + val right = LocalRelation('a.int, 'b.int, 'd.int) + val query = AsOfJoin(left, right, left.output(0), right.output(0), None, Inner, + tolerance = Some(1), allowExactMatches = false, direction = AsOfJoinDirection("nearest")) + + val rewritten = RewriteAsOfJoin(query.analyze) + + val filter = (!(OuterReference(left.output(0)) === right.output(0))) && + ((right.output(0) > OuterReference(left.output(0)) - 1) && + (right.output(0) < OuterReference(left.output(0)) + 1)) + val rightStruct = CreateStruct(right.output) + val orderExpression = If(OuterReference(left.output(0)) > right.output(0), + OuterReference(left.output(0)) - right.output(0), + right.output(0) - OuterReference(left.output(0))) + val nearestRight = MinBy(rightStruct, orderExpression) + .toAggregateExpression().as("__nearest_right__") + + val scalarSubquery = left.select( + left.output :+ ScalarSubquery( + right.where(filter).groupBy()(nearestRight), + left.output).as("__right__"): _*) + val correctAnswer = scalarSubquery + .where(scalarSubquery.output.last.isNotNull) + .select(left.output :+ + GetStructField(scalarSubquery.output.last, 0).as("a") :+ + GetStructField(scalarSubquery.output.last, 1).as("b") :+ + GetStructField(scalarSubquery.output.last, 2).as("d"): _*) + + comparePlans(rewritten, correctAnswer, checkAnalysis = false) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 1e85551990923..22e914ec9c45e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -1047,30 +1047,12 @@ class Dataset[T] private[sql]( } /** - * Join with another `DataFrame`, using the given join expression. The following performs - * a full outer join between `df1` and `df2`. - * - * {{{ - * // Scala: - * import org.apache.spark.sql.functions._ - * df1.join(df2, $"df1Key" === $"df2Key", "outer") - * - * // Java: - * import static org.apache.spark.sql.functions.*; - * df1.join(df2, col("df1Key").equalTo(col("df2Key")), "outer"); - * }}} - * - * @param right Right side of the join. - * @param joinExprs Join expression. - * @param joinType Type of join to perform. Default `inner`. Must be one of: - * `inner`, `cross`, `outer`, `full`, `fullouter`, `full_outer`, `left`, - * `leftouter`, `left_outer`, `right`, `rightouter`, `right_outer`, - * `semi`, `leftsemi`, `left_semi`, `anti`, `leftanti`, left_anti`. - * - * @group untypedrel - * @since 2.0.0 + * find the trivially true predicates and automatically resolves them to both sides. */ - def join(right: Dataset[_], joinExprs: Column, joinType: String): DataFrame = { + private def resolveSelfJoinCondition( + right: Dataset[_], + joinExprs: Option[Column], + joinType: String): Join = { // Note that in this function, we introduce a hack in the case of self-join to automatically // resolve ambiguous join conditions into ones that might make sense [SPARK-6231]. // Consider this case: df.join(df, df("key") === df("key")) @@ -1082,27 +1064,56 @@ class Dataset[T] private[sql]( // Trigger analysis so in the case of self-join, the analyzer will clone the plan. // After the cloning, left and right side will have distinct expression ids. val plan = withPlan( - Join(logicalPlan, right.logicalPlan, JoinType(joinType), Some(joinExprs.expr), JoinHint.NONE)) + Join(logicalPlan, right.logicalPlan, + JoinType(joinType), joinExprs.map(_.expr), JoinHint.NONE)) .queryExecution.analyzed.asInstanceOf[Join] // If auto self join alias is disabled, return the plan. if (!sparkSession.sessionState.conf.dataFrameSelfJoinAutoResolveAmbiguity) { - return withPlan(plan) + return plan } // If left/right have no output set intersection, return the plan. val lanalyzed = this.queryExecution.analyzed val ranalyzed = right.queryExecution.analyzed if (lanalyzed.outputSet.intersect(ranalyzed.outputSet).isEmpty) { - return withPlan(plan) + return plan } // Otherwise, find the trivially true predicates and automatically resolves them to both sides. // By the time we get here, since we have already run analysis, all attributes should've been // resolved and become AttributeReference. + resolveSelfJoinCondition(plan) + } + + /** + * Join with another `DataFrame`, using the given join expression. The following performs + * a full outer join between `df1` and `df2`. + * + * {{{ + * // Scala: + * import org.apache.spark.sql.functions._ + * df1.join(df2, $"df1Key" === $"df2Key", "outer") + * + * // Java: + * import static org.apache.spark.sql.functions.*; + * df1.join(df2, col("df1Key").equalTo(col("df2Key")), "outer"); + * }}} + * + * @param right Right side of the join. + * @param joinExprs Join expression. + * @param joinType Type of join to perform. Default `inner`. Must be one of: + * `inner`, `cross`, `outer`, `full`, `fullouter`, `full_outer`, `left`, + * `leftouter`, `left_outer`, `right`, `rightouter`, `right_outer`, + * `semi`, `leftsemi`, `left_semi`, `anti`, `leftanti`, left_anti`. + * + * @group untypedrel + * @since 2.0.0 + */ + def join(right: Dataset[_], joinExprs: Column, joinType: String): DataFrame = { withPlan { - resolveSelfJoinCondition(plan) + resolveSelfJoinCondition(right, Some(joinExprs), joinType) } } @@ -1232,6 +1243,58 @@ class Dataset[T] private[sql]( joinWith(other, condition, "inner") } + // TODO(SPARK-22947): Fix the DataFrame API. + private[sql] def joinAsOf( + other: Dataset[_], + leftAsOf: Column, + rightAsOf: Column, + usingColumns: Seq[String], + joinType: String, + tolerance: Column, + allowExactMatches: Boolean, + direction: String): DataFrame = { + val joinExprs = usingColumns.map { column => + EqualTo(resolve(column), other.resolve(column)) + }.reduceOption(And).map(Column.apply).orNull + + joinAsOf(other, leftAsOf, rightAsOf, joinExprs, joinType, + tolerance, allowExactMatches, direction) + } + + // TODO(SPARK-22947): Fix the DataFrame API. + private[sql] def joinAsOf( + other: Dataset[_], + leftAsOf: Column, + rightAsOf: Column, + joinExprs: Column, + joinType: String, + tolerance: Column, + allowExactMatches: Boolean, + direction: String): DataFrame = { + val joined = resolveSelfJoinCondition(other, Option(joinExprs), joinType) + val leftAsOfExpr = leftAsOf.expr.transformUp { + case a: AttributeReference if logicalPlan.outputSet.contains(a) => + val index = logicalPlan.output.indexWhere(_.exprId == a.exprId) + joined.left.output(index) + } + val rightAsOfExpr = rightAsOf.expr.transformUp { + case a: AttributeReference if other.logicalPlan.outputSet.contains(a) => + val index = other.logicalPlan.output.indexWhere(_.exprId == a.exprId) + joined.right.output(index) + } + withPlan { + AsOfJoin( + joined.left, joined.right, + leftAsOfExpr, rightAsOfExpr, + joined.condition, + joined.joinType, + Option(tolerance).map(_.expr), + allowExactMatches, + AsOfJoinDirection(direction) + ) + } + } + /** * Returns a new Dataset with each partition sorted by the given expressions. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAsOfJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAsOfJoinSuite.scala new file mode 100644 index 0000000000000..749efe95c5d2b --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAsOfJoinSuite.scala @@ -0,0 +1,169 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import scala.collection.JavaConverters._ + +import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.types._ + +class DataFrameAsOfJoinSuite extends QueryTest + with SharedSparkSession + with AdaptiveSparkPlanHelper { + + def prepareForAsOfJoin(): (DataFrame, DataFrame) = { + val schema1 = StructType( + StructField("a", IntegerType, false) :: + StructField("b", StringType, false) :: + StructField("left_val", StringType, false) :: Nil) + val rowSeq1: List[Row] = List(Row(1, "x", "a"), Row(5, "y", "b"), Row(10, "z", "c")) + val df1 = spark.createDataFrame(rowSeq1.asJava, schema1) + + val schema2 = StructType( + StructField("a", IntegerType) :: + StructField("b", StringType) :: + StructField("right_val", IntegerType) :: Nil) + val rowSeq2: List[Row] = List(Row(1, "v", 1), Row(2, "w", 2), Row(3, "x", 3), + Row(6, "y", 6), Row(7, "z", 7)) + val df2 = spark.createDataFrame(rowSeq2.asJava, schema2) + + (df1, df2) + } + + test("as-of join - simple") { + val (df1, df2) = prepareForAsOfJoin() + checkAnswer( + df1.joinAsOf( + df2, df1.col("a"), df2.col("a"), usingColumns = Seq.empty, + joinType = "inner", tolerance = null, allowExactMatches = true, direction = "backward"), + Seq( + Row(1, "x", "a", 1, "v", 1), + Row(5, "y", "b", 3, "x", 3), + Row(10, "z", "c", 7, "z", 7) + ) + ) + } + + test("as-of join - usingColumns") { + val (df1, df2) = prepareForAsOfJoin() + checkAnswer( + df1.joinAsOf(df2, df1.col("a"), df2.col("a"), usingColumns = Seq("b"), + joinType = "inner", tolerance = null, allowExactMatches = true, direction = "backward"), + Seq( + Row(10, "z", "c", 7, "z", 7) + ) + ) + } + + test("as-of join - usingColumns, left outer") { + val (df1, df2) = prepareForAsOfJoin() + checkAnswer( + df1.joinAsOf(df2, df1.col("a"), df2.col("a"), usingColumns = Seq("b"), + joinType = "left", tolerance = null, allowExactMatches = true, direction = "backward"), + Seq( + Row(1, "x", "a", null, null, null), + Row(5, "y", "b", null, null, null), + Row(10, "z", "c", 7, "z", 7) + ) + ) + } + + test("as-of join - tolerance = 1") { + val (df1, df2) = prepareForAsOfJoin() + checkAnswer( + df1.joinAsOf(df2, df1.col("a"), df2.col("a"), usingColumns = Seq.empty, + joinType = "inner", tolerance = lit(1), allowExactMatches = true, direction = "backward"), + Seq( + Row(1, "x", "a", 1, "v", 1) + ) + ) + } + + test("as-of join - tolerance should be a constant") { + val (df1, df2) = prepareForAsOfJoin() + val errMsg = intercept[AnalysisException] { + df1.joinAsOf( + df2, df1.col("a"), df2.col("a"), usingColumns = Seq.empty, + joinType = "inner", tolerance = df1.col("b"), allowExactMatches = true, + direction = "backward") + }.getMessage + assert(errMsg.contains("Input argument tolerance must be a constant.")) + } + + test("as-of join - tolerance should be non-negative") { + val (df1, df2) = prepareForAsOfJoin() + val errMsg = intercept[AnalysisException] { + df1.joinAsOf(df2, df1.col("a"), df2.col("a"), usingColumns = Seq.empty, + joinType = "inner", tolerance = lit(-1), allowExactMatches = true, direction = "backward") + }.getMessage + assert(errMsg.contains("Input argument tolerance must be non-negative.")) + } + + test("as-of join - allowExactMatches = false") { + val (df1, df2) = prepareForAsOfJoin() + checkAnswer( + df1.joinAsOf(df2, df1.col("a"), df2.col("a"), usingColumns = Seq.empty, + joinType = "inner", tolerance = null, allowExactMatches = false, direction = "backward"), + Seq( + Row(5, "y", "b", 3, "x", 3), + Row(10, "z", "c", 7, "z", 7) + ) + ) + } + + test("as-of join - direction = \"forward\"") { + val (df1, df2) = prepareForAsOfJoin() + checkAnswer( + df1.joinAsOf(df2, df1.col("a"), df2.col("a"), usingColumns = Seq.empty, + joinType = "inner", tolerance = null, allowExactMatches = true, direction = "forward"), + Seq( + Row(1, "x", "a", 1, "v", 1), + Row(5, "y", "b", 6, "y", 6) + ) + ) + } + + test("as-of join - direction = \"nearest\"") { + val (df1, df2) = prepareForAsOfJoin() + checkAnswer( + df1.joinAsOf(df2, df1.col("a"), df2.col("a"), usingColumns = Seq.empty, + joinType = "inner", tolerance = null, allowExactMatches = true, direction = "nearest"), + Seq( + Row(1, "x", "a", 1, "v", 1), + Row(5, "y", "b", 6, "y", 6), + Row(10, "z", "c", 7, "z", 7) + ) + ) + } + + test("as-of join - self") { + val (df1, _) = prepareForAsOfJoin() + checkAnswer( + df1.joinAsOf( + df1, df1.col("a"), df1.col("a"), usingColumns = Seq.empty, + joinType = "left", tolerance = null, allowExactMatches = false, direction = "nearest"), + Seq( + Row(1, "x", "a", 5, "y", "b"), + Row(5, "y", "b", 1, "x", "a"), + Row(10, "z", "c", 5, "y", "b") + ) + ) + } +}