From e238d7605c0432ab8103da0c6deddbb34560bd80 Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Mon, 30 Oct 2023 18:11:20 +0800 Subject: [PATCH 1/2] init init --- python/pyspark/pandas/indexes/multi.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/python/pyspark/pandas/indexes/multi.py b/python/pyspark/pandas/indexes/multi.py index 9fbc608c12a4b..fe279d41648f2 100644 --- a/python/pyspark/pandas/indexes/multi.py +++ b/python/pyspark/pandas/indexes/multi.py @@ -810,7 +810,21 @@ def symmetric_difference( # type: ignore[override] sdf_self = self._psdf._internal.spark_frame.select(self._internal.index_spark_columns) sdf_other = other._psdf._internal.spark_frame.select(other._internal.index_spark_columns) - sdf_symdiff = sdf_self.union(sdf_other).subtract(sdf_self.intersect(sdf_other)) + tmp_tag_column_name = verify_temp_column_name(sdf_self, "__multi_index_tag__") + tmp_max_column_name = verify_temp_column_name(sdf_self, "__multi_index_max_tag__") + tmp_min_column_name = verify_temp_column_name(sdf_self, "__multi_index_min_tag__") + + sdf_symdiff = ( + sdf_self.withColumn(tmp_tag_column_name, F.lit(0)) + .union(sdf_other.withColumn(tmp_tag_column_name, F.lit(1))) + .groupBy(*self._internal.index_spark_column_names) + .agg( + F.min(tmp_tag_column_name).alias(tmp_min_column_name), + F.max(tmp_tag_column_name).alias(tmp_max_column_name), + ) + .where(F.col(tmp_min_column_name) == F.col(tmp_max_column_name)) + .drop(tmp_min_column_name, tmp_max_column_name) + ) if sort: sdf_symdiff = sdf_symdiff.sort(*self._internal.index_spark_column_names) From 52d164aaa01e50da01b3a4325bf8c2433cbf266d Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Mon, 13 Nov 2023 17:33:30 +0800 Subject: [PATCH 2/2] simplify --- python/pyspark/pandas/indexes/multi.py | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/python/pyspark/pandas/indexes/multi.py b/python/pyspark/pandas/indexes/multi.py index fe279d41648f2..62b42c1fcd02c 100644 --- a/python/pyspark/pandas/indexes/multi.py +++ b/python/pyspark/pandas/indexes/multi.py @@ -810,20 +810,17 @@ def symmetric_difference( # type: ignore[override] sdf_self = self._psdf._internal.spark_frame.select(self._internal.index_spark_columns) sdf_other = other._psdf._internal.spark_frame.select(other._internal.index_spark_columns) - tmp_tag_column_name = verify_temp_column_name(sdf_self, "__multi_index_tag__") - tmp_max_column_name = verify_temp_column_name(sdf_self, "__multi_index_max_tag__") - tmp_min_column_name = verify_temp_column_name(sdf_self, "__multi_index_min_tag__") + tmp_tag_col = verify_temp_column_name(sdf_self, "__multi_index_tag__") + tmp_max_col = verify_temp_column_name(sdf_self, "__multi_index_max_tag__") + tmp_min_col = verify_temp_column_name(sdf_self, "__multi_index_min_tag__") sdf_symdiff = ( - sdf_self.withColumn(tmp_tag_column_name, F.lit(0)) - .union(sdf_other.withColumn(tmp_tag_column_name, F.lit(1))) + sdf_self.withColumn(tmp_tag_col, F.lit(0)) + .union(sdf_other.withColumn(tmp_tag_col, F.lit(1))) .groupBy(*self._internal.index_spark_column_names) - .agg( - F.min(tmp_tag_column_name).alias(tmp_min_column_name), - F.max(tmp_tag_column_name).alias(tmp_max_column_name), - ) - .where(F.col(tmp_min_column_name) == F.col(tmp_max_column_name)) - .drop(tmp_min_column_name, tmp_max_column_name) + .agg(F.min(tmp_tag_col).alias(tmp_min_col), F.max(tmp_tag_col).alias(tmp_max_col)) + .where(F.col(tmp_min_col) == F.col(tmp_max_col)) + .select(*self._internal.index_spark_column_names) ) if sort: