Skip to content

Commit

Permalink
[SPARK-3713][SQL] Uses JSON to serialize DataType objects
Browse files Browse the repository at this point in the history
This PR uses JSON instead of `toString` to serialize `DataType`s. The latter is not only hard to parse but also flaky in many cases.

Since we already write schema information to Parquet metadata in the old style, we have to reserve the old `DataType` parser and ensure downward compatibility. The old parser is now renamed to `CaseClassStringParser` and moved into `object DataType`.

JoshRosen davies Please help review PySpark related changes, thanks!

Author: Cheng Lian <lian.cs.zju@gmail.com>

Closes #2563 from liancheng/datatype-to-json and squashes the following commits:

fc92eb3 [Cheng Lian] Reverts debugging code, simplifies primitive type JSON representation
438c75f [Cheng Lian] Refactors PySpark DataType JSON SerDe per comments
6b6387b [Cheng Lian] Removes debugging code
6a3ee3a [Cheng Lian] Addresses per review comments
dc158b5 [Cheng Lian] Addresses PEP8 issues
99ab4ee [Cheng Lian] Adds compatibility est case for Parquet type conversion
a983a6c [Cheng Lian] Adds PySpark support
f608c6e [Cheng Lian] De/serializes DataType objects from/to JSON
  • Loading branch information
liancheng authored and marmbrus committed Oct 9, 2014
1 parent a85f24a commit a42cc08
Show file tree
Hide file tree
Showing 7 changed files with 277 additions and 168 deletions.
153 changes: 75 additions & 78 deletions python/pyspark/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import datetime
import keyword
import warnings
import json
from array import array
from operator import itemgetter
from itertools import imap
Expand Down Expand Up @@ -71,6 +72,18 @@ def __eq__(self, other):
def __ne__(self, other):
return not self.__eq__(other)

@classmethod
def typeName(cls):
return cls.__name__[:-4].lower()

def jsonValue(self):
return self.typeName()

def json(self):
return json.dumps(self.jsonValue(),
separators=(',', ':'),
sort_keys=True)


class PrimitiveTypeSingleton(type):

Expand Down Expand Up @@ -214,6 +227,16 @@ def __repr__(self):
return "ArrayType(%s,%s)" % (self.elementType,
str(self.containsNull).lower())

def jsonValue(self):
return {"type": self.typeName(),
"elementType": self.elementType.jsonValue(),
"containsNull": self.containsNull}

@classmethod
def fromJson(cls, json):
return ArrayType(_parse_datatype_json_value(json["elementType"]),
json["containsNull"])


class MapType(DataType):

Expand Down Expand Up @@ -254,6 +277,18 @@ def __repr__(self):
return "MapType(%s,%s,%s)" % (self.keyType, self.valueType,
str(self.valueContainsNull).lower())

def jsonValue(self):
return {"type": self.typeName(),
"keyType": self.keyType.jsonValue(),
"valueType": self.valueType.jsonValue(),
"valueContainsNull": self.valueContainsNull}

@classmethod
def fromJson(cls, json):
return MapType(_parse_datatype_json_value(json["keyType"]),
_parse_datatype_json_value(json["valueType"]),
json["valueContainsNull"])


class StructField(DataType):

Expand Down Expand Up @@ -292,6 +327,17 @@ def __repr__(self):
return "StructField(%s,%s,%s)" % (self.name, self.dataType,
str(self.nullable).lower())

def jsonValue(self):
return {"name": self.name,
"type": self.dataType.jsonValue(),
"nullable": self.nullable}

@classmethod
def fromJson(cls, json):
return StructField(json["name"],
_parse_datatype_json_value(json["type"]),
json["nullable"])


class StructType(DataType):

Expand Down Expand Up @@ -321,42 +367,30 @@ def __repr__(self):
return ("StructType(List(%s))" %
",".join(str(field) for field in self.fields))

def jsonValue(self):
return {"type": self.typeName(),
"fields": [f.jsonValue() for f in self.fields]}

def _parse_datatype_list(datatype_list_string):
"""Parses a list of comma separated data types."""
index = 0
datatype_list = []
start = 0
depth = 0
while index < len(datatype_list_string):
if depth == 0 and datatype_list_string[index] == ",":
datatype_string = datatype_list_string[start:index].strip()
datatype_list.append(_parse_datatype_string(datatype_string))
start = index + 1
elif datatype_list_string[index] == "(":
depth += 1
elif datatype_list_string[index] == ")":
depth -= 1
@classmethod
def fromJson(cls, json):
return StructType([StructField.fromJson(f) for f in json["fields"]])

index += 1

# Handle the last data type
datatype_string = datatype_list_string[start:index].strip()
datatype_list.append(_parse_datatype_string(datatype_string))
return datatype_list
_all_primitive_types = dict((v.typeName(), v)
for v in globals().itervalues()
if type(v) is PrimitiveTypeSingleton and
v.__base__ == PrimitiveType)


_all_primitive_types = dict((k, v) for k, v in globals().iteritems()
if type(v) is PrimitiveTypeSingleton and v.__base__ == PrimitiveType)
_all_complex_types = dict((v.typeName(), v)
for v in [ArrayType, MapType, StructType])


def _parse_datatype_string(datatype_string):
"""Parses the given data type string.
def _parse_datatype_json_string(json_string):
"""Parses the given data type JSON string.
>>> def check_datatype(datatype):
... scala_datatype = sqlCtx._ssql_ctx.parseDataType(str(datatype))
... python_datatype = _parse_datatype_string(
... scala_datatype.toString())
... scala_datatype = sqlCtx._ssql_ctx.parseDataType(datatype.json())
... python_datatype = _parse_datatype_json_string(scala_datatype.json())
... return datatype == python_datatype
>>> all(check_datatype(cls()) for cls in _all_primitive_types.values())
True
Expand Down Expand Up @@ -394,51 +428,14 @@ def _parse_datatype_string(datatype_string):
>>> check_datatype(complex_maptype)
True
"""
index = datatype_string.find("(")
if index == -1:
# It is a primitive type.
index = len(datatype_string)
type_or_field = datatype_string[:index]
rest_part = datatype_string[index + 1:len(datatype_string) - 1].strip()

if type_or_field in _all_primitive_types:
return _all_primitive_types[type_or_field]()

elif type_or_field == "ArrayType":
last_comma_index = rest_part.rfind(",")
containsNull = True
if rest_part[last_comma_index + 1:].strip().lower() == "false":
containsNull = False
elementType = _parse_datatype_string(
rest_part[:last_comma_index].strip())
return ArrayType(elementType, containsNull)

elif type_or_field == "MapType":
last_comma_index = rest_part.rfind(",")
valueContainsNull = True
if rest_part[last_comma_index + 1:].strip().lower() == "false":
valueContainsNull = False
keyType, valueType = _parse_datatype_list(
rest_part[:last_comma_index].strip())
return MapType(keyType, valueType, valueContainsNull)

elif type_or_field == "StructField":
first_comma_index = rest_part.find(",")
name = rest_part[:first_comma_index].strip()
last_comma_index = rest_part.rfind(",")
nullable = True
if rest_part[last_comma_index + 1:].strip().lower() == "false":
nullable = False
dataType = _parse_datatype_string(
rest_part[first_comma_index + 1:last_comma_index].strip())
return StructField(name, dataType, nullable)

elif type_or_field == "StructType":
# rest_part should be in the format like
# List(StructField(field1,IntegerType,false)).
field_list_string = rest_part[rest_part.find("(") + 1:-1]
fields = _parse_datatype_list(field_list_string)
return StructType(fields)
return _parse_datatype_json_value(json.loads(json_string))


def _parse_datatype_json_value(json_value):
if type(json_value) is unicode and json_value in _all_primitive_types.keys():
return _all_primitive_types[json_value]()
else:
return _all_complex_types[json_value["type"]].fromJson(json_value)


# Mapping Python types to Spark SQL DateType
Expand Down Expand Up @@ -992,7 +989,7 @@ def registerFunction(self, name, f, returnType=StringType()):
self._sc.pythonExec,
broadcast_vars,
self._sc._javaAccumulator,
str(returnType))
returnType.json())

def inferSchema(self, rdd):
"""Infer and apply a schema to an RDD of L{Row}.
Expand Down Expand Up @@ -1128,7 +1125,7 @@ def applySchema(self, rdd, schema):

batched = isinstance(rdd._jrdd_deserializer, BatchedSerializer)
jrdd = self._pythonToJava(rdd._jrdd, batched)
srdd = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), str(schema))
srdd = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), schema.json())
return SchemaRDD(srdd.toJavaSchemaRDD(), self)

def registerRDDAsTable(self, rdd, tableName):
Expand Down Expand Up @@ -1218,7 +1215,7 @@ def jsonFile(self, path, schema=None):
if schema is None:
srdd = self._ssql_ctx.jsonFile(path)
else:
scala_datatype = self._ssql_ctx.parseDataType(str(schema))
scala_datatype = self._ssql_ctx.parseDataType(schema.json())
srdd = self._ssql_ctx.jsonFile(path, scala_datatype)
return SchemaRDD(srdd.toJavaSchemaRDD(), self)

Expand Down Expand Up @@ -1288,7 +1285,7 @@ def func(iterator):
if schema is None:
srdd = self._ssql_ctx.jsonRDD(jrdd.rdd())
else:
scala_datatype = self._ssql_ctx.parseDataType(str(schema))
scala_datatype = self._ssql_ctx.parseDataType(schema.json())
srdd = self._ssql_ctx.jsonRDD(jrdd.rdd(), scala_datatype)
return SchemaRDD(srdd.toJavaSchemaRDD(), self)

Expand Down Expand Up @@ -1623,7 +1620,7 @@ def saveAsTable(self, tableName):
def schema(self):
"""Returns the schema of this SchemaRDD (represented by
a L{StructType})."""
return _parse_datatype_string(self._jschema_rdd.baseSchemaRDD().schema().toString())
return _parse_datatype_json_string(self._jschema_rdd.baseSchemaRDD().schema().json())

def schemaString(self):
"""Returns the output schema in the tree format."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,7 @@ import org.apache.spark.sql.catalyst.types.DataType
/**
* The data type representing [[DynamicRow]] values.
*/
case object DynamicType extends DataType {
def simpleString: String = "dynamic"
}
case object DynamicType extends DataType

/**
* Wrap a [[Row]] as a [[DynamicRow]].
Expand Down
Loading

0 comments on commit a42cc08

Please sign in to comment.