Skip to content

Commit

Permalink
support date and datetime by auto_convert
Browse files Browse the repository at this point in the history
  • Loading branch information
Davies Liu committed Apr 18, 2015
1 parent cb094ff commit 3c373f3
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 0 deletions.
2 changes: 2 additions & 0 deletions python/pyspark/rdd.py
Original file line number Diff line number Diff line change
Expand Up @@ -2267,6 +2267,8 @@ def _prepare_for_python_RDD(sc, command, obj=None):
# The broadcast will have same life cycle as created PythonRDD
broadcast = sc.broadcast(pickled_command)
pickled_command = ser.dumps(broadcast)
# There is a bug in py4j.java_gateway.JavaClass with auto_convert
# TODO: use auto_convert once py4j fix the bug
broadcast_vars = ListConverter().convert(
[x._jbroadcast for x in sc._pickled_broadcast_vars],
sc._gateway._gateway_client)
Expand Down
27 changes: 27 additions & 0 deletions python/pyspark/sql/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import sys
import decimal
import time
import datetime
import keyword
import warnings
Expand All @@ -30,6 +31,9 @@
long = int
unicode = str

from py4j.protocol import register_input_converter
from py4j.java_gateway import JavaClass

__all__ = [
"DataType", "NullType", "StringType", "BinaryType", "BooleanType", "DateType",
"TimestampType", "DecimalType", "DoubleType", "FloatType", "ByteType", "IntegerType",
Expand Down Expand Up @@ -1237,6 +1241,29 @@ def __repr__(self):
return "<Row(%s)>" % ", ".join(self)


class DateConverter(object):
def can_convert(self, obj):
return isinstance(obj, datetime.date)

def convert(self, obj, gateway_client):
Date = JavaClass("java.sql.Date", gateway_client)
return Date.valueOf(obj.strftime("%Y-%m-%d"))


class DatetimeConverter(object):
def can_convert(self, obj):
return isinstance(obj, datetime.datetime)

def convert(self, obj, gateway_client):
Timestamp = JavaClass("java.sql.Timestamp", gateway_client)
return Timestamp(int(time.mktime(obj.timetuple())) * 1000 + obj.microsecond // 1000)


# datetime is a subclass of date, we should register DatetimeConverter first
register_input_converter(DatetimeConverter())
register_input_converter(DateConverter())


def _test():
import doctest
from pyspark.context import SparkContext
Expand Down
11 changes: 11 additions & 0 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import tempfile
import pickle
import functools
import datetime

import py4j

Expand Down Expand Up @@ -464,6 +465,16 @@ def test_infer_long_type(self):
self.assertEqual(_infer_type(2**61), LongType())
self.assertEqual(_infer_type(2**71), LongType())

def test_filter_with_datetime(self):
time = datetime.datetime(2015, 4, 17, 23, 01, 02, 3000)
date = time.date()
row = Row(date=date, time=time)
df = self.sqlCtx.createDataFrame([row])
self.assertEqual(1, df.filter(df.date == date).count())
self.assertEqual(1, df.filter(df.time == time).count())
self.assertEqual(0, df.filter(df.date > date).count())
self.assertEqual(0, df.filter(df.time > time).count())

def test_dropna(self):
schema = StructType([
StructField("name", StringType(), True),
Expand Down

0 comments on commit 3c373f3

Please sign in to comment.