Skip to content
Permalink
Browse files

[SPARK-23691][PYTHON][BRANCH-2.3] Use sql_conf util in PySpark tests …

…where possible

## What changes were proposed in this pull request?

This PR backports #20830 to reduce the diff against master and restore the default value back in PySpark tests.

d6632d1 added an useful util. This backport extracts and brings this util:

```python
contextmanager
def sql_conf(self, pairs):
    ...
```

to allow configuration set/unset within a block:

```python
with self.sql_conf({"spark.blah.blah.blah", "blah"})
    # test codes
```

This PR proposes to use this util where possible in PySpark tests.

Note that there look already few places affecting tests without restoring the original value back in unittest classes.

## How was this patch tested?

Likewise, manually tested via:

```
./run-tests --modules=pyspark-sql --python-executables=python2
./run-tests --modules=pyspark-sql --python-executables=python3
```

Author: hyukjinkwon <gurwls223@gmail.com>

Closes #20863 from HyukjinKwon/backport-20830.
  • Loading branch information...
HyukjinKwon committed Mar 20, 2018
1 parent 2f82c03 commit c854b6ca7ba4dc33138c12ba4606ff8fbe82aef2
Showing with 72 additions and 71 deletions.
  1. +72 −71 python/pyspark/sql/tests.py
@@ -33,6 +33,7 @@
import array
import ctypes
import py4j
from contextlib import contextmanager

try:
import xmlrunner
@@ -201,6 +202,28 @@ def assertPandasEqual(self, expected, result):
"\n\nResult:\n%s\n%s" % (result, result.dtypes))
self.assertTrue(expected.equals(result), msg=msg)

@contextmanager
def sql_conf(self, pairs):
"""
A convenient context manager to test some configuration specific logic. This sets
`value` to the configuration `key` and then restores it back when it exits.
"""
assert isinstance(pairs, dict), "pairs should be a dictionary."

keys = pairs.keys()
new_values = pairs.values()
old_values = [self.spark.conf.get(key, None) for key in keys]
for key, new_value in zip(keys, new_values):
self.spark.conf.set(key, new_value)
try:
yield
finally:
for key, old_value in zip(keys, old_values):
if old_value is None:
self.spark.conf.unset(key)
else:
self.spark.conf.set(key, old_value)


class DataTypeTests(unittest.TestCase):
# regression test for SPARK-6055
@@ -2409,17 +2432,13 @@ def test_join_without_on(self):
df1 = self.spark.range(1).toDF("a")
df2 = self.spark.range(1).toDF("b")

try:
self.spark.conf.set("spark.sql.crossJoin.enabled", "false")
with self.sql_conf({"spark.sql.crossJoin.enabled": False}):
self.assertRaises(AnalysisException, lambda: df1.join(df2, how="inner").collect())

self.spark.conf.set("spark.sql.crossJoin.enabled", "true")
with self.sql_conf({"spark.sql.crossJoin.enabled": True}):
actual = df1.join(df2, how="inner").collect()
expected = [Row(a=0, b=0)]
self.assertEqual(actual, expected)
finally:
# We should unset this. Otherwise, other tests are affected.
self.spark.conf.unset("spark.sql.crossJoin.enabled")

# Regression test for invalid join methods when on is None, Spark-14761
def test_invalid_join_method(self):
@@ -2891,21 +2910,18 @@ def test_create_dateframe_from_pandas_with_dst(self):
self.assertPandasEqual(pdf, df.toPandas())

orig_env_tz = os.environ.get('TZ', None)
orig_session_tz = self.spark.conf.get('spark.sql.session.timeZone')
try:
tz = 'America/Los_Angeles'
os.environ['TZ'] = tz
time.tzset()
self.spark.conf.set('spark.sql.session.timeZone', tz)

df = self.spark.createDataFrame(pdf)
self.assertPandasEqual(pdf, df.toPandas())
with self.sql_conf({'spark.sql.session.timeZone': tz}):
df = self.spark.createDataFrame(pdf)
self.assertPandasEqual(pdf, df.toPandas())
finally:
del os.environ['TZ']
if orig_env_tz is not None:
os.environ['TZ'] = orig_env_tz
time.tzset()
self.spark.conf.set('spark.sql.session.timeZone', orig_session_tz)


class HiveSparkSubmitTests(SparkSubmitTests):
@@ -3472,12 +3488,11 @@ def test_null_conversion(self):
self.assertTrue(all([c == 1 for c in null_counts]))

def _toPandas_arrow_toggle(self, df):
self.spark.conf.set("spark.sql.execution.arrow.enabled", "false")
try:
with self.sql_conf({"spark.sql.execution.arrow.enabled": False}):
pdf = df.toPandas()
finally:
self.spark.conf.set("spark.sql.execution.arrow.enabled", "true")

pdf_arrow = df.toPandas()

return pdf, pdf_arrow

def test_toPandas_arrow_toggle(self):
@@ -3489,16 +3504,17 @@ def test_toPandas_arrow_toggle(self):

def test_toPandas_respect_session_timezone(self):
df = self.spark.createDataFrame(self.data, schema=self.schema)
orig_tz = self.spark.conf.get("spark.sql.session.timeZone")
try:
timezone = "America/New_York"
self.spark.conf.set("spark.sql.session.timeZone", timezone)
self.spark.conf.set("spark.sql.execution.pandas.respectSessionTimeZone", "false")
try:
pdf_la, pdf_arrow_la = self._toPandas_arrow_toggle(df)
self.assertPandasEqual(pdf_arrow_la, pdf_la)
finally:
self.spark.conf.set("spark.sql.execution.pandas.respectSessionTimeZone", "true")

timezone = "America/New_York"
with self.sql_conf({
"spark.sql.execution.pandas.respectSessionTimeZone": False,
"spark.sql.session.timeZone": timezone}):
pdf_la, pdf_arrow_la = self._toPandas_arrow_toggle(df)
self.assertPandasEqual(pdf_arrow_la, pdf_la)

with self.sql_conf({
"spark.sql.execution.pandas.respectSessionTimeZone": True,
"spark.sql.session.timeZone": timezone}):
pdf_ny, pdf_arrow_ny = self._toPandas_arrow_toggle(df)
self.assertPandasEqual(pdf_arrow_ny, pdf_ny)

@@ -3511,8 +3527,6 @@ def test_toPandas_respect_session_timezone(self):
pdf_la_corrected[field.name] = _check_series_convert_timestamps_local_tz(
pdf_la_corrected[field.name], timezone)
self.assertPandasEqual(pdf_ny, pdf_la_corrected)
finally:
self.spark.conf.set("spark.sql.session.timeZone", orig_tz)

def test_pandas_round_trip(self):
pdf = self.create_pandas_data_frame()
@@ -3528,12 +3542,11 @@ def test_filtered_frame(self):
self.assertTrue(pdf.empty)

def _createDataFrame_toggle(self, pdf, schema=None):
self.spark.conf.set("spark.sql.execution.arrow.enabled", "false")
try:
with self.sql_conf({"spark.sql.execution.arrow.enabled": False}):
df_no_arrow = self.spark.createDataFrame(pdf, schema=schema)
finally:
self.spark.conf.set("spark.sql.execution.arrow.enabled", "true")

df_arrow = self.spark.createDataFrame(pdf, schema=schema)

return df_no_arrow, df_arrow

def test_createDataFrame_toggle(self):
@@ -3544,18 +3557,18 @@ def test_createDataFrame_toggle(self):
def test_createDataFrame_respect_session_timezone(self):
from datetime import timedelta
pdf = self.create_pandas_data_frame()
orig_tz = self.spark.conf.get("spark.sql.session.timeZone")
try:
timezone = "America/New_York"
self.spark.conf.set("spark.sql.session.timeZone", timezone)
self.spark.conf.set("spark.sql.execution.pandas.respectSessionTimeZone", "false")
try:
df_no_arrow_la, df_arrow_la = self._createDataFrame_toggle(pdf, schema=self.schema)
result_la = df_no_arrow_la.collect()
result_arrow_la = df_arrow_la.collect()
self.assertEqual(result_la, result_arrow_la)
finally:
self.spark.conf.set("spark.sql.execution.pandas.respectSessionTimeZone", "true")
timezone = "America/New_York"
with self.sql_conf({
"spark.sql.execution.pandas.respectSessionTimeZone": False,
"spark.sql.session.timeZone": timezone}):
df_no_arrow_la, df_arrow_la = self._createDataFrame_toggle(pdf, schema=self.schema)
result_la = df_no_arrow_la.collect()
result_arrow_la = df_arrow_la.collect()
self.assertEqual(result_la, result_arrow_la)

with self.sql_conf({
"spark.sql.execution.pandas.respectSessionTimeZone": True,
"spark.sql.session.timeZone": timezone}):
df_no_arrow_ny, df_arrow_ny = self._createDataFrame_toggle(pdf, schema=self.schema)
result_ny = df_no_arrow_ny.collect()
result_arrow_ny = df_arrow_ny.collect()
@@ -3568,8 +3581,6 @@ def test_createDataFrame_respect_session_timezone(self):
for k, v in row.asDict().items()})
for row in result_la]
self.assertEqual(result_ny, result_la_corrected)
finally:
self.spark.conf.set("spark.sql.session.timeZone", orig_tz)

def test_createDataFrame_with_schema(self):
pdf = self.create_pandas_data_frame()
@@ -4222,9 +4233,7 @@ def gen_timestamps(id):
def test_vectorized_udf_check_config(self):
from pyspark.sql.functions import pandas_udf, col
import pandas as pd
orig_value = self.spark.conf.get("spark.sql.execution.arrow.maxRecordsPerBatch", None)
self.spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", 3)
try:
with self.sql_conf({"spark.sql.execution.arrow.maxRecordsPerBatch": 3}):
df = self.spark.range(10, numPartitions=1)

@pandas_udf(returnType=LongType())
@@ -4234,11 +4243,6 @@ def check_records_per_batch(x):
result = df.select(check_records_per_batch(col("id"))).collect()
for (r,) in result:
self.assertTrue(r <= 3)
finally:
if orig_value is None:
self.spark.conf.unset("spark.sql.execution.arrow.maxRecordsPerBatch")
else:
self.spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", orig_value)

def test_vectorized_udf_timestamps_respect_session_timezone(self):
from pyspark.sql.functions import pandas_udf, col
@@ -4257,30 +4261,27 @@ def test_vectorized_udf_timestamps_respect_session_timezone(self):
internal_value = pandas_udf(
lambda ts: ts.apply(lambda ts: ts.value if ts is not pd.NaT else None), LongType())

orig_tz = self.spark.conf.get("spark.sql.session.timeZone")
try:
timezone = "America/New_York"
self.spark.conf.set("spark.sql.session.timeZone", timezone)
self.spark.conf.set("spark.sql.execution.pandas.respectSessionTimeZone", "false")
try:
df_la = df.withColumn("tscopy", f_timestamp_copy(col("timestamp"))) \
.withColumn("internal_value", internal_value(col("timestamp")))
result_la = df_la.select(col("idx"), col("internal_value")).collect()
# Correct result_la by adjusting 3 hours difference between Los Angeles and New York
diff = 3 * 60 * 60 * 1000 * 1000 * 1000
result_la_corrected = \
df_la.select(col("idx"), col("tscopy"), col("internal_value") + diff).collect()
finally:
self.spark.conf.set("spark.sql.execution.pandas.respectSessionTimeZone", "true")
timezone = "America/New_York"
with self.sql_conf({
"spark.sql.execution.pandas.respectSessionTimeZone": False,
"spark.sql.session.timeZone": timezone}):
df_la = df.withColumn("tscopy", f_timestamp_copy(col("timestamp"))) \
.withColumn("internal_value", internal_value(col("timestamp")))
result_la = df_la.select(col("idx"), col("internal_value")).collect()
# Correct result_la by adjusting 3 hours difference between Los Angeles and New York
diff = 3 * 60 * 60 * 1000 * 1000 * 1000
result_la_corrected = \
df_la.select(col("idx"), col("tscopy"), col("internal_value") + diff).collect()

with self.sql_conf({
"spark.sql.execution.pandas.respectSessionTimeZone": True,
"spark.sql.session.timeZone": timezone}):
df_ny = df.withColumn("tscopy", f_timestamp_copy(col("timestamp"))) \
.withColumn("internal_value", internal_value(col("timestamp")))
result_ny = df_ny.select(col("idx"), col("tscopy"), col("internal_value")).collect()

self.assertNotEqual(result_ny, result_la)
self.assertEqual(result_ny, result_la_corrected)
finally:
self.spark.conf.set("spark.sql.session.timeZone", orig_tz)

def test_nondeterministic_vectorized_udf(self):
# Test that nondeterministic UDFs are evaluated only once in chained UDF evaluations

0 comments on commit c854b6c

Please sign in to comment.
You can’t perform that action at this time.