Skip to content

Commit

Permalink
[SPARK-23290][SQL][PYTHON] Use datetime.date for date type when conve…
Browse files Browse the repository at this point in the history
…rting Spark DataFrame to Pandas DataFrame.

## What changes were proposed in this pull request?

In #18664, there was a change in how `DateType` is being returned to users ([line 1968 in dataframe.py](https://github.com/apache/spark/pull/18664/files#diff-6fc344560230bf0ef711bb9b5573f1faR1968)). This can cause client code which works in Spark 2.2 to fail.
See [SPARK-23290](https://issues.apache.org/jira/browse/SPARK-23290?focusedCommentId=16350917&page=com.atlassian.jira.plugin.system.issuetabpanels%3Acomment-tabpanel#comment-16350917) for an example.

This pr modifies to use `datetime.date` for date type as Spark 2.2 does.

## How was this patch tested?

Tests modified to fit the new behavior and existing tests.

Author: Takuya UESHIN <ueshin@databricks.com>

Closes #20506 from ueshin/issues/SPARK-23290.
  • Loading branch information
ueshin committed Feb 6, 2018
1 parent 4aa9aaf commit b489f4a
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 22 deletions.
9 changes: 6 additions & 3 deletions python/pyspark/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,12 +267,15 @@ def load_stream(self, stream):
"""
Deserialize ArrowRecordBatches to an Arrow table and return as a list of pandas.Series.
"""
from pyspark.sql.types import _check_dataframe_localize_timestamps
from pyspark.sql.types import from_arrow_schema, _check_dataframe_convert_date, \
_check_dataframe_localize_timestamps
import pyarrow as pa
reader = pa.open_stream(stream)
schema = from_arrow_schema(reader.schema)
for batch in reader:
# NOTE: changed from pa.Columns.to_pandas, timezone issue in conversion fixed in 0.7.1
pdf = _check_dataframe_localize_timestamps(batch.to_pandas(), self._timezone)
pdf = batch.to_pandas()
pdf = _check_dataframe_convert_date(pdf, schema)
pdf = _check_dataframe_localize_timestamps(pdf, self._timezone)
yield [c for _, c in pdf.iteritems()]

def __repr__(self):
Expand Down
7 changes: 3 additions & 4 deletions python/pyspark/sql/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1923,14 +1923,16 @@ def toPandas(self):

if self.sql_ctx.getConf("spark.sql.execution.arrow.enabled", "false").lower() == "true":
try:
from pyspark.sql.types import _check_dataframe_localize_timestamps
from pyspark.sql.types import _check_dataframe_convert_date, \
_check_dataframe_localize_timestamps
from pyspark.sql.utils import require_minimum_pyarrow_version
import pyarrow
require_minimum_pyarrow_version()
tables = self._collectAsArrow()
if tables:
table = pyarrow.concat_tables(tables)
pdf = table.to_pandas()
pdf = _check_dataframe_convert_date(pdf, self.schema)
return _check_dataframe_localize_timestamps(pdf, timezone)
else:
return pd.DataFrame.from_records([], columns=self.columns)
Expand Down Expand Up @@ -2009,7 +2011,6 @@ def _to_corrected_pandas_type(dt):
"""
When converting Spark SQL records to Pandas DataFrame, the inferred data type may be wrong.
This method gets the corrected data type for Pandas if that type may be inferred uncorrectly.
NOTE: DateType is inferred incorrectly as 'object', TimestampType is correct with datetime64[ns]
"""
import numpy as np
if type(dt) == ByteType:
Expand All @@ -2020,8 +2021,6 @@ def _to_corrected_pandas_type(dt):
return np.int32
elif type(dt) == FloatType:
return np.float32
elif type(dt) == DateType:
return 'datetime64[ns]'
else:
return None

Expand Down
57 changes: 42 additions & 15 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -2810,7 +2810,7 @@ def test_to_pandas(self):
self.assertEquals(types[1], np.object)
self.assertEquals(types[2], np.bool)
self.assertEquals(types[3], np.float32)
self.assertEquals(types[4], 'datetime64[ns]')
self.assertEquals(types[4], np.object) # datetime.date
self.assertEquals(types[5], 'datetime64[ns]')

@unittest.skipIf(not _have_old_pandas, "Old Pandas not installed")
Expand Down Expand Up @@ -3356,7 +3356,7 @@ class ArrowTests(ReusedSQLTestCase):

@classmethod
def setUpClass(cls):
from datetime import datetime
from datetime import date, datetime
from decimal import Decimal
ReusedSQLTestCase.setUpClass()

Expand All @@ -3378,11 +3378,11 @@ def setUpClass(cls):
StructField("7_date_t", DateType(), True),
StructField("8_timestamp_t", TimestampType(), True)])
cls.data = [(u"a", 1, 10, 0.2, 2.0, Decimal("2.0"),
datetime(1969, 1, 1), datetime(1969, 1, 1, 1, 1, 1)),
date(1969, 1, 1), datetime(1969, 1, 1, 1, 1, 1)),
(u"b", 2, 20, 0.4, 4.0, Decimal("4.0"),
datetime(2012, 2, 2), datetime(2012, 2, 2, 2, 2, 2)),
date(2012, 2, 2), datetime(2012, 2, 2, 2, 2, 2)),
(u"c", 3, 30, 0.8, 6.0, Decimal("6.0"),
datetime(2100, 3, 3), datetime(2100, 3, 3, 3, 3, 3))]
date(2100, 3, 3), datetime(2100, 3, 3, 3, 3, 3))]

@classmethod
def tearDownClass(cls):
Expand Down Expand Up @@ -3435,7 +3435,9 @@ def _toPandas_arrow_toggle(self, df):
def test_toPandas_arrow_toggle(self):
df = self.spark.createDataFrame(self.data, schema=self.schema)
pdf, pdf_arrow = self._toPandas_arrow_toggle(df)
self.assertFramesEqual(pdf_arrow, pdf)
expected = self.create_pandas_data_frame()
self.assertFramesEqual(expected, pdf)
self.assertFramesEqual(expected, pdf_arrow)

def test_toPandas_respect_session_timezone(self):
df = self.spark.createDataFrame(self.data, schema=self.schema)
Expand Down Expand Up @@ -4036,18 +4038,42 @@ def test_vectorized_udf_unsupported_types(self):
with self.assertRaisesRegexp(Exception, 'Unsupported data type'):
df.select(f(col('map'))).collect()

def test_vectorized_udf_null_date(self):
def test_vectorized_udf_dates(self):
from pyspark.sql.functions import pandas_udf, col
from datetime import date
schema = StructType().add("date", DateType())
data = [(date(1969, 1, 1),),
(date(2012, 2, 2),),
(None,),
(date(2100, 4, 4),)]
schema = StructType().add("idx", LongType()).add("date", DateType())
data = [(0, date(1969, 1, 1),),
(1, date(2012, 2, 2),),
(2, None,),
(3, date(2100, 4, 4),)]
df = self.spark.createDataFrame(data, schema=schema)
date_f = pandas_udf(lambda t: t, returnType=DateType())
res = df.select(date_f(col("date")))
self.assertEquals(df.collect(), res.collect())

date_copy = pandas_udf(lambda t: t, returnType=DateType())
df = df.withColumn("date_copy", date_copy(col("date")))

@pandas_udf(returnType=StringType())
def check_data(idx, date, date_copy):
import pandas as pd
msgs = []
is_equal = date.isnull()
for i in range(len(idx)):
if (is_equal[i] and data[idx[i]][1] is None) or \
date[i] == data[idx[i]][1]:
msgs.append(None)
else:
msgs.append(
"date values are not equal (date='%s': data[%d][1]='%s')"
% (date[i], idx[i], data[idx[i]][1]))
return pd.Series(msgs)

result = df.withColumn("check_data",
check_data(col("idx"), col("date"), col("date_copy"))).collect()

self.assertEquals(len(data), len(result))
for i in range(len(result)):
self.assertEquals(data[i][1], result[i][1]) # "date" col
self.assertEquals(data[i][1], result[i][2]) # "date_copy" col
self.assertIsNone(result[i][3]) # "check_data" col

def test_vectorized_udf_timestamps(self):
from pyspark.sql.functions import pandas_udf, col
Expand Down Expand Up @@ -4088,6 +4114,7 @@ def check_data(idx, timestamp, timestamp_copy):
self.assertEquals(len(data), len(result))
for i in range(len(result)):
self.assertEquals(data[i][1], result[i][1]) # "timestamp" col
self.assertEquals(data[i][1], result[i][2]) # "timestamp_copy" col
self.assertIsNone(result[i][3]) # "check_data" col

def test_vectorized_udf_return_timestamp_tz(self):
Expand Down
15 changes: 15 additions & 0 deletions python/pyspark/sql/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1694,6 +1694,21 @@ def from_arrow_schema(arrow_schema):
for field in arrow_schema])


def _check_dataframe_convert_date(pdf, schema):
""" Correct date type value to use datetime.date.
Pandas DataFrame created from PyArrow uses datetime64[ns] for date type values, but we should
use datetime.date to match the behavior with when Arrow optimization is disabled.
:param pdf: pandas.DataFrame
:param schema: a Spark schema of the pandas.DataFrame
"""
for field in schema:
if type(field.dataType) == DateType:
pdf[field.name] = pdf[field.name].dt.date
return pdf


def _check_dataframe_localize_timestamps(pdf, timezone):
"""
Convert timezone aware timestamps to timezone-naive in the specified timezone or local timezone
Expand Down

0 comments on commit b489f4a

Please sign in to comment.