From 443713320178e1428a852d27511583289c73997b Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Fri, 23 Aug 2024 17:57:35 +0800 Subject: [PATCH 1/6] nit --- python/pyspark/pandas/plot/core.py | 72 +++++++++++++++++++ python/pyspark/pandas/plot/plotly.py | 10 ++- python/pyspark/pandas/spark/functions.py | 13 ++++ .../spark/sql/api/python/PythonSQLUtils.scala | 3 + 4 files changed, 97 insertions(+), 1 deletion(-) diff --git a/python/pyspark/pandas/plot/core.py b/python/pyspark/pandas/plot/core.py index c1dc7d2dc621..91bd242b706f 100644 --- a/python/pyspark/pandas/plot/core.py +++ b/python/pyspark/pandas/plot/core.py @@ -26,6 +26,7 @@ from pyspark.sql import functions as F, Column from pyspark.sql.types import DoubleType +from pyspark.pandas.spark import functions as SF from pyspark.pandas.missing import unsupported_function from pyspark.pandas.config import get_option from pyspark.pandas.utils import name_like_string @@ -427,6 +428,15 @@ def get_fliers(colname, outliers, min_val): # Here we normalize the values by subtracting the minimum value from # each, and use absolute values. order_col = F.abs(F.col("`{}`".format(colname)) - min_val.item()) + + print() + print() + print() + print(f"min_val = {min_val.item()}") + print() + print() + print() + fliers = ( fliers_df.select(F.col("`{}`".format(colname))) .orderBy(order_col) @@ -435,6 +445,68 @@ def get_fliers(colname, outliers, min_val): .values ) + print() + print() + print() + print(fliers) + print() + print() + print() + + return fliers + + @staticmethod + def get_multicol_fliers(colnames, multicol_outliers, multicol_whiskers): + print() + print() + print(f"multicol_whiskers = {multicol_whiskers}") + print() + print() + + scols = [] + extract_colnames = [] + for i, colname in enumerate(colnames): + formated_colname = "`{}`".format(colname) + outlier_colname = "__{}_outlier".format(colname) + min_val = multicol_whiskers[colname]["min"] + print(f"min_val = {min_val}") + pair_col = F.struct( + F.abs(F.col(formated_colname) - min_val).alias("ord"), + F.col(formated_colname).alias("val"), + ) + scols.append( + SF.collect_top_k( + F.when(F.col(outlier_colname), pair_col) + .otherwise(F.lit(None)) + .alias(f"pair_{i}"), + 1001, + False, + ).alias(f"top_{i}") + ) + extract_colnames.append(f"top_{i}.val") + + results = multicol_outliers.select(scols).select(extract_colnames).first() + + print() + print() + print() + print(results) + print() + print() + print() + + fliers = {} + for i, colname in enumerate(colnames): + fliers[colname] = results[i] + + print() + print() + print() + print(fliers) + print() + print() + print() + return fliers diff --git a/python/pyspark/pandas/plot/plotly.py b/python/pyspark/pandas/plot/plotly.py index 4de313b1e831..0f0c652840dc 100644 --- a/python/pyspark/pandas/plot/plotly.py +++ b/python/pyspark/pandas/plot/plotly.py @@ -199,11 +199,19 @@ def plot_box(data: Union["ps.DataFrame", "ps.Series"], **kwargs): # Computes min and max values of non-outliers - the whiskers whiskers = BoxPlotBase.calc_multicol_whiskers(numeric_column_names, outliers) + fliers = None + if boxpoints: + fliers = BoxPlotBase.get_multicol_fliers(numeric_column_names, outliers, whiskers) + i = 0 for colname in numeric_column_names: col_stats = multicol_stats[colname] col_whiskers = whiskers[colname] + col_fliers = None + if fliers is not None and len(fliers[colname]) > 0: + col_fliers = fliers[colname] + fig.add_trace( go.Box( x=[i], @@ -214,7 +222,7 @@ def plot_box(data: Union["ps.DataFrame", "ps.Series"], **kwargs): mean=[col_stats["mean"]], lowerfence=[col_whiskers["min"]], upperfence=[col_whiskers["max"]], - y=None, # todo: support y=fliers + y=None, boxpoints=boxpoints, notched=notched, **kwargs, diff --git a/python/pyspark/pandas/spark/functions.py b/python/pyspark/pandas/spark/functions.py index 8abeff655ea5..6bef3d9b87c0 100644 --- a/python/pyspark/pandas/spark/functions.py +++ b/python/pyspark/pandas/spark/functions.py @@ -174,6 +174,19 @@ def null_index(col: Column) -> Column: return Column(sc._jvm.PythonSQLUtils.nullIndex(col._jc)) +def collect_top_k(col: Column, num: int, reverse: bool) -> Column: + if is_remote(): + from pyspark.sql.connect.functions.builtin import _invoke_function_over_columns + + return _invoke_function_over_columns("collect_top_k", col, F.lit(num), F.lit(reverse)) + + else: + from pyspark import SparkContext + + sc = SparkContext._active_spark_context + return Column(sc._jvm.PythonSQLUtils.collect_top_k(col._jc, num, reverse)) + + def make_interval(unit: str, e: Union[Column, int, float]) -> Column: unit_mapping = { "YEAR": "years", diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala index 6b497553dcb0..c1c9af2ea427 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala @@ -149,6 +149,9 @@ private[sql] object PythonSQLUtils extends Logging { def nullIndex(e: Column): Column = Column.internalFn("null_index", e) + def collect_top_k(e: Column, num: Int, reverse: Boolean): Column = + Column.internalFn("collect_top_k", e, lit(num), lit(reverse)) + def pandasProduct(e: Column, ignoreNA: Boolean): Column = Column.internalFn("pandas_product", e, lit(ignoreNA)) From ca9994ebace1d44d8bca58aa476411bfce68931f Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Fri, 23 Aug 2024 18:11:20 +0800 Subject: [PATCH 2/6] fix type --- python/pyspark/pandas/plot/plotly.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/pyspark/pandas/plot/plotly.py b/python/pyspark/pandas/plot/plotly.py index 0f0c652840dc..f5ab71bff843 100644 --- a/python/pyspark/pandas/plot/plotly.py +++ b/python/pyspark/pandas/plot/plotly.py @@ -210,7 +210,7 @@ def plot_box(data: Union["ps.DataFrame", "ps.Series"], **kwargs): col_fliers = None if fliers is not None and len(fliers[colname]) > 0: - col_fliers = fliers[colname] + col_fliers = [fliers[colname]] fig.add_trace( go.Box( @@ -222,7 +222,7 @@ def plot_box(data: Union["ps.DataFrame", "ps.Series"], **kwargs): mean=[col_stats["mean"]], lowerfence=[col_whiskers["min"]], upperfence=[col_whiskers["max"]], - y=None, + y=col_fliers, boxpoints=boxpoints, notched=notched, **kwargs, From aba608a4ef813c5a9ea568b068ca338bba080bac Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Fri, 23 Aug 2024 18:12:07 +0800 Subject: [PATCH 3/6] fix type --- python/pyspark/pandas/plot/plotly.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/pandas/plot/plotly.py b/python/pyspark/pandas/plot/plotly.py index f5ab71bff843..0afcd6d7e869 100644 --- a/python/pyspark/pandas/plot/plotly.py +++ b/python/pyspark/pandas/plot/plotly.py @@ -209,7 +209,7 @@ def plot_box(data: Union["ps.DataFrame", "ps.Series"], **kwargs): col_whiskers = whiskers[colname] col_fliers = None - if fliers is not None and len(fliers[colname]) > 0: + if fliers is not None and colname in fliers and len(fliers[colname]) > 0: col_fliers = [fliers[colname]] fig.add_trace( From ef514c4ee6d9179ca0691d224d360bda574686d3 Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Mon, 26 Aug 2024 09:33:54 +0800 Subject: [PATCH 4/6] clean up --- python/pyspark/pandas/plot/core.py | 39 ------------------------------ 1 file changed, 39 deletions(-) diff --git a/python/pyspark/pandas/plot/core.py b/python/pyspark/pandas/plot/core.py index 91bd242b706f..b32ff5a66074 100644 --- a/python/pyspark/pandas/plot/core.py +++ b/python/pyspark/pandas/plot/core.py @@ -429,14 +429,6 @@ def get_fliers(colname, outliers, min_val): # each, and use absolute values. order_col = F.abs(F.col("`{}`".format(colname)) - min_val.item()) - print() - print() - print() - print(f"min_val = {min_val.item()}") - print() - print() - print() - fliers = ( fliers_df.select(F.col("`{}`".format(colname))) .orderBy(order_col) @@ -445,31 +437,16 @@ def get_fliers(colname, outliers, min_val): .values ) - print() - print() - print() - print(fliers) - print() - print() - print() - return fliers @staticmethod def get_multicol_fliers(colnames, multicol_outliers, multicol_whiskers): - print() - print() - print(f"multicol_whiskers = {multicol_whiskers}") - print() - print() - scols = [] extract_colnames = [] for i, colname in enumerate(colnames): formated_colname = "`{}`".format(colname) outlier_colname = "__{}_outlier".format(colname) min_val = multicol_whiskers[colname]["min"] - print(f"min_val = {min_val}") pair_col = F.struct( F.abs(F.col(formated_colname) - min_val).alias("ord"), F.col(formated_colname).alias("val"), @@ -487,26 +464,10 @@ def get_multicol_fliers(colnames, multicol_outliers, multicol_whiskers): results = multicol_outliers.select(scols).select(extract_colnames).first() - print() - print() - print() - print(results) - print() - print() - print() - fliers = {} for i, colname in enumerate(colnames): fliers[colname] = results[i] - print() - print() - print() - print(fliers) - print() - print() - print() - return fliers From 9548279d203ebea8319e15fc7c1962c17e79cdd2 Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Mon, 26 Aug 2024 10:07:32 +0800 Subject: [PATCH 5/6] nit --- python/pyspark/pandas/plot/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/pandas/plot/core.py b/python/pyspark/pandas/plot/core.py index b32ff5a66074..a5b52d7d9dc6 100644 --- a/python/pyspark/pandas/plot/core.py +++ b/python/pyspark/pandas/plot/core.py @@ -448,7 +448,7 @@ def get_multicol_fliers(colnames, multicol_outliers, multicol_whiskers): outlier_colname = "__{}_outlier".format(colname) min_val = multicol_whiskers[colname]["min"] pair_col = F.struct( - F.abs(F.col(formated_colname) - min_val).alias("ord"), + F.abs(F.col(formated_colname) - F.lit(min_val)).alias("ord"), F.col(formated_colname).alias("val"), ) scols.append( From a38a9d0f8c66cb4076a2118ad61ad04c00b80cae Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Mon, 26 Aug 2024 10:13:32 +0800 Subject: [PATCH 6/6] nit --- python/pyspark/pandas/plot/core.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/pyspark/pandas/plot/core.py b/python/pyspark/pandas/plot/core.py index a5b52d7d9dc6..2e188b411df1 100644 --- a/python/pyspark/pandas/plot/core.py +++ b/python/pyspark/pandas/plot/core.py @@ -428,7 +428,6 @@ def get_fliers(colname, outliers, min_val): # Here we normalize the values by subtracting the minimum value from # each, and use absolute values. order_col = F.abs(F.col("`{}`".format(colname)) - min_val.item()) - fliers = ( fliers_df.select(F.col("`{}`".format(colname))) .orderBy(order_col)