From 47f74fd9b037b2b0770ac6e191d1717c22aed81b Mon Sep 17 00:00:00 2001 From: jhereth Date: Mon, 14 Oct 2019 23:31:34 +0200 Subject: [PATCH] [SPARK-24915][PySpark] Fix Handling of Rows with Schema. --- python/pyspark/sql/tests/test_types.py | 46 ++++++++++++++++++++++++++ python/pyspark/sql/types.py | 3 ++ 2 files changed, 49 insertions(+) diff --git a/python/pyspark/sql/tests/test_types.py b/python/pyspark/sql/tests/test_types.py index 1cd84e0cd24e8..9ec48c10b8fe6 100644 --- a/python/pyspark/sql/tests/test_types.py +++ b/python/pyspark/sql/tests/test_types.py @@ -215,6 +215,52 @@ def test_create_dataframe_from_objects(self): self.assertEqual(df.dtypes, [("key", "bigint"), ("value", "string")]) self.assertEqual(df.first(), Row(key=1, value="1")) + def test_create_dataframe_from_rows_mixed_with_datetype(self): + data = [Row(name='Alice', join_date=datetime.date(2014, 5, 26)), + Row(name='Bob', join_date=datetime.date(2016, 7, 26))] + schema1 = StructType([ + StructField("join_date", DateType(), False), + StructField("name", StringType(), False), + ]) + schema2 = StructType([ + StructField("name", StringType(), False), + StructField("join_date", DateType(), False), + ]) + df = self.spark.createDataFrame(data, schema=schema1) + self.assertEqual(df.dtypes, [("join_date", "date"), ("name", "string")]) + self.assertEqual(df.first().asDict(), + Row(name='Alice', join_date=datetime.date(2014, 5, 26)).asDict()) + df = self.spark.createDataFrame(data, schema=schema2) + self.assertEqual(df.dtypes, [("name", "string"), ("join_date", "date")]) + self.assertEqual(df.first().asDict(), + Row(name='Alice', join_date=datetime.date(2014, 5, 26)).asDict()) + + def test_create_dataframe_from_rows_with_nested_row(self): + schema = StructType([ + StructField('field2', + StructType([ + StructField('sub_field', StringType(), False) + ]), False), + StructField('field1', StringType(), False), + ]) + row = Row(field1="Hello", field2=Row(sub_field='world')) + data = [row] + df = self.spark.createDataFrame(data, schema=schema) + self.assertEqual(df.dtypes, [('field2', 'struct'), + ('field1', 'string')]) + self.assertEqual(df.first().asDict(), row.asDict()) + + def test_create_dataframe_from_tuple_rows(self): + data = [Row('Alice', datetime.date(2014, 5, 26)), + Row('Bob', datetime.date(2016, 7, 26))] + schema = StructType([ + StructField("name", StringType(), False), + StructField("join_date", DateType(), False), + ]) + df = self.spark.createDataFrame(data, schema=schema) + self.assertEqual(df.dtypes, [("name", "string"), ("join_date", "date")]) + self.assertEqual(df.first(), Row('Alice', datetime.date(2014, 5, 26))) + def test_apply_schema(self): from datetime import date, datetime rdd = self.sc.parallelize([(127, -128, -32768, 32767, 2147483647, 1.0, diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 81fdd41435694..e9b517c18555f 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -599,6 +599,9 @@ def toInternal(self, obj): if isinstance(obj, dict): return tuple(f.toInternal(obj.get(n)) if c else obj.get(n) for n, f, c in zip(self.names, self.fields, self._needConversion)) + elif isinstance(obj, Row) and getattr(obj, "__from_dict__", False): + return tuple(f.toInternal(obj[n]) if c else obj[n] + for n, f, c in zip(self.names, self.fields, self._needConversion)) elif isinstance(obj, (tuple, list)): return tuple(f.toInternal(v) if c else v for f, v, c in zip(self.fields, obj, self._needConversion))