diff --git a/python/pyspark/sql/tests/test_group.py b/python/pyspark/sql/tests/test_group.py index 8e3d2d8d00033..bbc089b00c133 100644 --- a/python/pyspark/sql/tests/test_group.py +++ b/python/pyspark/sql/tests/test_group.py @@ -36,11 +36,11 @@ def test_agg_func(self): data = [Row(key=1, value=10), Row(key=1, value=20), Row(key=1, value=30)] df = self.spark.createDataFrame(data) g = df.groupBy("key") - self.assertEqual(g.max("value").collect(), [Row(**{"key": 1, "max(value)": 30})]) - self.assertEqual(g.min("value").collect(), [Row(**{"key": 1, "min(value)": 10})]) - self.assertEqual(g.sum("value").collect(), [Row(**{"key": 1, "sum(value)": 60})]) - self.assertEqual(g.count().collect(), [Row(key=1, count=3)]) - self.assertEqual(g.mean("value").collect(), [Row(**{"key": 1, "avg(value)": 20.0})]) + assertDataFrameEqual(g.max("value"), [Row(**{"key": 1, "max(value)": 30})]) + assertDataFrameEqual(g.min("value"), [Row(**{"key": 1, "min(value)": 10})]) + assertDataFrameEqual(g.sum("value"), [Row(**{"key": 1, "sum(value)": 60})]) + assertDataFrameEqual(g.count(), [Row(key=1, count=3)]) + assertDataFrameEqual(g.mean("value"), [Row(**{"key": 1, "avg(value)": 20.0})]) data = [ Row(electronic="Smartphone", year=2018, sales=150000), @@ -59,7 +59,7 @@ def test_aggregator(self): df = self.df g = df.groupBy() self.assertEqual([99, 100], sorted(g.agg({"key": "max", "value": "count"}).collect()[0])) - self.assertEqual([Row(**{"AVG(key#0)": 49.5})], g.mean().collect()) + assertDataFrameEqual([Row(**{"AVG(key#0)": 49.5})], g.mean().collect()) from pyspark.sql import functions diff --git a/python/pyspark/sql/tests/test_readwriter.py b/python/pyspark/sql/tests/test_readwriter.py index 2fca6b57decf9..683c925eefc23 100644 --- a/python/pyspark/sql/tests/test_readwriter.py +++ b/python/pyspark/sql/tests/test_readwriter.py @@ -23,6 +23,7 @@ from pyspark.sql.functions import col, lit from pyspark.sql.readwriter import DataFrameWriterV2 from pyspark.sql.types import StructType, StructField, StringType +from pyspark.testing import assertDataFrameEqual from pyspark.testing.sqlutils import ReusedSQLTestCase @@ -34,15 +35,15 @@ def test_save_and_load(self): try: df.write.json(tmpPath) actual = self.spark.read.json(tmpPath) - self.assertEqual(sorted(df.collect()), sorted(actual.collect())) + assertDataFrameEqual(df, actual) schema = StructType([StructField("value", StringType(), True)]) actual = self.spark.read.json(tmpPath, schema) - self.assertEqual(sorted(df.select("value").collect()), sorted(actual.collect())) + assertDataFrameEqual(df.select("value"), actual) df.write.json(tmpPath, "overwrite") actual = self.spark.read.json(tmpPath) - self.assertEqual(sorted(df.collect()), sorted(actual.collect())) + assertDataFrameEqual(df, actual) df.write.save( format="json", @@ -53,11 +54,11 @@ def test_save_and_load(self): actual = self.spark.read.load( format="json", path=tmpPath, noUse="this options will not be used in load." ) - self.assertEqual(sorted(df.collect()), sorted(actual.collect())) + assertDataFrameEqual(df, actual) with self.sql_conf({"spark.sql.sources.default": "org.apache.spark.sql.json"}): actual = self.spark.read.load(path=tmpPath) - self.assertEqual(sorted(df.collect()), sorted(actual.collect())) + assertDataFrameEqual(df, actual) csvpath = os.path.join(tempfile.mkdtemp(), "data") df.write.option("quote", None).format("csv").save(csvpath) @@ -71,15 +72,15 @@ def test_save_and_load_builder(self): try: df.write.json(tmpPath) actual = self.spark.read.json(tmpPath) - self.assertEqual(sorted(df.collect()), sorted(actual.collect())) + assertDataFrameEqual(df, actual) schema = StructType([StructField("value", StringType(), True)]) actual = self.spark.read.json(tmpPath, schema) - self.assertEqual(sorted(df.select("value").collect()), sorted(actual.collect())) + assertDataFrameEqual(df.select("value"), actual) df.write.mode("overwrite").json(tmpPath) actual = self.spark.read.json(tmpPath) - self.assertEqual(sorted(df.collect()), sorted(actual.collect())) + assertDataFrameEqual(df, actual) df.write.mode("overwrite").options( noUse="this options will not be used in save." @@ -89,11 +90,11 @@ def test_save_and_load_builder(self): actual = self.spark.read.format("json").load( path=tmpPath, noUse="this options will not be used in load." ) - self.assertEqual(sorted(df.collect()), sorted(actual.collect())) + assertDataFrameEqual(df, actual) with self.sql_conf({"spark.sql.sources.default": "org.apache.spark.sql.json"}): actual = self.spark.read.load(path=tmpPath) - self.assertEqual(sorted(df.collect()), sorted(actual.collect())) + assertDataFrameEqual(df, actual) finally: shutil.rmtree(tmpPath)