From 90207c4a29bae033ec4c144faba6a5cd7ff4071e Mon Sep 17 00:00:00 2001 From: "Sheamus K. Parkes" Date: Thu, 7 Apr 2016 21:41:43 -0400 Subject: [PATCH 1/9] Expand pyspark.sql.types.StructType tests. Do this to cover the anticipated new APIs. --- python/pyspark/sql/tests.py | 23 +++++++++++++++++++---- 1 file changed, 19 insertions(+), 4 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index e4f79c911c0d9..e04b02935a0db 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -802,11 +802,26 @@ def test_struct_type(self): self.assertNotEqual(struct1, struct2) # Catch exception raised during improper construction - try: + with self.assertRaises(ValueError): struct1 = StructType().add("name") - self.assertEqual(1, 0) - except ValueError: - self.assertEqual(1, 1) + + struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None) + for field in struct1: + self.assertIsInstance(field, StructField) + + struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None) + self.assertEqual(len(struct1), 2) + + struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None) + self.assertIs(struct1["f1"], struct1.fields[0]) + self.assertIs(struct1[0], struct1.fields[0]) + self.assertEqual(struct1[0:1], struct1.fields[0:1]) + with self.assertRaises(KeyError): + not_a_field = struct1["f9"] + with self.assertRaises(IndexError): + not_a_field = struct1[9] + with self.assertRaises(TypeError): + not_a_field = struct1[9.9] def test_metadata_null(self): from pyspark.sql.types import StructType, StringType, StructField From 139243aae61ad97f11c32a512337ca9de857719e Mon Sep 17 00:00:00 2001 From: "Sheamus K. Parkes" Date: Thu, 7 Apr 2016 21:43:48 -0400 Subject: [PATCH 2/9] Expand pyspark.sql.types.StructType accessors. Do this to support more of the Pythonic "magic" methods. --- python/pyspark/sql/types.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 734c1533a24bc..3cb4eb4f4d154 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -511,6 +511,30 @@ def add(self, field, data_type=None, nullable=True, metadata=None): self._needSerializeAnyField = any(f.needConversion() for f in self.fields) return self + def __iter__(self): + """Iterate the fields""" + return iter(self.fields) + + def __len__(self): + """Return the number of fields.""" + return len(self.fields) + + def __getitem__(self, key): + """Access fields by name or slice.""" + if isinstance(key, str): + _dict_fields = {field.name: field for field in self} + try: + return _dict_fields[key] + except KeyError: + raise KeyError('No StructField named {}'.format(key)) + elif isinstance(key, (int, slice)): + try: + return self.fields[key] + except IndexError: + raise IndexError('StructType index out of range') + else: + raise TypeError('StructType keys should be strings, integers or slices') + def simpleString(self): return 'struct<%s>' % (','.join(f.simpleString() for f in self.fields)) From 03931159ca433066089087fee7a50ef6ebdeb220 Mon Sep 17 00:00:00 2001 From: "Sheamus K. Parkes" Date: Thu, 7 Apr 2016 21:47:55 -0400 Subject: [PATCH 3/9] Dogfood new __iter__ in pyspark.sql.types.StructField. Do this because the more direct syntax is now available. --- python/pyspark/sql/types.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 3cb4eb4f4d154..a6c7a54a27ce7 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -463,7 +463,7 @@ def __init__(self, fields=None): self.names = [f.name for f in fields] assert all(isinstance(f, StructField) for f in fields),\ "fields should be a list of StructField" - self._needSerializeAnyField = any(f.needConversion() for f in self.fields) + self._needSerializeAnyField = any(f.needConversion() for f in self) def add(self, field, data_type=None, nullable=True, metadata=None): """ @@ -508,7 +508,7 @@ def add(self, field, data_type=None, nullable=True, metadata=None): data_type_f = data_type self.fields.append(StructField(field, data_type_f, nullable, metadata)) self.names.append(field) - self._needSerializeAnyField = any(f.needConversion() for f in self.fields) + self._needSerializeAnyField = any(f.needConversion() for f in self) return self def __iter__(self): @@ -536,15 +536,15 @@ def __getitem__(self, key): raise TypeError('StructType keys should be strings, integers or slices') def simpleString(self): - return 'struct<%s>' % (','.join(f.simpleString() for f in self.fields)) + return 'struct<%s>' % (','.join(f.simpleString() for f in self)) def __repr__(self): return ("StructType(List(%s))" % - ",".join(str(field) for field in self.fields)) + ",".join(str(field) for field in self)) def jsonValue(self): return {"type": self.typeName(), - "fields": [f.jsonValue() for f in self.fields]} + "fields": [f.jsonValue() for f in self]} @classmethod def fromJson(cls, json): From 831f03a3098d0e165703db6be0f4d76ef4893667 Mon Sep 17 00:00:00 2001 From: "Sheamus K. Parkes" Date: Thu, 7 Apr 2016 21:53:29 -0400 Subject: [PATCH 4/9] Make pyspark.sql.types.StructType attributes lazier. Do this so there is less chance for someone to get the stateful attributes out of sync. --- python/pyspark/sql/tests.py | 3 +++ python/pyspark/sql/types.py | 16 ++++++++++------ 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index e04b02935a0db..bdb49d0287a88 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -823,6 +823,9 @@ def test_struct_type(self): with self.assertRaises(TypeError): not_a_field = struct1[9.9] + struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None) + self.assertEqual(struct1.names, ["f1", "f2"]) + def test_metadata_null(self): from pyspark.sql.types import StructType, StringType, StructField schema = StructType([StructField("f1", StringType(), True, None), diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index a6c7a54a27ce7..02eef46bc471a 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -457,13 +457,10 @@ def __init__(self, fields=None): """ if not fields: self.fields = [] - self.names = [] else: self.fields = fields - self.names = [f.name for f in fields] assert all(isinstance(f, StructField) for f in fields),\ "fields should be a list of StructField" - self._needSerializeAnyField = any(f.needConversion() for f in self) def add(self, field, data_type=None, nullable=True, metadata=None): """ @@ -497,7 +494,6 @@ def add(self, field, data_type=None, nullable=True, metadata=None): """ if isinstance(field, StructField): self.fields.append(field) - self.names.append(field.name) else: if isinstance(field, str) and data_type is None: raise ValueError("Must specify DataType if passing name of struct_field to create.") @@ -507,10 +503,18 @@ def add(self, field, data_type=None, nullable=True, metadata=None): else: data_type_f = data_type self.fields.append(StructField(field, data_type_f, nullable, metadata)) - self.names.append(field) - self._needSerializeAnyField = any(f.needConversion() for f in self) return self + @property + def names(self): + """Return a list of the field names""" + return [field.name for field in self] + + @property + def _needSerializeAnyField(self): + """Determine if any field needs conversion""" + return any(field.needConversion() for field in self) + def __iter__(self): """Iterate the fields""" return iter(self.fields) From 5ffffd981a4119891aa61cec29d0b877e887b5bd Mon Sep 17 00:00:00 2001 From: "Sheamus K. Parkes" Date: Thu, 7 Apr 2016 23:07:11 -0400 Subject: [PATCH 5/9] Expand pyspark.sql.types.StructType docstring. Do this so the new accessors are better documented. --- python/pyspark/sql/types.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 02eef46bc471a..cfa0fcd90b44c 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -442,6 +442,15 @@ class StructType(DataType): """Struct type, consisting of a list of :class:`StructField`. This is the data type representing a :class:`Row`. + + Iterating a :class:`StructType` will iterate its :class:`StructField`s. + A contained :class:`StructField` can be accessed by name or position. + + >>> struct1 = StructType([StructField("f1", StringType(), True)]) + >>> struct1["f1"] + StructField(f1,StringType,true) + >>> struct1[0] + StructField(f1,StringType,true) """ def __init__(self, fields=None): """ From 4cf5d015b7a9e8647b6debfd2687a8830322b4f1 Mon Sep 17 00:00:00 2001 From: "Sheamus K. Parkes" Date: Mon, 18 Apr 2016 21:33:05 -0400 Subject: [PATCH 6/9] Correctly slice StructTypes. Do this because slicing containers in Python should return container objects of the same type. --- python/pyspark/sql/tests.py | 2 +- python/pyspark/sql/types.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index bdb49d0287a88..786da2690f374 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -815,7 +815,7 @@ def test_struct_type(self): struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None) self.assertIs(struct1["f1"], struct1.fields[0]) self.assertIs(struct1[0], struct1.fields[0]) - self.assertEqual(struct1[0:1], struct1.fields[0:1]) + self.assertEqual(struct1[0:1], StructType(struct1.fields[0:1])) with self.assertRaises(KeyError): not_a_field = struct1["f9"] with self.assertRaises(IndexError): diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index cfa0fcd90b44c..6b1e1fc9f21b7 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -540,11 +540,13 @@ def __getitem__(self, key): return _dict_fields[key] except KeyError: raise KeyError('No StructField named {}'.format(key)) - elif isinstance(key, (int, slice)): + elif isinstance(key, int): try: return self.fields[key] except IndexError: raise IndexError('StructType index out of range') + elif isinstance(key, slice): + return StructType(self.fields[key]) else: raise TypeError('StructType keys should be strings, integers or slices') From 387cdfe0050f41f35710272e51593c484e5d0bfe Mon Sep 17 00:00:00 2001 From: "Sheamus K. Parkes" Date: Tue, 19 Apr 2016 15:07:04 -0400 Subject: [PATCH 7/9] Revert "Make pyspark.sql.types.StructType attributes lazier." This reverts commit 831f03a3098d0e165703db6be0f4d76ef4893667. --- python/pyspark/sql/tests.py | 3 --- python/pyspark/sql/types.py | 16 ++++++---------- 2 files changed, 6 insertions(+), 13 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 786da2690f374..2fcb6eb5f27c8 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -823,9 +823,6 @@ def test_struct_type(self): with self.assertRaises(TypeError): not_a_field = struct1[9.9] - struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None) - self.assertEqual(struct1.names, ["f1", "f2"]) - def test_metadata_null(self): from pyspark.sql.types import StructType, StringType, StructField schema = StructType([StructField("f1", StringType(), True, None), diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 6b1e1fc9f21b7..0ce5b449c6a2a 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -466,10 +466,13 @@ def __init__(self, fields=None): """ if not fields: self.fields = [] + self.names = [] else: self.fields = fields + self.names = [f.name for f in fields] assert all(isinstance(f, StructField) for f in fields),\ "fields should be a list of StructField" + self._needSerializeAnyField = any(f.needConversion() for f in self) def add(self, field, data_type=None, nullable=True, metadata=None): """ @@ -503,6 +506,7 @@ def add(self, field, data_type=None, nullable=True, metadata=None): """ if isinstance(field, StructField): self.fields.append(field) + self.names.append(field.name) else: if isinstance(field, str) and data_type is None: raise ValueError("Must specify DataType if passing name of struct_field to create.") @@ -512,18 +516,10 @@ def add(self, field, data_type=None, nullable=True, metadata=None): else: data_type_f = data_type self.fields.append(StructField(field, data_type_f, nullable, metadata)) + self.names.append(field) + self._needSerializeAnyField = any(f.needConversion() for f in self) return self - @property - def names(self): - """Return a list of the field names""" - return [field.name for field in self] - - @property - def _needSerializeAnyField(self): - """Determine if any field needs conversion""" - return any(field.needConversion() for field in self) - def __iter__(self): """Iterate the fields""" return iter(self.fields) From b11fc7dc9e16d38af3f6c93b07ef65b487dd29f7 Mon Sep 17 00:00:00 2001 From: "Sheamus K. Parkes" Date: Wed, 20 Apr 2016 14:34:32 -0400 Subject: [PATCH 8/9] Make changes compatible with python 2.6. Do this because Spark is still targeting python 2.6 compatibility. --- python/pyspark/sql/types.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 0ce5b449c6a2a..8f3787607583e 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -531,11 +531,13 @@ def __len__(self): def __getitem__(self, key): """Access fields by name or slice.""" if isinstance(key, str): - _dict_fields = {field.name: field for field in self} + _dict_fields = {} + for field in self: + _dict_fields[field.name] = field try: return _dict_fields[key] except KeyError: - raise KeyError('No StructField named {}'.format(key)) + raise KeyError('No StructField named {0}'.format(key)) elif isinstance(key, int): try: return self.fields[key] From fc7898b0999aba15fc43465fddc116cea7c14f06 Mon Sep 17 00:00:00 2001 From: "Sheamus K. Parkes" Date: Wed, 20 Apr 2016 14:42:17 -0400 Subject: [PATCH 9/9] Switch to returning field from inside loop. Do this because without the dictionary comprehension syntax this was looking very busy. --- python/pyspark/sql/types.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 8f3787607583e..f7cd4b80ca91d 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -531,13 +531,10 @@ def __len__(self): def __getitem__(self, key): """Access fields by name or slice.""" if isinstance(key, str): - _dict_fields = {} for field in self: - _dict_fields[field.name] = field - try: - return _dict_fields[key] - except KeyError: - raise KeyError('No StructField named {0}'.format(key)) + if field.name == key: + return field + raise KeyError('No StructField named {0}'.format(key)) elif isinstance(key, int): try: return self.fields[key]