diff --git a/sdc/__init__.py b/sdc/__init__.py index 3f2183345..5647ebd7d 100644 --- a/sdc/__init__.py +++ b/sdc/__init__.py @@ -62,6 +62,7 @@ # sdc.datatypes.hpat_pandas_dataframe_pass.sdc_nopython_pipeline_lite_register import sdc.rewrites.dataframe_constructor + import sdc.rewrites.dataframe_getitem_attribute import sdc.datatypes.hpat_pandas_functions import sdc.datatypes.hpat_pandas_dataframe_functions else: diff --git a/sdc/rewrites/dataframe_getitem_attribute.py b/sdc/rewrites/dataframe_getitem_attribute.py new file mode 100644 index 000000000..e0e102839 --- /dev/null +++ b/sdc/rewrites/dataframe_getitem_attribute.py @@ -0,0 +1,97 @@ +# ***************************************************************************** +# Copyright (c) 2020, Intel Corporation All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# +# Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, +# THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; +# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR +# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, +# EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# ***************************************************************************** + +from numba.ir import Assign, Const, Expr, Var +from numba.ir_utils import mk_unique_var +from numba.rewrites import register_rewrite, Rewrite +from numba.types import StringLiteral +from numba.typing import signature + +from sdc.config import config_pipeline_hpat_default +from sdc.hiframes.pd_dataframe_type import DataFrameType + + +if not config_pipeline_hpat_default: + @register_rewrite('after-inference') + class RewriteDataFrameGetItemAttr(Rewrite): + """ + Search for calls of df.attr and replace it with calls of df['attr']: + $0.2 = getattr(value=df, attr=A) -> $const0.0 = const(str, A) + $0.2 = static_getitem(value=df, index=A, index_var=$const0.0) + """ + + def match(self, func_ir, block, typemap, calltypes): + self.func_ir = func_ir + self.block = block + self.typemap = typemap + self.calltypes = calltypes + self.getattrs = getattrs = set() + for expr in block.find_exprs(op='getattr'): + obj = typemap[expr.value.name] + if not isinstance(obj, DataFrameType): + continue + if expr.attr in obj.columns: + getattrs.add(expr) + + return len(getattrs) > 0 + + def apply(self): + new_block = self.block.copy() + new_block.clear() + for inst in self.block.body: + if isinstance(inst, Assign) and inst.value in self.getattrs: + const_assign = self._assign_const(inst) + new_block.append(const_assign) + + inst = self._assign_getitem(inst, index=const_assign.target) + + new_block.append(inst) + + return new_block + + def _assign_const(self, inst, prefix='$const0'): + """Create constant from attribute of the instruction.""" + const_node = Const(inst.value.attr, inst.loc) + const_var = Var(inst.target.scope, mk_unique_var(prefix), inst.loc) + + self.func_ir._definitions[const_var.name] = [const_node] + self.typemap[const_var.name] = StringLiteral(inst.value.attr) + + return Assign(const_node, const_var, inst.loc) + + def _assign_getitem(self, inst, index): + """Create getitem instruction from the getattr instruction.""" + new_expr = Expr.getitem(inst.value.value, index, inst.loc) + new_inst = Assign(value=new_expr, target=inst.target, loc=inst.loc) + + self.func_ir._definitions[inst.target] = [new_expr] + self.calltypes[new_expr] = signature( + self.typemap[inst.target.name], + self.typemap[new_expr.value.name], + self.typemap[new_expr.index.name] + ) + + return new_inst diff --git a/sdc/tests/test_dataframe.py b/sdc/tests/test_dataframe.py index a28eacf7b..63599df5a 100644 --- a/sdc/tests/test_dataframe.py +++ b/sdc/tests/test_dataframe.py @@ -1367,6 +1367,15 @@ def test_impl(df): pd.testing.assert_series_equal(sdc_func(df), test_impl(df)) + def test_df_getitem_attr(self): + def test_impl(df): + return df.A + + sdc_func = self.jit(test_impl) + df = gen_df(test_global_input_data_float64) + + pd.testing.assert_series_equal(sdc_func(df), test_impl(df)) + @skip_numba_jit def test_isin_df1(self): def test_impl(df, df2):