Skip to content

Commit

Permalink
fix bugs of groupby
Browse files Browse the repository at this point in the history
  • Loading branch information
继盛 committed Feb 5, 2016
1 parent 896b36b commit 6a99db6
Show file tree
Hide file tree
Showing 9 changed files with 36 additions and 13 deletions.
2 changes: 1 addition & 1 deletion odps/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import sys

__version__ = '0.3.3'
__version__ = '0.3.4'
__all__ = ['ODPS',]

version = sys.version_info
Expand Down
8 changes: 4 additions & 4 deletions odps/df/backends/odpssql/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -878,15 +878,15 @@ def visit_shift_window(self, expr):
def visit_scalar(self, expr):
compiled = None
if expr._value is not None:
if isinstance(expr._value, bool):
if expr.dtype == df_types.string and isinstance(expr.value, six.text_type):
compiled = repr(utils.to_str(expr.value))
elif isinstance(expr._value, bool):
compiled = 'true' if expr._value else 'false'
elif isinstance(expr._value, datetime):
# FIXME: just ignore shorter than second
compiled= 'FROM_UNIXTIME({0})'.format(utils.to_timestamp(expr._value))
elif isinstance(expr._value, Decimal):
raise NotImplementedError
elif expr.dtype == df_types.string and isinstance(expr.value, six.text_type):
compiled = repr(utils.to_str(expr.value))
compiled = 'CAST({0} AS DECIMAL)'.format(repr(str(expr._value)))

if compiled is None:
compiled = repr(expr._value)
Expand Down
6 changes: 6 additions & 0 deletions odps/df/backends/odpssql/tests/test_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import re
from datetime import datetime
import base64 # noqa
from decimal import Decimal

import six

Expand Down Expand Up @@ -277,6 +278,11 @@ def testArithmeticCompilation(self):
'FROM mocked_project.`pyodps_test_expr_table` t1' % unix_time
self.assertEqual(to_str(expect), to_str(self.engine.compile(expr, prettify=False)))

expr = self.expr.scale < Decimal('3.14')
expect = "SELECT t1.`scale` < CAST('3.14' AS DECIMAL) AS scale \n" \
"FROM mocked_project.`pyodps_test_expr_table` t1"
self.assertEqual(to_str(expect), to_str(self.engine.compile(expr, prettify=False)))

def testMathCompilation(self):
for math_cls, func in MATH_COMPILE_DIC.items():
e = getattr(self.expr.id, math_cls.lower())()
Expand Down
14 changes: 12 additions & 2 deletions odps/df/expr/element.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,17 @@ class AnyOp(TypedExpr):
@classmethod
def _new_cls(cls, *args, **kwargs):
if '_data_type' in kwargs:
bases = cls, SequenceExpr._new_cls(cls, *args, **kwargs)
seq_cls = SequenceExpr._new_cls(cls, *args, **kwargs)
if issubclass(cls, seq_cls):
return cls
bases = cls, seq_cls
else:
assert '_value_type' in kwargs
bases = cls, Scalar._new_cls(cls, *args, **kwargs)

scalar_cls = Scalar._new_cls(cls, *args, **kwargs)
if issubclass(cls, scalar_cls):
return cls
bases = cls, scalar_cls

return type(cls.__name__, bases, dict(cls.__dict__))

Expand All @@ -49,6 +56,9 @@ class ElementWise(AnyOp):
def _new_cls(cls, *args, **kwargs):
base = AnyOp._new_cls(*args, **kwargs)

if issubclass(cls, base):
return cls

dic = dict(cls.__dict__)
dic['_args'] = cls._args
if '_add_args_slots' in dic:
Expand Down
6 changes: 4 additions & 2 deletions odps/df/expr/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,11 @@ def repr_obj(obj):


class Expr(Node):
__slots__ = '__execution', '__ban_optimize'
__slots__ = '__execution', '__ban_optimize', '_engine'

def __init__(self, *args, **kwargs):
self.__ban_optimize = False
self._engine = None
super(Expr, self).__init__(*args, **kwargs)

def __repr__(self):
Expand Down Expand Up @@ -1084,8 +1085,9 @@ def to_sequence(self):
}
if '_source_value_type' in kw:
kw['_source_data_type'] = kw.pop('_source_value_type')

cls = next(c for c in inspect.getmro(type(self))[1:]
if c.__name__ == type(self).__name__)
if c.__name__ == type(self).__name__ and not issubclass(c, Scalar))
seq = cls._new(**kw)

for attr, value in six.iteritems(attr_values):
Expand Down
2 changes: 1 addition & 1 deletion odps/df/expr/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,8 +161,8 @@ def _as_grouped(self, reduction_expr):
else:
continue

path[idx] = to_sub
if idx == 0:
path[0] = to_sub
root = to_sub
else:
path[idx - 1].substitute(node, to_sub)
Expand Down
7 changes: 6 additions & 1 deletion odps/df/expr/tests/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,12 @@ def testGroupbyReductions(self):
self.assertGreater(len(expr.aggregations), 0)
self.assertIsInstance(expr.aggregations[0], GroupedMedian)

metric = self.expr.int32.mean() > 10
field = (metric.ifelse(self.expr.int64.max(), 0) + 1).rename('int64_max')
expr = self.expr.groupby('string').agg(field)
self.assertIsInstance(expr, GroupByCollectionExpr)
self.assertIsInstance(expr.int64_max, Int64SequenceExpr)

def testGroupbyField(self):
grouped = self.expr.groupby(['int32', 'boolean']).string.sum()
self.assertIsInstance(grouped, StringSequenceExpr)
Expand All @@ -104,6 +110,5 @@ def testMutate(self):
self.assertSequenceEqual(expr._schema.types,
[types.int16, types.datetime, types.float64, types.int64])


if __name__ == '__main__':
unittest.main()
2 changes: 1 addition & 1 deletion odps/df/expr/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def testGetAttrs(self):
expr = CollectionExpr(_source_data=table, _schema=schema)

expected = ('_lhs', '_rhs', '_data_type', '_source_data_type', '_name',
'_source_name', '_cached_args')
'_source_name', '_engine', '_cached_args')
self.assertSequenceEqual(expected, get_attrs(expr.id + 1))

if __name__ == '__main__':
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
long_description = f.read()

setup(name='pyodps',
version='0.3.3',
version='0.3.4',
description='ODPS Python SDK',
long_description=long_description,
author='Wu Wei',
Expand Down

0 comments on commit 6a99db6

Please sign in to comment.