Skip to content

Commit

Permalink
[SPARK-24915][PySpark] Fix Handling of Rows with Schema.
Browse files Browse the repository at this point in the history
  • Loading branch information
jhereth committed Nov 7, 2019
1 parent 9b61f90 commit 47f74fd
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 0 deletions.
46 changes: 46 additions & 0 deletions python/pyspark/sql/tests/test_types.py
Expand Up @@ -215,6 +215,52 @@ def test_create_dataframe_from_objects(self):
self.assertEqual(df.dtypes, [("key", "bigint"), ("value", "string")])
self.assertEqual(df.first(), Row(key=1, value="1"))

def test_create_dataframe_from_rows_mixed_with_datetype(self):
data = [Row(name='Alice', join_date=datetime.date(2014, 5, 26)),
Row(name='Bob', join_date=datetime.date(2016, 7, 26))]
schema1 = StructType([
StructField("join_date", DateType(), False),
StructField("name", StringType(), False),
])
schema2 = StructType([
StructField("name", StringType(), False),
StructField("join_date", DateType(), False),
])
df = self.spark.createDataFrame(data, schema=schema1)
self.assertEqual(df.dtypes, [("join_date", "date"), ("name", "string")])
self.assertEqual(df.first().asDict(),
Row(name='Alice', join_date=datetime.date(2014, 5, 26)).asDict())
df = self.spark.createDataFrame(data, schema=schema2)
self.assertEqual(df.dtypes, [("name", "string"), ("join_date", "date")])
self.assertEqual(df.first().asDict(),
Row(name='Alice', join_date=datetime.date(2014, 5, 26)).asDict())

def test_create_dataframe_from_rows_with_nested_row(self):
schema = StructType([
StructField('field2',
StructType([
StructField('sub_field', StringType(), False)
]), False),
StructField('field1', StringType(), False),
])
row = Row(field1="Hello", field2=Row(sub_field='world'))
data = [row]
df = self.spark.createDataFrame(data, schema=schema)
self.assertEqual(df.dtypes, [('field2', 'struct<sub_field:string>'),
('field1', 'string')])
self.assertEqual(df.first().asDict(), row.asDict())

def test_create_dataframe_from_tuple_rows(self):
data = [Row('Alice', datetime.date(2014, 5, 26)),
Row('Bob', datetime.date(2016, 7, 26))]
schema = StructType([
StructField("name", StringType(), False),
StructField("join_date", DateType(), False),
])
df = self.spark.createDataFrame(data, schema=schema)
self.assertEqual(df.dtypes, [("name", "string"), ("join_date", "date")])
self.assertEqual(df.first(), Row('Alice', datetime.date(2014, 5, 26)))

def test_apply_schema(self):
from datetime import date, datetime
rdd = self.sc.parallelize([(127, -128, -32768, 32767, 2147483647, 1.0,
Expand Down
3 changes: 3 additions & 0 deletions python/pyspark/sql/types.py
Expand Up @@ -599,6 +599,9 @@ def toInternal(self, obj):
if isinstance(obj, dict):
return tuple(f.toInternal(obj.get(n)) if c else obj.get(n)
for n, f, c in zip(self.names, self.fields, self._needConversion))
elif isinstance(obj, Row) and getattr(obj, "__from_dict__", False):
return tuple(f.toInternal(obj[n]) if c else obj[n]
for n, f, c in zip(self.names, self.fields, self._needConversion))
elif isinstance(obj, (tuple, list)):
return tuple(f.toInternal(v) if c else v
for f, v, c in zip(self.fields, obj, self._needConversion))
Expand Down

0 comments on commit 47f74fd

Please sign in to comment.