Skip to content

Commit

Permalink
[SPARK-23691][PYTHON] Use sql_conf util in PySpark tests where possible
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?

d6632d1 added an useful util

```python
contextmanager
def sql_conf(self, pairs):
    ...
```

to allow configuration set/unset within a block:

```python
with self.sql_conf({"spark.blah.blah.blah", "blah"})
    # test codes
```

This PR proposes to use this util where possible in PySpark tests.

Note that there look already few places affecting tests without restoring the original value back in unittest classes.

## How was this patch tested?

Manually tested via:

```
./run-tests --modules=pyspark-sql --python-executables=python2
./run-tests --modules=pyspark-sql --python-executables=python3
```

Author: hyukjinkwon <gurwls223@gmail.com>

Closes #20830 from HyukjinKwon/cleanup-sql-conf.
  • Loading branch information
HyukjinKwon authored and BryanCutler committed Mar 20, 2018
1 parent 5f4deff commit 5663218
Showing 1 changed file with 50 additions and 80 deletions.
130 changes: 50 additions & 80 deletions python/pyspark/sql/tests.py
Expand Up @@ -2461,17 +2461,13 @@ def test_join_without_on(self):
df1 = self.spark.range(1).toDF("a")
df2 = self.spark.range(1).toDF("b")

try:
self.spark.conf.set("spark.sql.crossJoin.enabled", "false")
with self.sql_conf({"spark.sql.crossJoin.enabled": False}):
self.assertRaises(AnalysisException, lambda: df1.join(df2, how="inner").collect())

self.spark.conf.set("spark.sql.crossJoin.enabled", "true")
with self.sql_conf({"spark.sql.crossJoin.enabled": True}):
actual = df1.join(df2, how="inner").collect()
expected = [Row(a=0, b=0)]
self.assertEqual(actual, expected)
finally:
# We should unset this. Otherwise, other tests are affected.
self.spark.conf.unset("spark.sql.crossJoin.enabled")

# Regression test for invalid join methods when on is None, Spark-14761
def test_invalid_join_method(self):
Expand Down Expand Up @@ -2943,21 +2939,18 @@ def test_create_dateframe_from_pandas_with_dst(self):
self.assertPandasEqual(pdf, df.toPandas())

orig_env_tz = os.environ.get('TZ', None)
orig_session_tz = self.spark.conf.get('spark.sql.session.timeZone')
try:
tz = 'America/Los_Angeles'
os.environ['TZ'] = tz
time.tzset()
self.spark.conf.set('spark.sql.session.timeZone', tz)

df = self.spark.createDataFrame(pdf)
self.assertPandasEqual(pdf, df.toPandas())
with self.sql_conf({'spark.sql.session.timeZone': tz}):
df = self.spark.createDataFrame(pdf)
self.assertPandasEqual(pdf, df.toPandas())
finally:
del os.environ['TZ']
if orig_env_tz is not None:
os.environ['TZ'] = orig_env_tz
time.tzset()
self.spark.conf.set('spark.sql.session.timeZone', orig_session_tz)


class HiveSparkSubmitTests(SparkSubmitTests):
Expand Down Expand Up @@ -3562,12 +3555,11 @@ def test_null_conversion(self):
self.assertTrue(all([c == 1 for c in null_counts]))

def _toPandas_arrow_toggle(self, df):
self.spark.conf.set("spark.sql.execution.arrow.enabled", "false")
try:
with self.sql_conf({"spark.sql.execution.arrow.enabled": False}):
pdf = df.toPandas()
finally:
self.spark.conf.set("spark.sql.execution.arrow.enabled", "true")

pdf_arrow = df.toPandas()

return pdf, pdf_arrow

def test_toPandas_arrow_toggle(self):
Expand All @@ -3579,16 +3571,17 @@ def test_toPandas_arrow_toggle(self):

def test_toPandas_respect_session_timezone(self):
df = self.spark.createDataFrame(self.data, schema=self.schema)
orig_tz = self.spark.conf.get("spark.sql.session.timeZone")
try:
timezone = "America/New_York"
self.spark.conf.set("spark.sql.session.timeZone", timezone)
self.spark.conf.set("spark.sql.execution.pandas.respectSessionTimeZone", "false")
try:
pdf_la, pdf_arrow_la = self._toPandas_arrow_toggle(df)
self.assertPandasEqual(pdf_arrow_la, pdf_la)
finally:
self.spark.conf.set("spark.sql.execution.pandas.respectSessionTimeZone", "true")

timezone = "America/New_York"
with self.sql_conf({
"spark.sql.execution.pandas.respectSessionTimeZone": False,
"spark.sql.session.timeZone": timezone}):
pdf_la, pdf_arrow_la = self._toPandas_arrow_toggle(df)
self.assertPandasEqual(pdf_arrow_la, pdf_la)

with self.sql_conf({
"spark.sql.execution.pandas.respectSessionTimeZone": True,
"spark.sql.session.timeZone": timezone}):
pdf_ny, pdf_arrow_ny = self._toPandas_arrow_toggle(df)
self.assertPandasEqual(pdf_arrow_ny, pdf_ny)

Expand All @@ -3601,8 +3594,6 @@ def test_toPandas_respect_session_timezone(self):
pdf_la_corrected[field.name] = _check_series_convert_timestamps_local_tz(
pdf_la_corrected[field.name], timezone)
self.assertPandasEqual(pdf_ny, pdf_la_corrected)
finally:
self.spark.conf.set("spark.sql.session.timeZone", orig_tz)

def test_pandas_round_trip(self):
pdf = self.create_pandas_data_frame()
Expand All @@ -3618,12 +3609,11 @@ def test_filtered_frame(self):
self.assertTrue(pdf.empty)

def _createDataFrame_toggle(self, pdf, schema=None):
self.spark.conf.set("spark.sql.execution.arrow.enabled", "false")
try:
with self.sql_conf({"spark.sql.execution.arrow.enabled": False}):
df_no_arrow = self.spark.createDataFrame(pdf, schema=schema)
finally:
self.spark.conf.set("spark.sql.execution.arrow.enabled", "true")

df_arrow = self.spark.createDataFrame(pdf, schema=schema)

return df_no_arrow, df_arrow

def test_createDataFrame_toggle(self):
Expand All @@ -3634,18 +3624,18 @@ def test_createDataFrame_toggle(self):
def test_createDataFrame_respect_session_timezone(self):
from datetime import timedelta
pdf = self.create_pandas_data_frame()
orig_tz = self.spark.conf.get("spark.sql.session.timeZone")
try:
timezone = "America/New_York"
self.spark.conf.set("spark.sql.session.timeZone", timezone)
self.spark.conf.set("spark.sql.execution.pandas.respectSessionTimeZone", "false")
try:
df_no_arrow_la, df_arrow_la = self._createDataFrame_toggle(pdf, schema=self.schema)
result_la = df_no_arrow_la.collect()
result_arrow_la = df_arrow_la.collect()
self.assertEqual(result_la, result_arrow_la)
finally:
self.spark.conf.set("spark.sql.execution.pandas.respectSessionTimeZone", "true")
timezone = "America/New_York"
with self.sql_conf({
"spark.sql.execution.pandas.respectSessionTimeZone": False,
"spark.sql.session.timeZone": timezone}):
df_no_arrow_la, df_arrow_la = self._createDataFrame_toggle(pdf, schema=self.schema)
result_la = df_no_arrow_la.collect()
result_arrow_la = df_arrow_la.collect()
self.assertEqual(result_la, result_arrow_la)

with self.sql_conf({
"spark.sql.execution.pandas.respectSessionTimeZone": True,
"spark.sql.session.timeZone": timezone}):
df_no_arrow_ny, df_arrow_ny = self._createDataFrame_toggle(pdf, schema=self.schema)
result_ny = df_no_arrow_ny.collect()
result_arrow_ny = df_arrow_ny.collect()
Expand All @@ -3658,8 +3648,6 @@ def test_createDataFrame_respect_session_timezone(self):
for k, v in row.asDict().items()})
for row in result_la]
self.assertEqual(result_ny, result_la_corrected)
finally:
self.spark.conf.set("spark.sql.session.timeZone", orig_tz)

def test_createDataFrame_with_schema(self):
pdf = self.create_pandas_data_frame()
Expand Down Expand Up @@ -4336,9 +4324,7 @@ def gen_timestamps(id):
def test_vectorized_udf_check_config(self):
from pyspark.sql.functions import pandas_udf, col
import pandas as pd
orig_value = self.spark.conf.get("spark.sql.execution.arrow.maxRecordsPerBatch", None)
self.spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", 3)
try:
with self.sql_conf({"spark.sql.execution.arrow.maxRecordsPerBatch": 3}):
df = self.spark.range(10, numPartitions=1)

@pandas_udf(returnType=LongType())
Expand All @@ -4348,11 +4334,6 @@ def check_records_per_batch(x):
result = df.select(check_records_per_batch(col("id"))).collect()
for (r,) in result:
self.assertTrue(r <= 3)
finally:
if orig_value is None:
self.spark.conf.unset("spark.sql.execution.arrow.maxRecordsPerBatch")
else:
self.spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", orig_value)

def test_vectorized_udf_timestamps_respect_session_timezone(self):
from pyspark.sql.functions import pandas_udf, col
Expand All @@ -4371,30 +4352,27 @@ def test_vectorized_udf_timestamps_respect_session_timezone(self):
internal_value = pandas_udf(
lambda ts: ts.apply(lambda ts: ts.value if ts is not pd.NaT else None), LongType())

orig_tz = self.spark.conf.get("spark.sql.session.timeZone")
try:
timezone = "America/New_York"
self.spark.conf.set("spark.sql.session.timeZone", timezone)
self.spark.conf.set("spark.sql.execution.pandas.respectSessionTimeZone", "false")
try:
df_la = df.withColumn("tscopy", f_timestamp_copy(col("timestamp"))) \
.withColumn("internal_value", internal_value(col("timestamp")))
result_la = df_la.select(col("idx"), col("internal_value")).collect()
# Correct result_la by adjusting 3 hours difference between Los Angeles and New York
diff = 3 * 60 * 60 * 1000 * 1000 * 1000
result_la_corrected = \
df_la.select(col("idx"), col("tscopy"), col("internal_value") + diff).collect()
finally:
self.spark.conf.set("spark.sql.execution.pandas.respectSessionTimeZone", "true")
timezone = "America/New_York"
with self.sql_conf({
"spark.sql.execution.pandas.respectSessionTimeZone": False,
"spark.sql.session.timeZone": timezone}):
df_la = df.withColumn("tscopy", f_timestamp_copy(col("timestamp"))) \
.withColumn("internal_value", internal_value(col("timestamp")))
result_la = df_la.select(col("idx"), col("internal_value")).collect()
# Correct result_la by adjusting 3 hours difference between Los Angeles and New York
diff = 3 * 60 * 60 * 1000 * 1000 * 1000
result_la_corrected = \
df_la.select(col("idx"), col("tscopy"), col("internal_value") + diff).collect()

with self.sql_conf({
"spark.sql.execution.pandas.respectSessionTimeZone": True,
"spark.sql.session.timeZone": timezone}):
df_ny = df.withColumn("tscopy", f_timestamp_copy(col("timestamp"))) \
.withColumn("internal_value", internal_value(col("timestamp")))
result_ny = df_ny.select(col("idx"), col("tscopy"), col("internal_value")).collect()

self.assertNotEqual(result_ny, result_la)
self.assertEqual(result_ny, result_la_corrected)
finally:
self.spark.conf.set("spark.sql.session.timeZone", orig_tz)

def test_nondeterministic_vectorized_udf(self):
# Test that nondeterministic UDFs are evaluated only once in chained UDF evaluations
Expand Down Expand Up @@ -5170,22 +5148,14 @@ def test_complex_expressions(self):

def test_retain_group_columns(self):
from pyspark.sql.functions import sum, lit, col
orig_value = self.spark.conf.get("spark.sql.retainGroupColumns", None)
self.spark.conf.set("spark.sql.retainGroupColumns", False)
try:
with self.sql_conf({"spark.sql.retainGroupColumns": False}):
df = self.data
sum_udf = self.pandas_agg_sum_udf

result1 = df.groupby(df.id).agg(sum_udf(df.v))
expected1 = df.groupby(df.id).agg(sum(df.v))
self.assertPandasEqual(expected1.toPandas(), result1.toPandas())

finally:
if orig_value is None:
self.spark.conf.unset("spark.sql.retainGroupColumns")
else:
self.spark.conf.set("spark.sql.retainGroupColumns", orig_value)

def test_invalid_args(self):
from pyspark.sql.functions import mean

Expand Down

0 comments on commit 5663218

Please sign in to comment.