From 5a4dd81c043c96225add86d24dd062443867422d Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Mon, 19 Sep 2022 12:58:48 +0800 Subject: [PATCH 1/4] refactor corr --- python/pyspark/pandas/correlation.py | 262 +++++++++++++++++++++++ python/pyspark/pandas/frame.py | 303 ++++++--------------------- 2 files changed, 323 insertions(+), 242 deletions(-) create mode 100644 python/pyspark/pandas/correlation.py diff --git a/python/pyspark/pandas/correlation.py b/python/pyspark/pandas/correlation.py new file mode 100644 index 0000000000000..75d3a857a0f2c --- /dev/null +++ b/python/pyspark/pandas/correlation.py @@ -0,0 +1,262 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import List + +from pyspark.sql import DataFrame as SparkDataFrame, functions as F +from pyspark.sql.window import Window + +from pyspark.pandas.utils import verify_temp_column_name + + +CORRELATION_VALUE_1_COLUMN = "__correlation_value_1_input__" +CORRELATION_VALUE_2_COLUMN = "__correlation_value_2_input__" +CORRELATION_CORR_OUTPUT_COLUMN = "__correlation_corr_output__" +CORRELATION_COUNT_OUTPUT_COLUMN = "__correlation_count_output__" + + +def compute(sdf: SparkDataFrame, groupKeys: List[str], method: str) -> SparkDataFrame: + """ + Compute correlation per group, excluding NA/null values. + + Input PySpark Dataframe should contain column `CORRELATION_VALUE_1_COLUMN` and + column `CORRELATION_VALUE_2_COLUMN`, as well as the group columns. + + The returned PySpark Dataframe will contain the correlation column + `CORRELATION_CORR_OUTPUT_COLUMN` and the non-null count column + `CORRELATION_COUNT_OUTPUT_COLUMN`, as well as the group columns. + """ + assert len(groupKeys) > 0 + assert method in ["pearson", "spearman", "kendall"] + + sdf = sdf.select( + *[F.col(key) for key in groupKeys], + *[ + # assign both columns nulls, if some of them are null + F.when( + F.isnull(CORRELATION_VALUE_1_COLUMN) | F.isnull(CORRELATION_VALUE_2_COLUMN), + F.lit(None), + ) + .otherwise(F.col(CORRELATION_VALUE_1_COLUMN)) + .alias(CORRELATION_VALUE_1_COLUMN), + F.when( + F.isnull(CORRELATION_VALUE_1_COLUMN) | F.isnull(CORRELATION_VALUE_2_COLUMN), + F.lit(None), + ) + .otherwise(F.col(CORRELATION_VALUE_2_COLUMN)) + .alias(CORRELATION_VALUE_2_COLUMN), + ], + ) + + if method in ["pearson", "spearman"]: + # convert values to avg ranks for spearman correlation + if method == "spearman": + ROW_NUMBER_COLUMN = verify_temp_column_name( + sdf, "__correlation_spearman_row_number_temp_column__" + ) + DENSE_RANK_COLUMN = verify_temp_column_name( + sdf, "__correlation_spearman_dense_rank_temp_column__" + ) + window = Window.partitionBy(groupKeys) + + # CORRELATION_VALUE_1_COLUMN: value -> avg rank + # for example: + # values: 3, 4, 5, 7, 7, 7, 9, 9, 10 + # avg ranks: 1.0, 2.0, 3.0, 5.0, 5.0, 5.0, 7.5, 7.5, 9.0 + sdf = ( + sdf.withColumn( + ROW_NUMBER_COLUMN, + F.row_number().over( + window.orderBy(F.asc_nulls_last(CORRELATION_VALUE_1_COLUMN)) + ), + ) + # drop nulls but make sure each group contains at least one row + .where(~F.isnull(CORRELATION_VALUE_1_COLUMN) | (F.col(ROW_NUMBER_COLUMN) == 1)) + .withColumn( + DENSE_RANK_COLUMN, + F.dense_rank().over( + window.orderBy(F.asc_nulls_last(CORRELATION_VALUE_1_COLUMN)) + ), + ) + .withColumn( + CORRELATION_VALUE_1_COLUMN, + F.when(F.isnull(CORRELATION_VALUE_1_COLUMN), F.lit(None)).otherwise( + F.avg(ROW_NUMBER_COLUMN).over( + window.orderBy(F.asc(DENSE_RANK_COLUMN)).rangeBetween(0, 0) + ) + ), + ) + ) + + # CORRELATION_VALUE_2_COLUMN: value -> avg rank + sdf = ( + sdf.withColumn( + ROW_NUMBER_COLUMN, + F.row_number().over( + window.orderBy(F.asc_nulls_last(CORRELATION_VALUE_2_COLUMN)) + ), + ) + .withColumn( + DENSE_RANK_COLUMN, + F.dense_rank().over( + window.orderBy(F.asc_nulls_last(CORRELATION_VALUE_2_COLUMN)) + ), + ) + .withColumn( + CORRELATION_VALUE_2_COLUMN, + F.when(F.isnull(CORRELATION_VALUE_2_COLUMN), F.lit(None)).otherwise( + F.avg(ROW_NUMBER_COLUMN).over( + window.orderBy(F.asc(DENSE_RANK_COLUMN)).rangeBetween(0, 0) + ) + ), + ) + ) + + sdf = sdf.groupby(groupKeys).agg( + F.corr(CORRELATION_VALUE_1_COLUMN, CORRELATION_VALUE_2_COLUMN).alias( + CORRELATION_CORR_OUTPUT_COLUMN + ), + F.count( + F.when( + ~F.isnull(CORRELATION_VALUE_1_COLUMN), + 1, + ) + ).alias(CORRELATION_COUNT_OUTPUT_COLUMN), + ) + + return sdf + + else: + # kendall correlation + ROW_NUMBER_1_2_COLUMN = verify_temp_column_name( + sdf, "__correlation_kendall_row_number_1_2_temp_column__" + ) + sdf = sdf.withColumn( + ROW_NUMBER_1_2_COLUMN, + F.row_number().over( + Window.partitionBy(groupKeys).orderBy( + F.asc_nulls_last(CORRELATION_VALUE_1_COLUMN), + F.asc_nulls_last(CORRELATION_VALUE_2_COLUMN), + ) + ), + ) + + # drop nulls but make sure each group contains at least one row + sdf = sdf.where(~F.isnull(CORRELATION_VALUE_1_COLUMN) | (F.col(ROW_NUMBER_1_2_COLUMN) == 1)) + + CORRELATION_VALUE_X_COLUMN = verify_temp_column_name( + sdf, "__correlation_kendall_value_x_temp_column__" + ) + CORRELATION_VALUE_Y_COLUMN = verify_temp_column_name( + sdf, "__correlation_kendall_value_y_temp_column__" + ) + ROW_NUMBER_X_Y_COLUMN = verify_temp_column_name( + sdf, "__correlation_kendall_row_number_x_y_temp_column__" + ) + sdf2 = sdf.select( + *[F.col(key) for key in groupKeys], + *[ + F.col(CORRELATION_VALUE_1_COLUMN).alias(CORRELATION_VALUE_X_COLUMN), + F.col(CORRELATION_VALUE_2_COLUMN).alias(CORRELATION_VALUE_Y_COLUMN), + F.col(ROW_NUMBER_1_2_COLUMN).alias(ROW_NUMBER_X_Y_COLUMN), + ], + ) + + sdf = sdf.join(sdf2, groupKeys, "inner").where( + F.col(ROW_NUMBER_1_2_COLUMN) <= F.col(ROW_NUMBER_X_Y_COLUMN) + ) + + # compute P, Q, T, U in tau_b = (P - Q) / sqrt((P + Q + T) * (P + Q + U)) + # see https://github.com/scipy/scipy/blob/v1.9.1/scipy/stats/_stats_py.py#L5015-L5222 + CORRELATION_KENDALL_P_COLUMN = verify_temp_column_name( + sdf, "__correlation_kendall_tau_b_p_temp_column__" + ) + CORRELATION_KENDALL_Q_COLUMN = verify_temp_column_name( + sdf, "__correlation_kendall_tau_b_q_temp_column__" + ) + CORRELATION_KENDALL_T_COLUMN = verify_temp_column_name( + sdf, "__correlation_kendall_tau_b_t_temp_column__" + ) + CORRELATION_KENDALL_U_COLUMN = verify_temp_column_name( + sdf, "__correlation_kendall_tau_b_u_temp_column__" + ) + + pair_cond = ~F.isnull(CORRELATION_VALUE_1_COLUMN) & ( + F.col(ROW_NUMBER_1_2_COLUMN) < F.col(ROW_NUMBER_X_Y_COLUMN) + ) + + p_cond = ( + (F.col(CORRELATION_VALUE_1_COLUMN) < F.col(CORRELATION_VALUE_X_COLUMN)) + & (F.col(CORRELATION_VALUE_2_COLUMN) < F.col(CORRELATION_VALUE_Y_COLUMN)) + ) | ( + (F.col(CORRELATION_VALUE_1_COLUMN) > F.col(CORRELATION_VALUE_X_COLUMN)) + & (F.col(CORRELATION_VALUE_2_COLUMN) > F.col(CORRELATION_VALUE_Y_COLUMN)) + ) + q_cond = ( + (F.col(CORRELATION_VALUE_1_COLUMN) < F.col(CORRELATION_VALUE_X_COLUMN)) + & (F.col(CORRELATION_VALUE_2_COLUMN) > F.col(CORRELATION_VALUE_Y_COLUMN)) + ) | ( + (F.col(CORRELATION_VALUE_1_COLUMN) > F.col(CORRELATION_VALUE_X_COLUMN)) + & (F.col(CORRELATION_VALUE_2_COLUMN) < F.col(CORRELATION_VALUE_Y_COLUMN)) + ) + t_cond = (F.col(CORRELATION_VALUE_1_COLUMN) == F.col(CORRELATION_VALUE_X_COLUMN)) & ( + F.col(CORRELATION_VALUE_2_COLUMN) != F.col(CORRELATION_VALUE_Y_COLUMN) + ) + u_cond = (F.col(CORRELATION_VALUE_1_COLUMN) != F.col(CORRELATION_VALUE_X_COLUMN)) & ( + F.col(CORRELATION_VALUE_2_COLUMN) == F.col(CORRELATION_VALUE_Y_COLUMN) + ) + + sdf = ( + sdf.groupby(groupKeys) + .agg( + F.count(F.when(pair_cond & p_cond, 1)).alias(CORRELATION_KENDALL_P_COLUMN), + F.count(F.when(pair_cond & q_cond, 1)).alias(CORRELATION_KENDALL_Q_COLUMN), + F.count(F.when(pair_cond & t_cond, 1)).alias(CORRELATION_KENDALL_T_COLUMN), + F.count(F.when(pair_cond & u_cond, 1)).alias(CORRELATION_KENDALL_U_COLUMN), + F.max( + F.when( + ~F.isnull(CORRELATION_VALUE_1_COLUMN), F.col(ROW_NUMBER_X_Y_COLUMN) + ).otherwise(F.lit(0)) + ).alias(CORRELATION_COUNT_OUTPUT_COLUMN), + ) + .withColumn( + CORRELATION_CORR_OUTPUT_COLUMN, + (F.col(CORRELATION_KENDALL_P_COLUMN) - F.col(CORRELATION_KENDALL_Q_COLUMN)) + / F.sqrt( + ( + ( + F.col(CORRELATION_KENDALL_P_COLUMN) + + F.col(CORRELATION_KENDALL_Q_COLUMN) + + (F.col(CORRELATION_KENDALL_T_COLUMN)) + ) + ) + * ( + ( + F.col(CORRELATION_KENDALL_P_COLUMN) + + F.col(CORRELATION_KENDALL_Q_COLUMN) + + (F.col(CORRELATION_KENDALL_U_COLUMN)) + ) + ) + ), + ) + ) + + sdf = sdf.select( + *[F.col(key) for key in groupKeys], + *[CORRELATION_CORR_OUTPUT_COLUMN, CORRELATION_COUNT_OUTPUT_COLUMN], + ) + return sdf diff --git a/python/pyspark/pandas/frame.py b/python/pyspark/pandas/frame.py index e2b70caf5d7e2..e1e5125476c09 100644 --- a/python/pyspark/pandas/frame.py +++ b/python/pyspark/pandas/frame.py @@ -88,6 +88,13 @@ from pyspark.pandas._typing import Axis, DataFrameOrSeries, Dtype, Label, Name, Scalar, T from pyspark.pandas.accessors import PandasOnSparkFrameMethods from pyspark.pandas.config import option_context, get_option +from pyspark.pandas.correlation import ( + compute, + CORRELATION_VALUE_1_COLUMN, + CORRELATION_VALUE_2_COLUMN, + CORRELATION_CORR_OUTPUT_COLUMN, + CORRELATION_COUNT_OUTPUT_COLUMN, +) from pyspark.pandas.spark import functions as SF from pyspark.pandas.spark.accessors import SparkFrameMethods, CachedSparkFrameMethods from pyspark.pandas.utils import ( @@ -1489,10 +1496,9 @@ def corr(self, method: str = "pearson", min_periods: Optional[int] = None) -> "D num_scols = len(numeric_scols) sdf = internal.spark_frame - tmp_index_1_col_name = verify_temp_column_name(sdf, "__tmp_index_1_col__") - tmp_index_2_col_name = verify_temp_column_name(sdf, "__tmp_index_2_col__") - tmp_value_1_col_name = verify_temp_column_name(sdf, "__tmp_value_1_col__") - tmp_value_2_col_name = verify_temp_column_name(sdf, "__tmp_value_2_col__") + index_1_col_name = verify_temp_column_name(sdf, "__corr_index_1_temp_column__") + index_2_col_name = verify_temp_column_name(sdf, "__corr_index_2_temp_column__") + tuple_col_name = verify_temp_column_name(sdf, "__corr_tuple_temp_column__") # simple dataset # +---+---+----+ @@ -1507,10 +1513,10 @@ def corr(self, method: str = "pearson", min_periods: Optional[int] = None) -> "D for j in range(i, num_scols): pair_scols.append( F.struct( - F.lit(i).alias(tmp_index_1_col_name), - F.lit(j).alias(tmp_index_2_col_name), - numeric_scols[i].alias(tmp_value_1_col_name), - numeric_scols[j].alias(tmp_value_2_col_name), + F.lit(i).alias(index_1_col_name), + F.lit(j).alias(index_2_col_name), + numeric_scols[i].alias(CORRELATION_VALUE_1_COLUMN), + numeric_scols[j].alias(CORRELATION_VALUE_2_COLUMN), ) ) @@ -1530,209 +1536,32 @@ def corr(self, method: str = "pearson", min_periods: Optional[int] = None) -> "D # | 1| 2| null| null| # | 2| 2| null| null| # +-------------------+-------------------+-------------------+-------------------+ - tmp_tuple_col_name = verify_temp_column_name(sdf, "__tmp_tuple_col__") - null_cond = F.isnull(F.col(f"{tmp_tuple_col_name}.{tmp_value_1_col_name}")) | F.isnull( - F.col(f"{tmp_tuple_col_name}.{tmp_value_2_col_name}") - ) - sdf = sdf.select(F.explode(F.array(*pair_scols)).alias(tmp_tuple_col_name)).select( - F.col(f"{tmp_tuple_col_name}.{tmp_index_1_col_name}").alias(tmp_index_1_col_name), - F.col(f"{tmp_tuple_col_name}.{tmp_index_2_col_name}").alias(tmp_index_2_col_name), - F.when(null_cond, F.lit(None)) - .otherwise(F.col(f"{tmp_tuple_col_name}.{tmp_value_1_col_name}")) - .alias(tmp_value_1_col_name), - F.when(null_cond, F.lit(None)) - .otherwise(F.col(f"{tmp_tuple_col_name}.{tmp_value_2_col_name}")) - .alias(tmp_value_2_col_name), - ) - not_null_cond = ( - F.col(tmp_value_1_col_name).isNotNull() & F.col(tmp_value_2_col_name).isNotNull() + sdf = sdf.select(F.explode(F.array(*pair_scols)).alias(tuple_col_name)).select( + F.col(f"{tuple_col_name}.{index_1_col_name}").alias(index_1_col_name), + F.col(f"{tuple_col_name}.{index_2_col_name}").alias(index_2_col_name), + F.col(f"{tuple_col_name}.{CORRELATION_VALUE_1_COLUMN}").alias( + CORRELATION_VALUE_1_COLUMN + ), + F.col(f"{tuple_col_name}.{CORRELATION_VALUE_2_COLUMN}").alias( + CORRELATION_VALUE_2_COLUMN + ), ) - tmp_count_col_name = verify_temp_column_name(sdf, "__tmp_count_col__") - tmp_corr_col_name = verify_temp_column_name(sdf, "__tmp_corr_col__") - if method in ["pearson", "spearman"]: - # convert values to avg ranks for spearman correlation - if method == "spearman": - tmp_row_number_col_name = verify_temp_column_name(sdf, "__tmp_row_number_col__") - tmp_dense_rank_col_name = verify_temp_column_name(sdf, "__tmp_dense_rank_col__") - window = Window.partitionBy(tmp_index_1_col_name, tmp_index_2_col_name) - - # tmp_value_1_col_name: value -> avg rank - # for example: - # values: 3, 4, 5, 7, 7, 7, 9, 9, 10 - # avg ranks: 1.0, 2.0, 3.0, 5.0, 5.0, 5.0, 7.5, 7.5, 9.0 - sdf = ( - sdf.withColumn( - tmp_row_number_col_name, - F.row_number().over(window.orderBy(F.asc_nulls_last(tmp_value_1_col_name))), - ) - .withColumn( - tmp_dense_rank_col_name, - F.dense_rank().over(window.orderBy(F.asc_nulls_last(tmp_value_1_col_name))), - ) - .withColumn( - tmp_value_1_col_name, - F.when(F.isnull(F.col(tmp_value_1_col_name)), F.lit(None)).otherwise( - F.avg(tmp_row_number_col_name).over( - window.orderBy(F.asc(tmp_dense_rank_col_name)).rangeBetween(0, 0) - ) - ), - ) - ) - - # tmp_value_2_col_name: value -> avg rank - sdf = ( - sdf.withColumn( - tmp_row_number_col_name, - F.row_number().over(window.orderBy(F.asc_nulls_last(tmp_value_2_col_name))), - ) - .withColumn( - tmp_dense_rank_col_name, - F.dense_rank().over(window.orderBy(F.asc_nulls_last(tmp_value_2_col_name))), - ) - .withColumn( - tmp_value_2_col_name, - F.when(F.isnull(F.col(tmp_value_2_col_name)), F.lit(None)).otherwise( - F.avg(tmp_row_number_col_name).over( - window.orderBy(F.asc(tmp_dense_rank_col_name)).rangeBetween(0, 0) - ) - ), - ) - ) - - sdf = sdf.select( - tmp_index_1_col_name, - tmp_index_2_col_name, - tmp_value_1_col_name, - tmp_value_2_col_name, - ) - - # +-------------------+-------------------+----------------+-----------------+ - # |__tmp_index_1_col__|__tmp_index_2_col__|__tmp_corr_col__|__tmp_count_col__| - # +-------------------+-------------------+----------------+-----------------+ - # | 2| 2| null| 1| - # | 1| 2| null| 1| - # | 1| 1| 1.0| 2| - # | 0| 0| 1.0| 2| - # | 0| 1| -1.0| 2| - # | 0| 2| null| 1| - # +-------------------+-------------------+----------------+-----------------+ - sdf = sdf.groupby(tmp_index_1_col_name, tmp_index_2_col_name).agg( - F.corr(tmp_value_1_col_name, tmp_value_2_col_name).alias(tmp_corr_col_name), - F.count( - F.when( - not_null_cond, - 1, - ) - ).alias(tmp_count_col_name), - ) - - else: - # kendall correlation - tmp_row_number_12_col_name = verify_temp_column_name(sdf, "__tmp_row_number_12_col__") + sdf = compute(sdf=sdf, groupKeys=[index_1_col_name, index_2_col_name], method=method) + if method == "kendall": sdf = sdf.withColumn( - tmp_row_number_12_col_name, - F.row_number().over( - Window.partitionBy(tmp_index_1_col_name, tmp_index_2_col_name).orderBy( - F.asc_nulls_last(tmp_value_1_col_name), - F.asc_nulls_last(tmp_value_2_col_name), - ) + CORRELATION_CORR_OUTPUT_COLUMN, + F.when(F.col(index_1_col_name) == F.col(index_2_col_name), F.lit(1.0)).otherwise( + F.col(CORRELATION_CORR_OUTPUT_COLUMN) ), ) - # drop nulls but make sure each partition contains at least one row - sdf = sdf.where(not_null_cond | (F.col(tmp_row_number_12_col_name) == 1)) - - tmp_value_x_col_name = verify_temp_column_name(sdf, "__tmp_value_x_col__") - tmp_value_y_col_name = verify_temp_column_name(sdf, "__tmp_value_y_col__") - tmp_row_number_xy_col_name = verify_temp_column_name(sdf, "__tmp_row_number_xy_col__") - sdf2 = sdf.select( - F.col(tmp_index_1_col_name), - F.col(tmp_index_2_col_name), - F.col(tmp_value_1_col_name).alias(tmp_value_x_col_name), - F.col(tmp_value_2_col_name).alias(tmp_value_y_col_name), - F.col(tmp_row_number_12_col_name).alias(tmp_row_number_xy_col_name), - ) - - sdf = sdf.join(sdf2, [tmp_index_1_col_name, tmp_index_2_col_name], "inner").where( - F.col(tmp_row_number_12_col_name) <= F.col(tmp_row_number_xy_col_name) - ) - - # compute P, Q, T, U in tau_b = (P - Q) / sqrt((P + Q + T) * (P + Q + U)) - # see https://github.com/scipy/scipy/blob/v1.9.1/scipy/stats/_stats_py.py#L5015-L5222 - tmp_tau_b_p_col_name = verify_temp_column_name(sdf, "__tmp_tau_b_p_col__") - tmp_tau_b_q_col_name = verify_temp_column_name(sdf, "__tmp_tau_b_q_col__") - tmp_tau_b_t_col_name = verify_temp_column_name(sdf, "__tmp_tau_b_t_col__") - tmp_tau_b_u_col_name = verify_temp_column_name(sdf, "__tmp_tau_b_u_col__") - - pair_cond = not_null_cond & ( - F.col(tmp_row_number_12_col_name) < F.col(tmp_row_number_xy_col_name) - ) - - p_cond = ( - (F.col(tmp_value_1_col_name) < F.col(tmp_value_x_col_name)) - & (F.col(tmp_value_2_col_name) < F.col(tmp_value_y_col_name)) - ) | ( - (F.col(tmp_value_1_col_name) > F.col(tmp_value_x_col_name)) - & (F.col(tmp_value_2_col_name) > F.col(tmp_value_y_col_name)) - ) - q_cond = ( - (F.col(tmp_value_1_col_name) < F.col(tmp_value_x_col_name)) - & (F.col(tmp_value_2_col_name) > F.col(tmp_value_y_col_name)) - ) | ( - (F.col(tmp_value_1_col_name) > F.col(tmp_value_x_col_name)) - & (F.col(tmp_value_2_col_name) < F.col(tmp_value_y_col_name)) - ) - t_cond = (F.col(tmp_value_1_col_name) == F.col(tmp_value_x_col_name)) & ( - F.col(tmp_value_2_col_name) != F.col(tmp_value_y_col_name) - ) - u_cond = (F.col(tmp_value_1_col_name) != F.col(tmp_value_x_col_name)) & ( - F.col(tmp_value_2_col_name) == F.col(tmp_value_y_col_name) - ) - - sdf = ( - sdf.groupby(tmp_index_1_col_name, tmp_index_2_col_name) - .agg( - F.count(F.when(pair_cond & p_cond, 1)).alias(tmp_tau_b_p_col_name), - F.count(F.when(pair_cond & q_cond, 1)).alias(tmp_tau_b_q_col_name), - F.count(F.when(pair_cond & t_cond, 1)).alias(tmp_tau_b_t_col_name), - F.count(F.when(pair_cond & u_cond, 1)).alias(tmp_tau_b_u_col_name), - F.max( - F.when(not_null_cond, F.col(tmp_row_number_xy_col_name)).otherwise(F.lit(0)) - ).alias(tmp_count_col_name), - ) - .withColumn( - tmp_corr_col_name, - F.when( - F.col(tmp_index_1_col_name) == F.col(tmp_index_2_col_name), F.lit(1.0) - ).otherwise( - (F.col(tmp_tau_b_p_col_name) - F.col(tmp_tau_b_q_col_name)) - / F.sqrt( - ( - ( - F.col(tmp_tau_b_p_col_name) - + F.col(tmp_tau_b_q_col_name) - + (F.col(tmp_tau_b_t_col_name)) - ) - ) - * ( - ( - F.col(tmp_tau_b_p_col_name) - + F.col(tmp_tau_b_q_col_name) - + (F.col(tmp_tau_b_u_col_name)) - ) - ) - ) - ), - ) - ) - - sdf = sdf.select( - F.col(tmp_index_1_col_name), - F.col(tmp_index_2_col_name), - F.col(tmp_corr_col_name), - F.col(tmp_count_col_name), - ) + sdf = sdf.withColumn( + CORRELATION_CORR_OUTPUT_COLUMN, + F.when(F.col(CORRELATION_COUNT_OUTPUT_COLUMN) < min_periods, F.lit(None)).otherwise( + F.col(CORRELATION_CORR_OUTPUT_COLUMN) + ), + ) # +-------------------+-------------------+----------------+ # |__tmp_index_1_col__|__tmp_index_2_col__|__tmp_corr_col__| @@ -1747,31 +1576,23 @@ def corr(self, method: str = "pearson", min_periods: Optional[int] = None) -> "D # | 0| 2| null| # | 2| 0| null| # +-------------------+-------------------+----------------+ - sdf = ( - sdf.withColumn( - tmp_corr_col_name, + + sdf = sdf.withColumn( + tuple_col_name, + F.explode( F.when( - F.col(tmp_count_col_name) >= min_periods, F.col(tmp_corr_col_name) - ).otherwise(F.lit(None)), - ) - .withColumn( - tmp_tuple_col_name, - F.explode( - F.when( - F.col(tmp_index_1_col_name) == F.col(tmp_index_2_col_name), - F.lit([0]), - ).otherwise(F.lit([0, 1])) - ), - ) - .select( - F.when(F.col(tmp_tuple_col_name) == 0, F.col(tmp_index_1_col_name)) - .otherwise(F.col(tmp_index_2_col_name)) - .alias(tmp_index_1_col_name), - F.when(F.col(tmp_tuple_col_name) == 0, F.col(tmp_index_2_col_name)) - .otherwise(F.col(tmp_index_1_col_name)) - .alias(tmp_index_2_col_name), - F.col(tmp_corr_col_name), - ) + F.col(index_1_col_name) == F.col(index_2_col_name), + F.lit([0]), + ).otherwise(F.lit([0, 1])) + ), + ).select( + F.when(F.col(tuple_col_name) == 0, F.col(index_1_col_name)) + .otherwise(F.col(index_2_col_name)) + .alias(index_1_col_name), + F.when(F.col(tuple_col_name) == 0, F.col(index_2_col_name)) + .otherwise(F.col(index_1_col_name)) + .alias(index_2_col_name), + F.col(CORRELATION_CORR_OUTPUT_COLUMN), ) # +-------------------+--------------------+ @@ -1781,23 +1602,23 @@ def corr(self, method: str = "pearson", min_periods: Optional[int] = None) -> "D # | 1|[{0, -1.0}, {1, 1...| # | 2|[{0, null}, {1, n...| # +-------------------+--------------------+ - tmp_array_col_name = verify_temp_column_name(sdf, "__tmp_array_col__") + array_col_name = verify_temp_column_name(sdf, "__corr_array_temp_column__") sdf = ( - sdf.groupby(tmp_index_1_col_name) + sdf.groupby(index_1_col_name) .agg( F.array_sort( - F.collect_list(F.struct(F.col(tmp_index_2_col_name), F.col(tmp_corr_col_name))) - ).alias(tmp_array_col_name) + F.collect_list( + F.struct(F.col(index_2_col_name), F.col(CORRELATION_CORR_OUTPUT_COLUMN)) + ) + ).alias(array_col_name) ) - .orderBy(tmp_index_1_col_name) + .orderBy(index_1_col_name) ) for i in range(0, num_scols): - sdf = sdf.withColumn( - tmp_tuple_col_name, F.get(F.col(tmp_array_col_name), i) - ).withColumn( + sdf = sdf.withColumn(tuple_col_name, F.get(F.col(array_col_name), i)).withColumn( numeric_col_names[i], - F.col(f"{tmp_tuple_col_name}.{tmp_corr_col_name}"), + F.col(f"{tuple_col_name}.{CORRELATION_CORR_OUTPUT_COLUMN}"), ) index_col_names: List[str] = [] @@ -1805,14 +1626,12 @@ def corr(self, method: str = "pearson", min_periods: Optional[int] = None) -> "D for level in range(0, internal.column_labels_level): index_col_name = SPARK_INDEX_NAME_FORMAT(level) indices = [label[level] for label in numeric_labels] - sdf = sdf.withColumn( - index_col_name, F.get(F.lit(indices), F.col(tmp_index_1_col_name)) - ) + sdf = sdf.withColumn(index_col_name, F.get(F.lit(indices), F.col(index_1_col_name))) index_col_names.append(index_col_name) else: sdf = sdf.withColumn( SPARK_DEFAULT_INDEX_NAME, - F.get(F.lit(numeric_col_names), F.col(tmp_index_1_col_name)), + F.get(F.lit(numeric_col_names), F.col(index_1_col_name)), ) index_col_names = [SPARK_DEFAULT_INDEX_NAME] From b46f23de3ff4149b37c7a8d75ae1d83343858047 Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Mon, 19 Sep 2022 14:46:43 +0800 Subject: [PATCH 2/4] refactor corrwith --- python/pyspark/pandas/frame.py | 128 ++++++++++++------ python/pyspark/pandas/tests/test_dataframe.py | 18 +-- 2 files changed, 96 insertions(+), 50 deletions(-) diff --git a/python/pyspark/pandas/frame.py b/python/pyspark/pandas/frame.py index e1e5125476c09..8cde566ce9bcf 100644 --- a/python/pyspark/pandas/frame.py +++ b/python/pyspark/pandas/frame.py @@ -1648,9 +1648,8 @@ def corr(self, method: str = "pearson", min_periods: Optional[int] = None) -> "D ) ) - # TODO: add axis parameter and support more methods def corrwith( - self, other: DataFrameOrSeries, drop: bool = False, method: str = "pearson" + self, other: DataFrameOrSeries, axis: Axis = 0, drop: bool = False, method: str = "pearson" ) -> "Series": """ Compute pairwise correlation. @@ -1690,10 +1689,10 @@ def corrwith( ... "A":[1, 5, 7, 8], ... "X":[5, 8, 4, 3], ... "C":[10, 4, 9, 3]}) - >>> df1.corrwith(df1[["X", "C"]]) - X 1.0 - C 1.0 + >>> df1.corrwith(df1[["X", "C"]]).sort_index() A NaN + C 1.0 + X 1.0 dtype: float64 >>> df2 = ps.DataFrame({ @@ -1702,15 +1701,31 @@ def corrwith( ... "C":[4, 3, 8, 5]}) >>> with ps.option_context("compute.ops_on_diff_frames", True): - ... df1.corrwith(df2) + ... df1.corrwith(df2).sort_index() A -0.041703 + B NaN C 0.395437 X NaN + dtype: float64 + + >>> with ps.option_context("compute.ops_on_diff_frames", True): + ... df1.corrwith(df2, method="kendall").sort_index() + A 0.0 + B NaN + C 0.0 + X NaN + dtype: float64 + + >>> with ps.option_context("compute.ops_on_diff_frames", True): + ... df1.corrwith(df2, method="spearman").sort_index() + A -0.041703 B NaN + C 0.395437 + X NaN dtype: float64 >>> with ps.option_context("compute.ops_on_diff_frames", True): - ... df2.corrwith(df1.X) + ... df2.corrwith(df1.X).sort_index() A -0.597614 B -0.151186 C -0.642857 @@ -1718,8 +1733,11 @@ def corrwith( """ from pyspark.pandas.series import Series, first_series - if (method is not None) and (method not in ["pearson"]): - raise NotImplementedError("corrwith currently works only for method='pearson'") + axis = validate_axis(axis) + if axis != 0: + raise NotImplementedError("corrwith currently only works for axis=0") + if method not in ["pearson", "spearman", "kendall"]: + raise ValueError(f"Invalid method {method}") if not isinstance(other, (DataFrame, Series)): raise TypeError("unsupported type: {}".format(type(other).__name__)) @@ -1734,6 +1752,10 @@ def corrwith( this = combined["this"] that = combined["that"] + sdf = combined._internal.spark_frame + index_col_name = verify_temp_column_name(sdf, "__corrwith_index_temp_column__") + tuple_col_name = verify_temp_column_name(sdf, "__corrwith_tuple_temp_column__") + this_numeric_column_labels: List[Label] = [] for column_label in this._internal.column_labels: if isinstance(this._internal.spark_type_for(column_label), (NumericType, BooleanType)): @@ -1746,15 +1768,19 @@ def corrwith( intersect_numeric_column_labels: List[Label] = [] diff_numeric_column_labels: List[Label] = [] - corr_scols = [] + pair_scols: List[Column] = [] if right_is_series: intersect_numeric_column_labels = this_numeric_column_labels - that_scol = that._internal.spark_column_for(that_numeric_column_labels[0]) + that_scol = that._internal.spark_column_for(that_numeric_column_labels[0]).cast( + "double" + ) for numeric_column_label in intersect_numeric_column_labels: - this_scol = this._internal.spark_column_for(numeric_column_label) - corr_scols.append( - F.corr(this_scol.cast("double"), that_scol.cast("double")).alias( - name_like_string(numeric_column_label) + this_scol = this._internal.spark_column_for(numeric_column_label).cast("double") + pair_scols.append( + F.struct( + F.lit(name_like_string(numeric_column_label)).alias(index_col_name), + this_scol.alias(CORRELATION_VALUE_1_COLUMN), + that_scol.alias(CORRELATION_VALUE_2_COLUMN), ) ) else: @@ -1767,37 +1793,57 @@ def corrwith( if numeric_column_label not in this_numeric_column_labels: diff_numeric_column_labels.append(numeric_column_label) for numeric_column_label in intersect_numeric_column_labels: - this_scol = this._internal.spark_column_for(numeric_column_label) - that_scol = that._internal.spark_column_for(numeric_column_label) - corr_scols.append( - F.corr(this_scol.cast("double"), that_scol.cast("double")).alias( - name_like_string(numeric_column_label) + this_scol = this._internal.spark_column_for(numeric_column_label).cast("double") + that_scol = that._internal.spark_column_for(numeric_column_label).cast("double") + pair_scols.append( + F.struct( + F.lit(name_like_string(numeric_column_label)).alias(index_col_name), + this_scol.alias(CORRELATION_VALUE_1_COLUMN), + that_scol.alias(CORRELATION_VALUE_2_COLUMN), ) ) - corr_labels: List[Label] = intersect_numeric_column_labels - if not drop: - for numeric_column_label in diff_numeric_column_labels: - corr_scols.append( - F.lit(None).cast("double").alias(name_like_string(numeric_column_label)) - ) - corr_labels.append(numeric_column_label) - - sdf = combined._internal.spark_frame.select( - *[F.lit(None).cast(StringType()).alias(SPARK_DEFAULT_INDEX_NAME)], *corr_scols - ).limit( - 1 - ) # limit(1) to avoid returning more than 1 row when intersection is empty + if len(pair_scols) > 0: + sdf = sdf.select(F.explode(F.array(*pair_scols)).alias(tuple_col_name)).select( + F.col(f"{tuple_col_name}.{index_col_name}").alias(index_col_name), + F.col(f"{tuple_col_name}.{CORRELATION_VALUE_1_COLUMN}").alias( + CORRELATION_VALUE_1_COLUMN + ), + F.col(f"{tuple_col_name}.{CORRELATION_VALUE_2_COLUMN}").alias( + CORRELATION_VALUE_2_COLUMN + ), + ) - # The data is expected to be small so it's fine to transpose/use default index. - with ps.option_context("compute.max_rows", 1): - internal = InternalFrame( - spark_frame=sdf, - index_spark_columns=[scol_for(sdf, SPARK_DEFAULT_INDEX_NAME)], - column_labels=corr_labels, - column_label_names=self._internal.column_label_names, + sdf = compute(sdf=sdf, groupKeys=[index_col_name], method=method).select( + index_col_name, CORRELATION_CORR_OUTPUT_COLUMN ) - return first_series(DataFrame(internal).transpose()) + + else: + sdf = self._internal.spark_frame.select( + F.lit(None).cast("string").alias(index_col_name), + F.lit(None).cast("double").alias(CORRELATION_CORR_OUTPUT_COLUMN), + ).limit(0) + + if not drop and len(diff_numeric_column_labels) > 0: + sdf2 = sdf.sparkSession.createDataFrame( + [name_like_string(label) for label in diff_numeric_column_labels], StringType() + ).select(F.col("value").alias(index_col_name)) + sdf = sdf.unionByName(sdf2, allowMissingColumns=True) + + sdf = sdf.withColumn( + NATURAL_ORDER_COLUMN_NAME, + F.monotonically_increasing_id(), + ) + + internal = InternalFrame( + spark_frame=sdf, + index_spark_columns=[scol_for(sdf, index_col_name)], + column_labels=[(CORRELATION_CORR_OUTPUT_COLUMN,)], + column_label_names=self._internal.column_label_names, + ) + sser = first_series(DataFrame(internal)) + sser.name = None + return sser def iteritems(self) -> Iterator[Tuple[Name, "Series"]]: """ diff --git a/python/pyspark/pandas/tests/test_dataframe.py b/python/pyspark/pandas/tests/test_dataframe.py index 48919514459f3..5da0974c9063d 100644 --- a/python/pyspark/pandas/tests/test_dataframe.py +++ b/python/pyspark/pandas/tests/test_dataframe.py @@ -6062,13 +6062,12 @@ def test_corrwith(self): self._test_corrwith((df1 + 1), (df2.C + 2)) self._test_corrwith((df1 + 1), (df3.B + 2)) - with self.assertRaisesRegex( - NotImplementedError, "corrwith currently works only for method='pearson'" - ): - df1.corrwith(df2, method="kendall") - with self.assertRaisesRegex(TypeError, "unsupported type"): df1.corrwith(123) + with self.assertRaisesRegex(NotImplementedError, "only works for axis=0"): + df1.corrwith(df1.A, axis=1) + with self.assertRaisesRegex(ValueError, "Invalid method"): + df1.corrwith(df1.A, method="cov") df_bool = ps.DataFrame({"A": [True, True, False, False], "B": [True, False, False, True]}) self._test_corrwith(df_bool, df_bool.A) @@ -6077,10 +6076,11 @@ def test_corrwith(self): def _test_corrwith(self, psdf, psobj): pdf = psdf.to_pandas() pobj = psobj.to_pandas() - for drop in [True, False]: - p_corr = pdf.corrwith(pobj, drop=drop) - ps_corr = psdf.corrwith(psobj, drop=drop) - self.assert_eq(p_corr.sort_index(), ps_corr.sort_index(), almost=True) + for method in ["pearson", "spearman", "kendall"]: + for drop in [True, False]: + p_corr = pdf.corrwith(pobj, drop=drop, method=method) + ps_corr = psdf.corrwith(psobj, drop=drop, method=method) + self.assert_eq(p_corr.sort_index(), ps_corr.sort_index(), almost=True) def test_iteritems(self): pdf = pd.DataFrame( From 8790b46b93b1630636560ba1212eed1733c57fb2 Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Mon, 19 Sep 2022 16:17:11 +0800 Subject: [PATCH 3/4] fix test --- python/pyspark/pandas/frame.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/python/pyspark/pandas/frame.py b/python/pyspark/pandas/frame.py index 8cde566ce9bcf..78b97d9f4651a 100644 --- a/python/pyspark/pandas/frame.py +++ b/python/pyspark/pandas/frame.py @@ -1717,11 +1717,10 @@ def corrwith( dtype: float64 >>> with ps.option_context("compute.ops_on_diff_frames", True): - ... df1.corrwith(df2, method="spearman").sort_index() - A -0.041703 - B NaN - C 0.395437 - X NaN + ... df1.corrwith(df2.B, method="spearman").sort_index() + A -0.4 + C 0.8 + X -0.2 dtype: float64 >>> with ps.option_context("compute.ops_on_diff_frames", True): From 2ea62a84fd3b2374040d8f9ad7d90ed507193564 Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Tue, 20 Sep 2022 10:27:05 +0800 Subject: [PATCH 4/4] update docs --- python/pyspark/pandas/frame.py | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/python/pyspark/pandas/frame.py b/python/pyspark/pandas/frame.py index 78b97d9f4651a..014fc175315a0 100644 --- a/python/pyspark/pandas/frame.py +++ b/python/pyspark/pandas/frame.py @@ -1665,14 +1665,14 @@ def corrwith( ---------- other : DataFrame, Series Object with which to compute correlations. - + axis : int, default 0 or 'index' + Can only be set to 0 at the moment. drop : bool, default False Drop missing indices from result. - - method : str, default 'pearson' - Method of correlation, one of: - + method : {'pearson', 'spearman', 'kendall'} * pearson : standard correlation coefficient + * spearman : Spearman rank correlation + * kendall : Kendall Tau correlation coefficient Returns ------- @@ -1824,9 +1824,15 @@ def corrwith( ).limit(0) if not drop and len(diff_numeric_column_labels) > 0: - sdf2 = sdf.sparkSession.createDataFrame( - [name_like_string(label) for label in diff_numeric_column_labels], StringType() - ).select(F.col("value").alias(index_col_name)) + sdf2 = ( + self._internal.spark_frame.select( + F.lit([name_like_string(label) for label in diff_numeric_column_labels]).alias( + index_col_name + ) + ) + .limit(1) + .select(F.explode(index_col_name).alias(index_col_name)) + ) sdf = sdf.unionByName(sdf2, allowMissingColumns=True) sdf = sdf.withColumn(