Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions python/pyspark/pandas/plot/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

from pyspark.sql import functions as F, Column
from pyspark.sql.types import DoubleType
from pyspark.pandas.spark import functions as SF
from pyspark.pandas.missing import unsupported_function
from pyspark.pandas.config import get_option
from pyspark.pandas.utils import name_like_string
Expand Down Expand Up @@ -437,6 +438,37 @@ def get_fliers(colname, outliers, min_val):

return fliers

@staticmethod
def get_multicol_fliers(colnames, multicol_outliers, multicol_whiskers):
scols = []
extract_colnames = []
for i, colname in enumerate(colnames):
formated_colname = "`{}`".format(colname)
outlier_colname = "__{}_outlier".format(colname)
min_val = multicol_whiskers[colname]["min"]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel it weird to select the outliers by the distance |value - lower_whisker|, which is used in series.boxplot.

It should be something like |value - median| or |value - mean|, will revisit this later.

pair_col = F.struct(
F.abs(F.col(formated_colname) - F.lit(min_val)).alias("ord"),
F.col(formated_colname).alias("val"),
)
scols.append(
SF.collect_top_k(
F.when(F.col(outlier_colname), pair_col)
.otherwise(F.lit(None))
.alias(f"pair_{i}"),
1001,
False,
).alias(f"top_{i}")
)
extract_colnames.append(f"top_{i}.val")

results = multicol_outliers.select(scols).select(extract_colnames).first()

fliers = {}
for i, colname in enumerate(colnames):
fliers[colname] = results[i]

return fliers


class KdePlotBase(NumericPlotBase):
@staticmethod
Expand Down
10 changes: 9 additions & 1 deletion python/pyspark/pandas/plot/plotly.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,11 +199,19 @@ def plot_box(data: Union["ps.DataFrame", "ps.Series"], **kwargs):
# Computes min and max values of non-outliers - the whiskers
whiskers = BoxPlotBase.calc_multicol_whiskers(numeric_column_names, outliers)

fliers = None
if boxpoints:
fliers = BoxPlotBase.get_multicol_fliers(numeric_column_names, outliers, whiskers)

i = 0
for colname in numeric_column_names:
col_stats = multicol_stats[colname]
col_whiskers = whiskers[colname]

col_fliers = None
if fliers is not None and colname in fliers and len(fliers[colname]) > 0:
col_fliers = [fliers[colname]]

fig.add_trace(
go.Box(
x=[i],
Expand All @@ -214,7 +222,7 @@ def plot_box(data: Union["ps.DataFrame", "ps.Series"], **kwargs):
mean=[col_stats["mean"]],
lowerfence=[col_whiskers["min"]],
upperfence=[col_whiskers["max"]],
y=None, # todo: support y=fliers
y=col_fliers,
boxpoints=boxpoints,
notched=notched,
**kwargs,
Expand Down
13 changes: 13 additions & 0 deletions python/pyspark/pandas/spark/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,19 @@ def null_index(col: Column) -> Column:
return Column(sc._jvm.PythonSQLUtils.nullIndex(col._jc))


def collect_top_k(col: Column, num: int, reverse: bool) -> Column:
if is_remote():
from pyspark.sql.connect.functions.builtin import _invoke_function_over_columns

return _invoke_function_over_columns("collect_top_k", col, F.lit(num), F.lit(reverse))

else:
from pyspark import SparkContext

sc = SparkContext._active_spark_context
return Column(sc._jvm.PythonSQLUtils.collect_top_k(col._jc, num, reverse))


def make_interval(unit: str, e: Union[Column, int, float]) -> Column:
unit_mapping = {
"YEAR": "years",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,9 @@ private[sql] object PythonSQLUtils extends Logging {

def nullIndex(e: Column): Column = Column.internalFn("null_index", e)

def collect_top_k(e: Column, num: Int, reverse: Boolean): Column =
Column.internalFn("collect_top_k", e, lit(num), lit(reverse))

def pandasProduct(e: Column, ignoreNA: Boolean): Column =
Column.internalFn("pandas_product", e, lit(ignoreNA))

Expand Down