diff --git a/sdc/__init__.py b/sdc/__init__.py index e73c51682..0c3235441 100644 --- a/sdc/__init__.py +++ b/sdc/__init__.py @@ -70,6 +70,7 @@ import sdc.rewrites.dataframe_constructor import sdc.rewrites.read_csv_consts +import sdc.rewrites.dict_zip_tuples import sdc.rewrites.dataframe_getitem_attribute import sdc.datatypes.hpat_pandas_functions import sdc.datatypes.hpat_pandas_dataframe_functions diff --git a/sdc/functions/tuple_utils.py b/sdc/functions/tuple_utils.py index 17dffa200..aa41db68f 100644 --- a/sdc/functions/tuple_utils.py +++ b/sdc/functions/tuple_utils.py @@ -25,9 +25,14 @@ # EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # ***************************************************************************** +from textwrap import dedent + from numba import types -from numba.extending import (intrinsic, ) +from numba.extending import intrinsic from numba.core.typing.templates import (signature, ) +from numba.typed.dictobject import build_map + +from sdc.utilities.utils import sdc_overload @intrinsic @@ -205,3 +210,84 @@ def codegen(context, builder, sig, args): return context.make_tuple(builder, ret_type, [first_tup, second_tup]) return ret_type(data_type), codegen + + +def sdc_tuple_zip(x, y): + pass + + +@sdc_overload(sdc_tuple_zip) +def sdc_tuple_zip_ovld(x, y): + """ This function combines tuple of pairs from two input tuples x and y, preserving + literality of elements in them. """ + + if not (isinstance(x, types.BaseAnonymousTuple) and isinstance(y, types.BaseAnonymousTuple)): + return None + + res_size = min(len(x), len(y)) + func_impl_name = 'sdc_tuple_zip_impl' + tup_elements = ', '.join([f"(x[{i}], y[{i}])" for i in range(res_size)]) + func_text = dedent(f""" + def {func_impl_name}(x, y): + return ({tup_elements}{',' if res_size else ''}) + """) + use_globals, use_locals = {}, {} + exec(func_text, use_globals, use_locals) + return use_locals[func_impl_name] + + # FIXME_Numba#6533: alternatively we could have used sdc_tuple_map_elementwise + # to avoid another use of exec, but due to @intrinsic-s not supporting + # prefer_literal option below implementation looses literaly of args! + # from sdc.functions.tuple_utils import sdc_tuple_map_elementwise + # def sdc_tuple_zip_impl(x, y): + # return sdc_tuple_map_elementwise( + # lambda a, b: (a, b), + # x, + # y + # ) + # + # return sdc_tuple_zip_impl + + +@intrinsic +def literal_dict_ctor(typingctx, items): + + tup_size = len(items) + key_order = {p[0].literal_value: i for i, p in enumerate(items)} + ret_type = types.LiteralStrKeyDict(dict(items), key_order) + + def codegen(context, builder, sig, args): + items_val = args[0] + + # extract elements from the input tuple and repack into a list of variables required by build_map + repacked_items = [] + for i in range(tup_size): + elem = builder.extract_value(items_val, i) + elem_first = builder.extract_value(elem, 0) + elem_second = builder.extract_value(elem, 1) + repacked_items.append((elem_first, elem_second)) + d = build_map(context, builder, ret_type, items, repacked_items) + return d + + return ret_type(items), codegen + + +@sdc_overload(dict) +def dict_from_tuples_ovld(x): + + accepted_tuple_types = (types.Tuple, types.UniTuple) + if not isinstance(x, types.BaseAnonymousTuple): + return None + + def check_tuple_element(ty): + return (isinstance(ty, accepted_tuple_types) + and len(ty) == 2 + and isinstance(ty[0], types.StringLiteral)) + + # below checks that elements are tuples with size 2 and first element is literal string + if not (len(x) != 0 and all(map(check_tuple_element, x))): + assert False, f"Creating LiteralStrKeyDict not supported from pairs of: {x}" + + def dict_from_tuples_impl(x): + return literal_dict_ctor(x) + return dict_from_tuples_impl diff --git a/sdc/rewrites/dict_zip_tuples.py b/sdc/rewrites/dict_zip_tuples.py new file mode 100644 index 000000000..96e38f5ca --- /dev/null +++ b/sdc/rewrites/dict_zip_tuples.py @@ -0,0 +1,78 @@ +# ***************************************************************************** +# Copyright (c) 2019-2021, Intel Corporation All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# +# Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, +# THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; +# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR +# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, +# EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# ***************************************************************************** + +from numba.core.rewrites import register_rewrite, Rewrite +from numba.core.ir_utils import guard, get_definition +from numba import errors +from numba.core import ir + +from sdc.rewrites.ir_utils import find_operations, import_function +from sdc.functions.tuple_utils import sdc_tuple_zip + + +@register_rewrite('before-inference') +class RewriteDictZip(Rewrite): + """ + Searches for calls like dict(zip(arg1, arg2)) and replaces zip with sdc_zip. + """ + + def match(self, func_ir, block, typemap, calltypes): + + self._block = block + self._func_ir = func_ir + self._calls_to_rewrite = set() + + # Find all assignments with a RHS expr being a call to dict, and where arg + # is a call to zip and store these ir.Expr for further modification + for inst in find_operations(block=block, op_name='call'): + expr = inst.value + try: + callee = func_ir.infer_constant(expr.func) + except errors.ConstantInferenceError: + continue + + if (callee is dict and len(expr.args) == 1): + dict_arg_expr = guard(get_definition, func_ir, expr.args[0]) + if (getattr(dict_arg_expr, 'op', None) == 'call'): + called_func = guard(get_definition, func_ir, dict_arg_expr.func) + if (called_func.value is zip and len(dict_arg_expr.args) == 2): + self._calls_to_rewrite.add(dict_arg_expr) + + return len(self._calls_to_rewrite) > 0 + + def apply(self): + """ + Replace call to zip in matched expressions with call to sdc_zip. + """ + new_block = self._block.copy() + new_block.clear() + zip_spec_stmt = import_function(sdc_tuple_zip, new_block, self._func_ir) + for inst in self._block.body: + if isinstance(inst, ir.Assign) and inst.value in self._calls_to_rewrite: + expr = inst.value + expr.func = zip_spec_stmt.target # injects the new function + new_block.append(inst) + return new_block diff --git a/sdc/tests/test_basic.py b/sdc/tests/test_basic.py index 413905dc9..21cd245f3 100644 --- a/sdc/tests/test_basic.py +++ b/sdc/tests/test_basic.py @@ -30,7 +30,10 @@ import pandas as pd import random import unittest +from itertools import product + from numba import types +from numba.tests.support import MemoryLeakMixin import sdc from sdc.tests.test_base import TestCase @@ -43,7 +46,8 @@ dist_IR_contains, get_rank, get_start_end, - skip_numba_jit) + skip_numba_jit, + assert_nbtype_for_varname) def get_np_state_ptr(): @@ -540,5 +544,47 @@ def test_rhs(arr_len): np.testing.assert_allclose(A, B) +class TestPython(MemoryLeakMixin, TestCase): + + def test_literal_dict_ctor(self): + """ Verifies that dict builtin creates LiteralStrKeyDict from tuple + of pairs ('col_name_i', col_data_i), where col_name_i is literal string """ + + def test_impl_1(): + items = (('A', np.arange(11)), ) + res = dict(items) + return len(res) + + def test_impl_2(): + items = (('A', np.arange(5)), ('B', np.ones(11)), ) + res = dict(items) + return len(res) + + local_vars = locals() + list_tested_fns = [local_vars[k] for k in local_vars.keys() if k.startswith('test_impl')] + + for test_impl in list_tested_fns: + with self.subTest(tested_func_name=test_impl.__name__): + sdc_func = self.jit(test_impl) + self.assertEqual(sdc_func(), test_impl()) + assert_nbtype_for_varname(self, sdc_func, 'res', types.LiteralStrKeyDict) + + def test_dict_zip_rewrite(self): + """ Verifies that a compination of dict(zip()) creates LiteralStrKeyDict when + zip is applied to tuples of literal column names and columns data """ + + dict_keys = ('A', 'B') + dict_values = (np.ones(5), np.array([1, 2, 3])) + + def test_impl(): + res = dict(zip(dict_keys, dict_values)) + return len(res) + + sdc_func = self.jit(test_impl) + expected = len(dict(zip(dict_keys, dict_values))) + self.assertEqual(sdc_func(), expected) + assert_nbtype_for_varname(self, sdc_func, 'res', types.LiteralStrKeyDict) + + if __name__ == "__main__": unittest.main() diff --git a/sdc/tests/test_utils.py b/sdc/tests/test_utils.py index 110c7424b..571da7d1a 100644 --- a/sdc/tests/test_utils.py +++ b/sdc/tests/test_utils.py @@ -272,3 +272,10 @@ def _make_func_from_text(func_text, func_name='test_impl', global_vars={}): exec(func_text, global_vars, loc_vars) test_impl = loc_vars[func_name] return test_impl + + +def assert_nbtype_for_varname(self, disp, var, expected_type, fn_sig=None): + fn_sig = fn_sig or disp.nopython_signatures[0] + cres = disp.get_compile_result(fn_sig) + fn_typemap = cres.type_annotation.typemap + self.assertIsInstance(fn_typemap[var], expected_type)