From efe113fdb8f2ec7ba57766b0b87a6c8df9728291 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Thu, 13 Jul 2017 11:06:44 +0900 Subject: [PATCH 1/3] Add StructType.fieldNames in PySpark --- python/pyspark/sql/tests.py | 10 +++++++--- python/pyspark/sql/types.py | 12 +++++++++++- 2 files changed, 18 insertions(+), 4 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 29e48a6ccf763..a6b23bbd39f32 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -1216,21 +1216,25 @@ def test_struct_type(self): struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None) struct2 = StructType([StructField("f1", StringType(), True), StructField("f2", StringType(), True, None)]) + self.assertEqual(struct1.fieldNames(), tuple(struct2.names)) self.assertEqual(struct1, struct2) struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None) struct2 = StructType([StructField("f1", StringType(), True)]) + self.assertNotEqual(struct1.fieldNames(), tuple(struct2.names)) self.assertNotEqual(struct1, struct2) struct1 = (StructType().add(StructField("f1", StringType(), True)) .add(StructField("f2", StringType(), True, None))) struct2 = StructType([StructField("f1", StringType(), True), StructField("f2", StringType(), True, None)]) + self.assertEqual(struct1.fieldNames(), tuple(struct2.names)) self.assertEqual(struct1, struct2) struct1 = (StructType().add(StructField("f1", StringType(), True)) .add(StructField("f2", StringType(), True, None))) struct2 = StructType([StructField("f1", StringType(), True)]) + self.assertNotEqual(struct1.fieldNames(), tuple(struct2.names)) self.assertNotEqual(struct1, struct2) # Catch exception raised during improper construction @@ -1249,11 +1253,11 @@ def test_struct_type(self): self.assertIs(struct1[0], struct1.fields[0]) self.assertEqual(struct1[0:1], StructType(struct1.fields[0:1])) with self.assertRaises(KeyError): - not_a_field = struct1["f9"] + _ = struct1["f9"] with self.assertRaises(IndexError): - not_a_field = struct1[9] + _ = struct1[9] with self.assertRaises(TypeError): - not_a_field = struct1[9.9] + _ = struct1[9.9] def test_parse_datatype_string(self): from pyspark.sql.types import _all_atomic_types, _parse_datatype_string diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 22fa273fc1aac..20c21d12b1b58 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -445,7 +445,7 @@ class StructType(DataType): This is the data type representing a :class:`Row`. - Iterating a :class:`StructType` will iterate its :class:`StructField`s. + Iterating a :class:`StructType` will iterate its :class:`StructField`\\s. A contained :class:`StructField` can be accessed by name or position. >>> struct1 = StructType([StructField("f1", StringType(), True)]) @@ -562,6 +562,16 @@ def jsonValue(self): def fromJson(cls, json): return StructType([StructField.fromJson(f) for f in json["fields"]]) + def fieldNames(self): + """ + Returns all field names in a tuple. + + >>> struct = StructType([StructField("f1", StringType(), True)]) + >>> struct.fieldNames() + ('f1',) + """ + return tuple(self.names) + def needConversion(self): # We need convert Row()/namedtuple into tuple() return True From eaa910dfdb31448724384352674440232fb584b6 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Sat, 15 Jul 2017 13:15:58 +0900 Subject: [PATCH 2/3] Make the return type to a list instead of a tuple --- python/pyspark/sql/tests.py | 20 ++++++++------------ python/pyspark/sql/types.py | 6 +++--- 2 files changed, 11 insertions(+), 15 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index a6b23bbd39f32..53706ba5625db 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -1216,30 +1216,29 @@ def test_struct_type(self): struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None) struct2 = StructType([StructField("f1", StringType(), True), StructField("f2", StringType(), True, None)]) - self.assertEqual(struct1.fieldNames(), tuple(struct2.names)) + self.assertEqual(struct1.fieldNames(), struct2.names) self.assertEqual(struct1, struct2) struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None) struct2 = StructType([StructField("f1", StringType(), True)]) - self.assertNotEqual(struct1.fieldNames(), tuple(struct2.names)) + self.assertNotEqual(struct1.fieldNames(), struct2.names) self.assertNotEqual(struct1, struct2) struct1 = (StructType().add(StructField("f1", StringType(), True)) .add(StructField("f2", StringType(), True, None))) struct2 = StructType([StructField("f1", StringType(), True), StructField("f2", StringType(), True, None)]) - self.assertEqual(struct1.fieldNames(), tuple(struct2.names)) + self.assertEqual(struct1.fieldNames(), struct2.names) self.assertEqual(struct1, struct2) struct1 = (StructType().add(StructField("f1", StringType(), True)) .add(StructField("f2", StringType(), True, None))) struct2 = StructType([StructField("f1", StringType(), True)]) - self.assertNotEqual(struct1.fieldNames(), tuple(struct2.names)) + self.assertNotEqual(struct1.fieldNames(), struct2.names) self.assertNotEqual(struct1, struct2) # Catch exception raised during improper construction - with self.assertRaises(ValueError): - struct1 = StructType().add("name") + self.assertRaises(ValueError, lambda: StructType().add("name")) struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None) for field in struct1: @@ -1252,12 +1251,9 @@ def test_struct_type(self): self.assertIs(struct1["f1"], struct1.fields[0]) self.assertIs(struct1[0], struct1.fields[0]) self.assertEqual(struct1[0:1], StructType(struct1.fields[0:1])) - with self.assertRaises(KeyError): - _ = struct1["f9"] - with self.assertRaises(IndexError): - _ = struct1[9] - with self.assertRaises(TypeError): - _ = struct1[9.9] + self.assertRaises(KeyError, lambda: struct1["f9"]) + self.assertRaises(IndexError, lambda: struct1[9]) + self.assertRaises(TypeError, lambda: struct1[9.9]) def test_parse_datatype_string(self): from pyspark.sql.types import _all_atomic_types, _parse_datatype_string diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 20c21d12b1b58..14e0c62161daa 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -564,13 +564,13 @@ def fromJson(cls, json): def fieldNames(self): """ - Returns all field names in a tuple. + Returns all field names in a list. >>> struct = StructType([StructField("f1", StringType(), True)]) >>> struct.fieldNames() - ('f1',) + ['f1'] """ - return tuple(self.names) + return list(self.names) def needConversion(self): # We need convert Row()/namedtuple into tuple() From 86493bee60502e326d53ad61dd260b78db8050d5 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Fri, 21 Jul 2017 10:00:04 +0900 Subject: [PATCH 3/3] Add a note for deprecation for names attribute in StructType --- python/pyspark/sql/types.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 14e0c62161daa..a81aaa30903dd 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -448,6 +448,9 @@ class StructType(DataType): Iterating a :class:`StructType` will iterate its :class:`StructField`\\s. A contained :class:`StructField` can be accessed by name or position. + .. note:: `names` attribute is deprecated in 2.3. Use `fieldNames` method instead + to get a list of field names. + >>> struct1 = StructType([StructField("f1", StringType(), True)]) >>> struct1["f1"] StructField(f1,StringType,true)