From 67d1bf4b6cb7e14c4aa69dee450fb351d164946e Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 26 Aug 2015 15:25:16 -0700 Subject: [PATCH] fix create DataFrame from Python class --- python/pyspark/sql/tests.py | 12 ++++++++++++ python/pyspark/sql/types.py | 6 ++++++ 2 files changed, 18 insertions(+) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index aacfb34c77618..cd32e26c64f22 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -145,6 +145,12 @@ class PythonOnlyPoint(ExamplePoint): __UDT__ = PythonOnlyUDT() +class MyObject(object): + def __init__(self, key, value): + self.key = key + self.value = value + + class DataTypeTests(unittest.TestCase): # regression test for SPARK-6055 def test_data_type_eq(self): @@ -383,6 +389,12 @@ def test_infer_nested_schema(self): df = self.sqlCtx.inferSchema(rdd) self.assertEquals(Row(field1=1, field2=u'row1'), df.first()) + def test_create_dataframe_from_objects(self): + data = [MyObject(1, "1"), MyObject(2, "2")] + df = self.sqlCtx.createDataFrame(data) + self.assertEqual(df.dtypes, [("key", "bigint"), ("value", "string")]) + self.assertEqual(df.first(), Row(key=1, value="1")) + def test_select_null_literal(self): df = self.sqlCtx.sql("select null as col") self.assertEquals(Row(col=None), df.first()) diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index ed4e5b594bd61..94e581a78364c 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -537,6 +537,9 @@ def toInternal(self, obj): return tuple(f.toInternal(obj.get(n)) for n, f in zip(self.names, self.fields)) elif isinstance(obj, (tuple, list)): return tuple(f.toInternal(v) for f, v in zip(self.fields, obj)) + elif hasattr(obj, "__dict__"): + d = obj.__dict__ + return tuple(f.toInternal(d.get(n)) for n, f in zip(self.names, self.fields)) else: raise ValueError("Unexpected tuple %r with StructType" % obj) else: @@ -544,6 +547,9 @@ def toInternal(self, obj): return tuple(obj.get(n) for n in self.names) elif isinstance(obj, (list, tuple)): return tuple(obj) + elif hasattr(obj, "__dict__"): + d = obj.__dict__ + return tuple(d.get(n) for n in self.names) else: raise ValueError("Unexpected tuple %r with StructType" % obj)