Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
davies committed Jul 30, 2014
1 parent bc6e9e1 commit 182fb46
Showing 1 changed file with 48 additions and 112 deletions.
160 changes: 48 additions & 112 deletions python/pyspark/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,26 @@
"SQLContext", "HiveContext", "LocalHiveContext", "TestHiveContext", "SchemaRDD", "Row"]


class DataType(object):
"""Spark SQL DataType"""

def __repr__(self):
return self.__class__.__name__

def __hash__(self):
return hash(repr(self))

def __eq__(self, other):
return (isinstance(other, self.__class__) and
self.__dict__ == other.__dict__)

def __ne__(self, other):
return not self.__eq__(other)


class PrimitiveTypeSingleton(type):
"""Metaclass for PrimitiveType"""

_instances = {}

def __call__(cls):
Expand All @@ -44,140 +63,91 @@ def __call__(cls):
return cls._instances[cls]


class StringType(object):
class PrimitiveType(DataType):
"""Spark SQL PrimitiveType"""

__metaclass__ = PrimitiveTypeSingleton


class StringType(PrimitiveType):
"""Spark SQL StringType
The data type representing string values.
"""
__metaclass__ = PrimitiveTypeSingleton

def __repr__(self):
return "StringType"


class BinaryType(object):
class BinaryType(PrimitiveType):
"""Spark SQL BinaryType
The data type representing bytearray values.
"""
__metaclass__ = PrimitiveTypeSingleton

def __repr__(self):
return "BinaryType"


class BooleanType(object):
class BooleanType(PrimitiveType):
"""Spark SQL BooleanType
The data type representing bool values.
"""
__metaclass__ = PrimitiveTypeSingleton

def __repr__(self):
return "BooleanType"


class TimestampType(object):
class TimestampType(PrimitiveType):
"""Spark SQL TimestampType
The data type representing datetime.datetime values.
"""
__metaclass__ = PrimitiveTypeSingleton

def __repr__(self):
return "TimestampType"


class DecimalType(object):
class DecimalType(PrimitiveType):
"""Spark SQL DecimalType
The data type representing decimal.Decimal values.
"""
__metaclass__ = PrimitiveTypeSingleton

def __repr__(self):
return "DecimalType"


class DoubleType(object):
class DoubleType(PrimitiveType):
"""Spark SQL DoubleType
The data type representing float values.
"""
__metaclass__ = PrimitiveTypeSingleton

def __repr__(self):
return "DoubleType"


class FloatType(object):
class FloatType(PrimitiveType):
"""Spark SQL FloatType
The data type representing single precision floating-point values.
"""
__metaclass__ = PrimitiveTypeSingleton

def __repr__(self):
return "FloatType"


class ByteType(object):
class ByteType(PrimitiveType):
"""Spark SQL ByteType
The data type representing int values with 1 singed byte.
"""
__metaclass__ = PrimitiveTypeSingleton

def __repr__(self):
return "ByteType"


class IntegerType(object):
class IntegerType(PrimitiveType):
"""Spark SQL IntegerType
The data type representing int values.
"""
__metaclass__ = PrimitiveTypeSingleton

def __repr__(self):
return "IntegerType"


class LongType(object):
class LongType(PrimitiveType):
"""Spark SQL LongType
The data type representing long values. If the any value is beyond the range of
[-9223372036854775808, 9223372036854775807], please use DecimalType.
"""
__metaclass__ = PrimitiveTypeSingleton

def __repr__(self):
return "LongType"


class ShortType(object):
class ShortType(PrimitiveType):
"""Spark SQL ShortType
The data type representing int values with 2 signed bytes.
"""
__metaclass__ = PrimitiveTypeSingleton

def __repr__(self):
return "ShortType"


class ArrayType(object):
class ArrayType(DataType):
"""Spark SQL ArrayType
The data type representing list values.
Expand All @@ -201,19 +171,12 @@ def __init__(self, elementType, containsNull=False):
self.containsNull = containsNull

def __repr__(self):
return "ArrayType(" + self.elementType.__repr__() + "," + \
str(self.containsNull).lower() + ")"
return "ArrayType(%r,%s)" % (self.elementType,
str(self.containsNull).lower())

def __eq__(self, other):
return (isinstance(other, self.__class__) and
self.elementType == other.elementType and
self.containsNull == other.containsNull)

def __ne__(self, other):
return not self.__eq__(other)


class MapType(object):
class MapType(DataType):
"""Spark SQL MapType
The data type representing dict values.
Expand Down Expand Up @@ -241,21 +204,11 @@ def __init__(self, keyType, valueType, valueContainsNull=True):
self.valueContainsNull = valueContainsNull

def __repr__(self):
return "MapType(" + self.keyType.__repr__() + "," + \
self.valueType.__repr__() + "," + \
str(self.valueContainsNull).lower() + ")"

def __eq__(self, other):
return (isinstance(other, self.__class__) and
self.keyType == other.keyType and
self.valueType == other.valueType and
self.valueContainsNull == other.valueContainsNull)

def __ne__(self, other):
return not self.__eq__(other)
return "MapType(%r,%r,%s)" % (self.keyType, self.valueType,
str(self.valueContainsNull).lower())


class StructField(object):
class StructField(DataType):
"""Spark SQL StructField
Represents a field in a StructType.
Expand All @@ -281,21 +234,11 @@ def __init__(self, name, dataType, nullable):
self.nullable = nullable

def __repr__(self):
return "StructField(" + self.name + "," + \
self.dataType.__repr__() + "," + \
str(self.nullable).lower() + ")"

def __eq__(self, other):
return (isinstance(other, self.__class__) and
self.name == other.name and
self.dataType == other.dataType and
self.nullable == other.nullable)
return "StructField(%s,%r,%s)" % (self.name, self.dataType,
str(self.nullable).lower())

def __ne__(self, other):
return not self.__eq__(other)


class StructType(object):
class StructType(DataType):
"""Spark SQL StructType
The data type representing namedtuple values.
Expand All @@ -318,15 +261,8 @@ def __init__(self, fields):
self.fields = fields

def __repr__(self):
return "StructType(List(" + \
",".join([field.__repr__() for field in self.fields]) + "))"

def __eq__(self, other):
return (isinstance(other, self.__class__) and
self.fields == other.fields)

def __ne__(self, other):
return not self.__eq__(other)
return ("StructType(List(%s))" %
",".join(repr(field) for field in self.fields))


def _parse_datatype_list(datatype_list_string):
Expand Down

0 comments on commit 182fb46

Please sign in to comment.