Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions python/pyspark/sql/tests/test_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,11 @@ def test_agg_func(self):
data = [Row(key=1, value=10), Row(key=1, value=20), Row(key=1, value=30)]
df = self.spark.createDataFrame(data)
g = df.groupBy("key")
self.assertEqual(g.max("value").collect(), [Row(**{"key": 1, "max(value)": 30})])
self.assertEqual(g.min("value").collect(), [Row(**{"key": 1, "min(value)": 10})])
self.assertEqual(g.sum("value").collect(), [Row(**{"key": 1, "sum(value)": 60})])
self.assertEqual(g.count().collect(), [Row(key=1, count=3)])
self.assertEqual(g.mean("value").collect(), [Row(**{"key": 1, "avg(value)": 20.0})])
assertDataFrameEqual(g.max("value"), [Row(**{"key": 1, "max(value)": 30})])
assertDataFrameEqual(g.min("value"), [Row(**{"key": 1, "min(value)": 10})])
assertDataFrameEqual(g.sum("value"), [Row(**{"key": 1, "sum(value)": 60})])
assertDataFrameEqual(g.count(), [Row(key=1, count=3)])
assertDataFrameEqual(g.mean("value"), [Row(**{"key": 1, "avg(value)": 20.0})])

data = [
Row(electronic="Smartphone", year=2018, sales=150000),
Expand All @@ -59,7 +59,7 @@ def test_aggregator(self):
df = self.df
g = df.groupBy()
self.assertEqual([99, 100], sorted(g.agg({"key": "max", "value": "count"}).collect()[0]))
self.assertEqual([Row(**{"AVG(key#0)": 49.5})], g.mean().collect())
assertDataFrameEqual([Row(**{"AVG(key#0)": 49.5})], g.mean().collect())

from pyspark.sql import functions

Expand Down
21 changes: 11 additions & 10 deletions python/pyspark/sql/tests/test_readwriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from pyspark.sql.functions import col, lit
from pyspark.sql.readwriter import DataFrameWriterV2
from pyspark.sql.types import StructType, StructField, StringType
from pyspark.testing import assertDataFrameEqual
from pyspark.testing.sqlutils import ReusedSQLTestCase


Expand All @@ -34,15 +35,15 @@ def test_save_and_load(self):
try:
df.write.json(tmpPath)
actual = self.spark.read.json(tmpPath)
self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
assertDataFrameEqual(df, actual)

schema = StructType([StructField("value", StringType(), True)])
actual = self.spark.read.json(tmpPath, schema)
self.assertEqual(sorted(df.select("value").collect()), sorted(actual.collect()))
assertDataFrameEqual(df.select("value"), actual)

df.write.json(tmpPath, "overwrite")
actual = self.spark.read.json(tmpPath)
self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
assertDataFrameEqual(df, actual)

df.write.save(
format="json",
Expand All @@ -53,11 +54,11 @@ def test_save_and_load(self):
actual = self.spark.read.load(
format="json", path=tmpPath, noUse="this options will not be used in load."
)
self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
assertDataFrameEqual(df, actual)

with self.sql_conf({"spark.sql.sources.default": "org.apache.spark.sql.json"}):
actual = self.spark.read.load(path=tmpPath)
self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
assertDataFrameEqual(df, actual)

csvpath = os.path.join(tempfile.mkdtemp(), "data")
df.write.option("quote", None).format("csv").save(csvpath)
Expand All @@ -71,15 +72,15 @@ def test_save_and_load_builder(self):
try:
df.write.json(tmpPath)
actual = self.spark.read.json(tmpPath)
self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
assertDataFrameEqual(df, actual)

schema = StructType([StructField("value", StringType(), True)])
actual = self.spark.read.json(tmpPath, schema)
self.assertEqual(sorted(df.select("value").collect()), sorted(actual.collect()))
assertDataFrameEqual(df.select("value"), actual)

df.write.mode("overwrite").json(tmpPath)
actual = self.spark.read.json(tmpPath)
self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
assertDataFrameEqual(df, actual)

df.write.mode("overwrite").options(
noUse="this options will not be used in save."
Expand All @@ -89,11 +90,11 @@ def test_save_and_load_builder(self):
actual = self.spark.read.format("json").load(
path=tmpPath, noUse="this options will not be used in load."
)
self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
assertDataFrameEqual(df, actual)

with self.sql_conf({"spark.sql.sources.default": "org.apache.spark.sql.json"}):
actual = self.spark.read.load(path=tmpPath)
self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
assertDataFrameEqual(df, actual)
finally:
shutil.rmtree(tmpPath)

Expand Down