Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1809,6 +1809,11 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging {
val children = fun.getArgumentsList.asScala.map(transformExpression)
Some(NullIndex(children(0)))

case "timestampdiff" if fun.getArgumentsCount == 3 =>
val children = fun.getArgumentsList.asScala.map(transformExpression)
val unit = extractString(children(0), "unit")
Some(TimestampDiff(unit, children(1), children(2)))

// ML-specific functions
case "vector_to_array" if fun.getArgumentsCount == 2 =>
val expr = transformExpression(fun.getArguments(0))
Expand Down
46 changes: 34 additions & 12 deletions python/pyspark/pandas/resample.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
Generic,
List,
Optional,
Union,
)

import numpy as np
Expand Down Expand Up @@ -65,6 +66,8 @@
scol_for,
verify_temp_column_name,
)
from pyspark.sql.utils import is_remote
from pyspark.pandas.spark.functions import timestampdiff


class Resampler(Generic[FrameLike], metaclass=ABCMeta):
Expand Down Expand Up @@ -131,8 +134,27 @@ def _resamplekey_scol(self) -> Column:
def _agg_columns_scols(self) -> List[Column]:
return [s.spark.column for s in self._agg_columns]

def get_make_interval( # type: ignore[return]
self, unit: str, col: Union[Column, int, float]
) -> Column:
if is_remote():
from pyspark.sql.connect.functions import lit, make_interval

col = col if not isinstance(col, (int, float)) else lit(col) # type: ignore[assignment]
if unit == "MONTH":
return make_interval(months=col) # type: ignore
if unit == "HOUR":
return make_interval(hours=col) # type: ignore
if unit == "MINUTE":
return make_interval(mins=col) # type: ignore
if unit == "SECOND":
return make_interval(secs=col) # type: ignore
else:
sql_utils = SparkContext._active_spark_context._jvm.PythonSQLUtils
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@itholic can we remove this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me check & create a follow-up

col = col._jc if isinstance(col, Column) else F.lit(col)._jc
return sql_utils.makeInterval(unit, col)

def _bin_time_stamp(self, origin: pd.Timestamp, ts_scol: Column) -> Column:
sql_utils = SparkContext._active_spark_context._jvm.PythonSQLUtils
origin_scol = F.lit(origin)
(rule_code, n) = (self._offset.rule_code, getattr(self._offset, "n"))
left_closed, right_closed = (self._closed == "left", self._closed == "right")
Expand Down Expand Up @@ -191,18 +213,18 @@ def _bin_time_stamp(self, origin: pd.Timestamp, ts_scol: Column) -> Column:
truncated_ts_scol = F.date_trunc("MONTH", ts_scol)
edge_label = truncated_ts_scol
if left_closed and right_labeled:
edge_label += sql_utils.makeInterval("MONTH", F.lit(n)._jc)
edge_label += self.get_make_interval("MONTH", n)
elif right_closed and left_labeled:
edge_label -= sql_utils.makeInterval("MONTH", F.lit(n)._jc)
edge_label -= self.get_make_interval("MONTH", n)

if left_labeled:
non_edge_label = F.when(
mod == 0,
truncated_ts_scol - sql_utils.makeInterval("MONTH", F.lit(n)._jc),
).otherwise(truncated_ts_scol - sql_utils.makeInterval("MONTH", mod._jc))
truncated_ts_scol - self.get_make_interval("MONTH", n),
).otherwise(truncated_ts_scol - self.get_make_interval("MONTH", mod))
else:
non_edge_label = F.when(mod == 0, truncated_ts_scol).otherwise(
truncated_ts_scol - sql_utils.makeInterval("MONTH", (mod - n)._jc)
truncated_ts_scol - self.get_make_interval("MONTH", mod - n)
)

return F.to_timestamp(
Expand Down Expand Up @@ -257,7 +279,7 @@ def _bin_time_stamp(self, origin: pd.Timestamp, ts_scol: Column) -> Column:
unit_str = unit_mapping[rule_code]

truncated_ts_scol = F.date_trunc(unit_str, ts_scol)
diff = sql_utils.timestampDiff(unit_str, origin_scol._jc, truncated_ts_scol._jc)
diff = timestampdiff(unit_str, origin_scol, truncated_ts_scol)
mod = F.lit(0) if n == 1 else (diff % F.lit(n))

if rule_code == "H":
Expand All @@ -271,19 +293,19 @@ def _bin_time_stamp(self, origin: pd.Timestamp, ts_scol: Column) -> Column:

edge_label = truncated_ts_scol
if left_closed and right_labeled:
edge_label += sql_utils.makeInterval(unit_str, F.lit(n)._jc)
edge_label += self.get_make_interval(unit_str, n)
elif right_closed and left_labeled:
edge_label -= sql_utils.makeInterval(unit_str, F.lit(n)._jc)
edge_label -= self.get_make_interval(unit_str, n)

if left_labeled:
non_edge_label = F.when(mod == 0, truncated_ts_scol).otherwise(
truncated_ts_scol - sql_utils.makeInterval(unit_str, mod._jc)
truncated_ts_scol - self.get_make_interval(unit_str, mod)
)
else:
non_edge_label = F.when(
mod == 0,
truncated_ts_scol + sql_utils.makeInterval(unit_str, F.lit(n)._jc),
).otherwise(truncated_ts_scol - sql_utils.makeInterval(unit_str, (mod - n)._jc))
truncated_ts_scol + self.get_make_interval(unit_str, n),
).otherwise(truncated_ts_scol - self.get_make_interval(unit_str, mod - n))

return F.when(edge_cond, edge_label).otherwise(non_edge_label)

Expand Down
16 changes: 16 additions & 0 deletions python/pyspark/pandas/spark/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,3 +185,19 @@ def null_index(col: Column) -> Column:
else:
sc = SparkContext._active_spark_context
return Column(sc._jvm.PythonSQLUtils.nullIndex(col._jc))


def timestampdiff(unit: str, start: Column, end: Column) -> Column:
if is_remote():
from pyspark.sql.connect.functions import _invoke_function_over_columns, lit

return _invoke_function_over_columns( # type: ignore[return-value]
"timestampdiff",
lit(unit),
start, # type: ignore[arg-type]
end, # type: ignore[arg-type]
)

else:
sc = SparkContext._active_spark_context
return Column(sc._jvm.PythonSQLUtils.timestampDiff(unit, start._jc, end._jc))
8 changes: 1 addition & 7 deletions python/pyspark/pandas/tests/connect/test_parity_resample.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,7 @@
class ResampleTestsParityMixin(
ResampleTestsMixin, PandasOnSparkTestUtils, TestUtils, ReusedConnectTestCase
):
@unittest.skip("TODO(SPARK-43660): Enable `resample` with Spark Connect.")
def test_dataframe_resample(self):
super().test_dataframe_resample()

@unittest.skip("TODO(SPARK-43660): Enable `resample` with Spark Connect.")
def test_series_resample(self):
super().test_series_resample()
pass


if __name__ == "__main__":
Expand Down