-
Notifications
You must be signed in to change notification settings - Fork 29.1k
[SPARK-38937][PYTHON] interpolate support param limit_direction
#36246
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3253,16 +3253,17 @@ def ffill( | |
|
|
||
| pad = ffill | ||
|
|
||
| # TODO: add 'axis', 'inplace', 'limit_direction', 'limit_area', 'downcast' | ||
| # TODO: add 'axis', 'inplace', 'limit_area', 'downcast' | ||
| def interpolate( | ||
| self: FrameLike, | ||
| method: Optional[str] = None, | ||
| method: str = "linear", | ||
| limit: Optional[int] = None, | ||
| limit_direction: Optional[str] = None, | ||
| ) -> FrameLike: | ||
| """ | ||
| Fill NaN values using an interpolation method. | ||
|
|
||
| .. note:: the current implementation of rank uses Spark's Window without | ||
| .. note:: the current implementation of interpolate uses Spark's Window without | ||
| specifying partition specification. This leads to move all data into | ||
| single partition in single machine and could cause serious | ||
| performance degradation. Avoid this method against very large dataset. | ||
|
|
@@ -3281,6 +3282,10 @@ def interpolate( | |
| Maximum number of consecutive NaNs to fill. Must be greater than | ||
| 0. | ||
|
|
||
| limit_direction : str, default None | ||
| Consecutive NaNs will be filled in this direction. | ||
| One of {{'forward', 'backward', 'both'}}. | ||
|
||
|
|
||
| Returns | ||
| ------- | ||
| Series or DataFrame or None | ||
|
|
@@ -3335,7 +3340,7 @@ def interpolate( | |
| 2 2.0 3.0 -3.0 9.0 | ||
| 3 2.0 4.0 -4.0 16.0 | ||
| """ | ||
| return self.interpolate(method=method, limit=limit) | ||
| return self.interpolate(method=method, limit=limit, limit_direction=limit_direction) | ||
|
|
||
| @property | ||
| def at(self) -> AtIndexer: | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2169,18 +2169,28 @@ def _fillna( | |
| ) | ||
| )._psser_for(self._column_label) | ||
|
|
||
| def interpolate(self, method: Optional[str] = None, limit: Optional[int] = None) -> "Series": | ||
| return self._interpolate(method=method, limit=limit) | ||
| def interpolate( | ||
| self, | ||
| method: str = "linear", | ||
| limit: Optional[int] = None, | ||
| limit_direction: Optional[str] = None, | ||
| ) -> "Series": | ||
| return self._interpolate(method=method, limit=limit, limit_direction=limit_direction) | ||
|
|
||
| def _interpolate( | ||
| self, | ||
| method: Optional[str] = None, | ||
| method: str = "linear", | ||
| limit: Optional[int] = None, | ||
| limit_direction: Optional[str] = None, | ||
| ) -> "Series": | ||
| if (method is not None) and (method not in ["linear"]): | ||
| if method not in ["linear"]: | ||
| raise NotImplementedError("interpolate currently works only for method='linear'") | ||
| if (limit is not None) and (not limit > 0): | ||
| raise ValueError("limit must be > 0.") | ||
| if (limit_direction is not None) and ( | ||
| limit_direction not in ["forward", "backward", "both"] | ||
| ): | ||
| raise ValueError("invalid limit_direction: '{}'".format(limit_direction)) | ||
|
|
||
| if not self.spark.nullable and not isinstance( | ||
| self.spark.data_type, (FloatType, DoubleType) | ||
|
|
@@ -2209,15 +2219,50 @@ def _interpolate( | |
| ) * null_index_forward + last_non_null_forward | ||
|
|
||
| fill_cond = ~F.isnull(last_non_null_backward) & ~F.isnull(last_non_null_forward) | ||
| pad_cond = F.isnull(last_non_null_backward) & ~F.isnull(last_non_null_forward) | ||
| if limit is not None: | ||
| fill_cond = fill_cond & (null_index_forward <= F.lit(limit)) | ||
| pad_cond = pad_cond & (null_index_forward <= F.lit(limit)) | ||
|
|
||
| pad_head = SF.lit(None) | ||
|
||
| pad_head_cond = SF.lit(False) | ||
| pad_tail = SF.lit(None) | ||
| pad_tail_cond = SF.lit(False) | ||
|
|
||
| # inputs -> NaN, NaN, 1.0, NaN, NaN, NaN, 5.0, NaN, NaN | ||
| if limit_direction is None or limit_direction == "forward": | ||
| # outputs -> NaN, NaN, 1.0, 2.0, 3.0, 4.0, 5.0, 5.0, 5.0 | ||
| pad_tail = last_non_null_forward | ||
| pad_tail_cond = F.isnull(last_non_null_backward) & ~F.isnull(last_non_null_forward) | ||
| if limit is not None: | ||
| # outputs (limit=1) -> NaN, NaN, 1.0, 2.0, NaN, NaN, 5.0, 5.0, NaN | ||
| fill_cond = fill_cond & (null_index_forward <= F.lit(limit)) | ||
| pad_tail_cond = pad_tail_cond & (null_index_forward <= F.lit(limit)) | ||
|
|
||
| elif limit_direction == "backward": | ||
| # outputs -> 1.0, 1.0, 1.0, 2.0, 3.0, 4.0, 5.0, NaN, NaN | ||
| pad_head = last_non_null_backward | ||
| pad_head_cond = ~F.isnull(last_non_null_backward) & F.isnull(last_non_null_forward) | ||
| if limit is not None: | ||
| # outputs (limit=1) -> NaN, 1.0, 1.0, NaN, NaN, 4.0, 5.0, NaN, NaN | ||
| fill_cond = fill_cond & (null_index_backward <= F.lit(limit)) | ||
| pad_head_cond = pad_head_cond & (null_index_backward <= F.lit(limit)) | ||
|
|
||
| else: | ||
| # outputs -> 1.0, 1.0, 1.0, 2.0, 3.0, 4.0, 5.0, 5.0, 5.0 | ||
| pad_head = last_non_null_backward | ||
| pad_head_cond = ~F.isnull(last_non_null_backward) & F.isnull(last_non_null_forward) | ||
| pad_tail = last_non_null_forward | ||
| pad_tail_cond = F.isnull(last_non_null_backward) & ~F.isnull(last_non_null_forward) | ||
| if limit is not None: | ||
| # outputs (limit=1) -> NaN, 1.0, 1.0, 2.0, NaN, 4.0, 5.0, 5.0, NaN | ||
| fill_cond = fill_cond & ( | ||
| (null_index_forward <= F.lit(limit)) | (null_index_backward <= F.lit(limit)) | ||
| ) | ||
| pad_head_cond = pad_head_cond & (null_index_backward <= F.lit(limit)) | ||
| pad_tail_cond = pad_tail_cond & (null_index_forward <= F.lit(limit)) | ||
|
|
||
| cond = self.isnull().spark.column | ||
| scol = ( | ||
| F.when(cond & fill_cond, fill) | ||
| .when(cond & pad_cond, last_non_null_forward) | ||
| .when(cond & pad_head_cond, pad_head) | ||
| .when(cond & pad_tail_cond, pad_tail) | ||
| .otherwise(scol) | ||
| ) | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -33,17 +33,18 @@ def test_interpolate_error(self): | |
| with self.assertRaisesRegex(ValueError, "limit must be > 0"): | ||
| psdf.interpolate(limit=0) | ||
|
|
||
| def _test_series_interpolate(self, pser): | ||
| psser = ps.from_pandas(pser) | ||
| self.assert_eq(psser.interpolate(), pser.interpolate()) | ||
| for l1 in range(1, 5): | ||
| self.assert_eq(psser.interpolate(limit=l1), pser.interpolate(limit=l1)) | ||
|
|
||
| def _test_dataframe_interpolate(self, pdf): | ||
| psdf = ps.from_pandas(pdf) | ||
| self.assert_eq(psdf.interpolate(), pdf.interpolate()) | ||
| for l2 in range(1, 5): | ||
| self.assert_eq(psdf.interpolate(limit=l2), pdf.interpolate(limit=l2)) | ||
| with self.assertRaisesRegex(ValueError, "invalid limit_direction"): | ||
| psdf.interpolate(limit_direction="jump") | ||
|
|
||
| def _test_interpolate(self, pobj): | ||
|
||
| psobj = ps.from_pandas(pobj) | ||
| self.assert_eq(psobj.interpolate(), pobj.interpolate()) | ||
| for limit in range(1, 5): | ||
| for limit_direction in [None, "forward", "backward", "both"]: | ||
| self.assert_eq( | ||
| psobj.interpolate(limit=limit, limit_direction=limit_direction), | ||
| pobj.interpolate(limit=limit, limit_direction=limit_direction), | ||
| ) | ||
|
|
||
| def test_interpolate(self): | ||
| pser = pd.Series( | ||
|
|
@@ -54,7 +55,7 @@ def test_interpolate(self): | |
| ], | ||
| name="a", | ||
| ) | ||
| self._test_series_interpolate(pser) | ||
| self._test_interpolate(pser) | ||
|
|
||
| pser = pd.Series( | ||
| [ | ||
|
|
@@ -64,7 +65,7 @@ def test_interpolate(self): | |
| ], | ||
| name="a", | ||
| ) | ||
| self._test_series_interpolate(pser) | ||
| self._test_interpolate(pser) | ||
|
|
||
| pser = pd.Series( | ||
| [ | ||
|
|
@@ -84,7 +85,7 @@ def test_interpolate(self): | |
| ], | ||
| name="a", | ||
| ) | ||
| self._test_series_interpolate(pser) | ||
| self._test_interpolate(pser) | ||
|
|
||
| pdf = pd.DataFrame( | ||
| [ | ||
|
|
@@ -96,7 +97,7 @@ def test_interpolate(self): | |
| ], | ||
| columns=list("abc"), | ||
| ) | ||
| self._test_dataframe_interpolate(pdf) | ||
| self._test_interpolate(pdf) | ||
|
|
||
| pdf = pd.DataFrame( | ||
| [ | ||
|
|
@@ -108,7 +109,7 @@ def test_interpolate(self): | |
| ], | ||
| columns=list("abcde"), | ||
| ) | ||
| self._test_dataframe_interpolate(pdf) | ||
| self._test_interpolate(pdf) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
|
|
||

There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about adding
versionadded?Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the
interploate'sversionaddedis here (https://github.com/apache/spark/blob/master/python/pyspark/pandas/generic.py#L3270)https://github.com/apache/spark/pull/36246/files#diff-cb01df978681ad2406bdd22c29f73aa3020cfcd5c1f6e4052b4bc33f59160d33R3271, I think no need to add it for this param.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My bad I didn't notice that.