Skip to content
This repository was archived by the owner on Feb 2, 2024. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions sdc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@

import sdc.rewrites.dataframe_constructor
import sdc.rewrites.read_csv_consts
import sdc.rewrites.dict_zip_tuples
import sdc.rewrites.dataframe_getitem_attribute
import sdc.datatypes.hpat_pandas_functions
import sdc.datatypes.hpat_pandas_dataframe_functions
Expand Down
88 changes: 87 additions & 1 deletion sdc/functions/tuple_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,14 @@
# EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
# *****************************************************************************

from textwrap import dedent

from numba import types
from numba.extending import (intrinsic, )
from numba.extending import intrinsic
from numba.core.typing.templates import (signature, )
from numba.typed.dictobject import build_map

from sdc.utilities.utils import sdc_overload


@intrinsic
Expand Down Expand Up @@ -205,3 +210,84 @@ def codegen(context, builder, sig, args):
return context.make_tuple(builder, ret_type, [first_tup, second_tup])

return ret_type(data_type), codegen


def sdc_tuple_zip(x, y):
pass


@sdc_overload(sdc_tuple_zip)
def sdc_tuple_zip_ovld(x, y):
""" This function combines tuple of pairs from two input tuples x and y, preserving
literality of elements in them. """

if not (isinstance(x, types.BaseAnonymousTuple) and isinstance(y, types.BaseAnonymousTuple)):
return None

res_size = min(len(x), len(y))
func_impl_name = 'sdc_tuple_zip_impl'
tup_elements = ', '.join([f"(x[{i}], y[{i}])" for i in range(res_size)])
func_text = dedent(f"""
def {func_impl_name}(x, y):
return ({tup_elements}{',' if res_size else ''})
""")
use_globals, use_locals = {}, {}
exec(func_text, use_globals, use_locals)
return use_locals[func_impl_name]

# FIXME_Numba#6533: alternatively we could have used sdc_tuple_map_elementwise
# to avoid another use of exec, but due to @intrinsic-s not supporting
# prefer_literal option below implementation looses literaly of args!
# from sdc.functions.tuple_utils import sdc_tuple_map_elementwise
# def sdc_tuple_zip_impl(x, y):
# return sdc_tuple_map_elementwise(
# lambda a, b: (a, b),
# x,
# y
# )
#
# return sdc_tuple_zip_impl


@intrinsic
def literal_dict_ctor(typingctx, items):

tup_size = len(items)
key_order = {p[0].literal_value: i for i, p in enumerate(items)}
ret_type = types.LiteralStrKeyDict(dict(items), key_order)

def codegen(context, builder, sig, args):
items_val = args[0]

# extract elements from the input tuple and repack into a list of variables required by build_map
repacked_items = []
for i in range(tup_size):
elem = builder.extract_value(items_val, i)
elem_first = builder.extract_value(elem, 0)
elem_second = builder.extract_value(elem, 1)
repacked_items.append((elem_first, elem_second))
d = build_map(context, builder, ret_type, items, repacked_items)
return d

return ret_type(items), codegen


@sdc_overload(dict)
def dict_from_tuples_ovld(x):

accepted_tuple_types = (types.Tuple, types.UniTuple)
if not isinstance(x, types.BaseAnonymousTuple):
return None

def check_tuple_element(ty):
return (isinstance(ty, accepted_tuple_types)
and len(ty) == 2
and isinstance(ty[0], types.StringLiteral))

# below checks that elements are tuples with size 2 and first element is literal string
if not (len(x) != 0 and all(map(check_tuple_element, x))):
assert False, f"Creating LiteralStrKeyDict not supported from pairs of: {x}"

def dict_from_tuples_impl(x):
return literal_dict_ctor(x)
return dict_from_tuples_impl
78 changes: 78 additions & 0 deletions sdc/rewrites/dict_zip_tuples.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# *****************************************************************************
# Copyright (c) 2019-2021, Intel Corporation All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer.
#
# Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO,
# THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR
# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE,
# EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
# *****************************************************************************

from numba.core.rewrites import register_rewrite, Rewrite
from numba.core.ir_utils import guard, get_definition
from numba import errors
from numba.core import ir

from sdc.rewrites.ir_utils import find_operations, import_function
from sdc.functions.tuple_utils import sdc_tuple_zip


@register_rewrite('before-inference')
class RewriteDictZip(Rewrite):
"""
Searches for calls like dict(zip(arg1, arg2)) and replaces zip with sdc_zip.
"""

def match(self, func_ir, block, typemap, calltypes):

self._block = block
self._func_ir = func_ir
self._calls_to_rewrite = set()

# Find all assignments with a RHS expr being a call to dict, and where arg
# is a call to zip and store these ir.Expr for further modification
for inst in find_operations(block=block, op_name='call'):
expr = inst.value
try:
callee = func_ir.infer_constant(expr.func)
except errors.ConstantInferenceError:
continue

if (callee is dict and len(expr.args) == 1):
dict_arg_expr = guard(get_definition, func_ir, expr.args[0])
if (getattr(dict_arg_expr, 'op', None) == 'call'):
called_func = guard(get_definition, func_ir, dict_arg_expr.func)
if (called_func.value is zip and len(dict_arg_expr.args) == 2):
self._calls_to_rewrite.add(dict_arg_expr)

return len(self._calls_to_rewrite) > 0

def apply(self):
"""
Replace call to zip in matched expressions with call to sdc_zip.
"""
new_block = self._block.copy()
new_block.clear()
zip_spec_stmt = import_function(sdc_tuple_zip, new_block, self._func_ir)
for inst in self._block.body:
if isinstance(inst, ir.Assign) and inst.value in self._calls_to_rewrite:
expr = inst.value
expr.func = zip_spec_stmt.target # injects the new function
new_block.append(inst)
return new_block
48 changes: 47 additions & 1 deletion sdc/tests/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,10 @@
import pandas as pd
import random
import unittest
from itertools import product

from numba import types
from numba.tests.support import MemoryLeakMixin

import sdc
from sdc.tests.test_base import TestCase
Expand All @@ -43,7 +46,8 @@
dist_IR_contains,
get_rank,
get_start_end,
skip_numba_jit)
skip_numba_jit,
assert_nbtype_for_varname)


def get_np_state_ptr():
Expand Down Expand Up @@ -540,5 +544,47 @@ def test_rhs(arr_len):
np.testing.assert_allclose(A, B)


class TestPython(MemoryLeakMixin, TestCase):

def test_literal_dict_ctor(self):
""" Verifies that dict builtin creates LiteralStrKeyDict from tuple
of pairs ('col_name_i', col_data_i), where col_name_i is literal string """

def test_impl_1():
items = (('A', np.arange(11)), )
res = dict(items)
return len(res)

def test_impl_2():
items = (('A', np.arange(5)), ('B', np.ones(11)), )
res = dict(items)
return len(res)

local_vars = locals()
list_tested_fns = [local_vars[k] for k in local_vars.keys() if k.startswith('test_impl')]

for test_impl in list_tested_fns:
with self.subTest(tested_func_name=test_impl.__name__):
sdc_func = self.jit(test_impl)
self.assertEqual(sdc_func(), test_impl())
assert_nbtype_for_varname(self, sdc_func, 'res', types.LiteralStrKeyDict)

def test_dict_zip_rewrite(self):
""" Verifies that a compination of dict(zip()) creates LiteralStrKeyDict when
zip is applied to tuples of literal column names and columns data """

dict_keys = ('A', 'B')
dict_values = (np.ones(5), np.array([1, 2, 3]))

def test_impl():
res = dict(zip(dict_keys, dict_values))
return len(res)

sdc_func = self.jit(test_impl)
expected = len(dict(zip(dict_keys, dict_values)))
self.assertEqual(sdc_func(), expected)
assert_nbtype_for_varname(self, sdc_func, 'res', types.LiteralStrKeyDict)


if __name__ == "__main__":
unittest.main()
7 changes: 7 additions & 0 deletions sdc/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,3 +272,10 @@ def _make_func_from_text(func_text, func_name='test_impl', global_vars={}):
exec(func_text, global_vars, loc_vars)
test_impl = loc_vars[func_name]
return test_impl


def assert_nbtype_for_varname(self, disp, var, expected_type, fn_sig=None):
fn_sig = fn_sig or disp.nopython_signatures[0]
cres = disp.get_compile_result(fn_sig)
fn_typemap = cres.type_annotation.typemap
self.assertIsInstance(fn_typemap[var], expected_type)