Skip to content

Commit

Permalink
[REFACTOR][PY] Establish tvm.arith (#4904)
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen committed Feb 18, 2020
1 parent 38d1dd2 commit d1e1ac4
Show file tree
Hide file tree
Showing 14 changed files with 322 additions and 107 deletions.
22 changes: 22 additions & 0 deletions python/tvm/arith/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Integer bound analysis, simplification and pattern detection."""

from .int_set import IntSet, IntervalSet
from .analyzer import ModularSet, ConstIntBound, Analyzer
from .bound import deduce_bound
from .pattern import detect_linear_equation, detect_clip_bound
21 changes: 21 additions & 0 deletions python/tvm/arith/_ffi_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""FFI APIs for tvm.arith"""
import tvm._ffi


tvm._ffi._init_api("arith", __name__)
72 changes: 29 additions & 43 deletions python/tvm/arith.py → python/tvm/arith/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,42 +17,15 @@
"""Arithmetic data structure and utility"""
import tvm._ffi
from tvm.runtime import Object


class IntSet(Object):
"""Represent a set of integer in one dimension."""
def is_nothing(self):
"""Whether the set represent nothing"""
return _IntSetIsNothing(self)

def is_everything(self):
"""Whether the set represent everything"""
return _IntSetIsEverything(self)


@tvm._ffi.register_object("arith.IntervalSet")
class IntervalSet(IntSet):
"""Represent set of continuous interval [min_value, max_value]
Parameters
----------
min_value : Expr
The minimum value in the interval.
max_value : Expr
The maximum value in the interval.
"""
def __init__(self, min_value, max_value):
self.__init_handle_by_constructor__(
_make_IntervalSet, min_value, max_value)
from . import _ffi_api


@tvm._ffi.register_object("arith.ModularSet")
class ModularSet(Object):
"""Represent range of (coeff * x + base) for x in Z """
def __init__(self, coeff, base):
self.__init_handle_by_constructor__(
_make_ModularSet, coeff, base)
_ffi_api.ModularSet, coeff, base)


@tvm._ffi.register_object("arith.ConstIntBound")
Expand All @@ -72,7 +45,7 @@ class ConstIntBound(Object):

def __init__(self, min_value, max_value):
self.__init_handle_by_constructor__(
_make_ConstIntBound, min_value, max_value)
_ffi_api.ConstIntBound, min_value, max_value)


class ConstraintScope:
Expand Down Expand Up @@ -105,11 +78,12 @@ class Analyzer:
be used to perform various symbolic integer analysis.
"""
def __init__(self):
_mod = _CreateAnalyzer()
_mod = _ffi_api.CreateAnalyzer()
self._const_int_bound = _mod("const_int_bound")
self._const_int_bound_update = _mod("const_int_bound_update")
self._bind = _mod("bind")
self._modular_set = _mod("modular_set")
self._simplify = _mod("Simplify")
self._rewrite_simplify = _mod("rewrite_simplify")
self._canonical_simplify = _mod("canonical_simplify")
self._int_set = _mod("int_set")
Expand All @@ -120,7 +94,7 @@ def const_int_bound(self, expr):
Parameters
----------
expr : tvm.Expr
expr : PrimExpr
The expression.
Returns
Expand All @@ -135,7 +109,7 @@ def modular_set(self, expr):
Parameters
----------
expr : tvm.Expr
expr : PrimExpr
The expression.
Returns
Expand All @@ -145,12 +119,27 @@ def modular_set(self, expr):
"""
return self._modular_set(expr)

def simplify(self, expr):
"""Simplify expression via both rewrite and canonicalization.
Parameters
----------
expr : PrimExpr
The expression.
Returns
-------
result : Expr
The result.
"""
return self._simplify(expr)

def rewrite_simplify(self, expr):
"""Simplify expression via rewriting rules.
Parameters
----------
expr : tvm.Expr
expr : PrimExpr
The expression.
Returns
Expand All @@ -165,7 +154,7 @@ def canonical_simplify(self, expr):
Parameters
----------
expr : tvm.Expr
expr : PrimExpr
The expression.
Returns
Expand All @@ -180,7 +169,7 @@ def int_set(self, expr, dom_map):
Parameters
----------
expr : tvm.Expr
expr : PrimExpr
The expression.
dom_map : Dict[Var, tvm.arith.IntSet]
Expand All @@ -198,10 +187,10 @@ def bind(self, var, expr):
Parameters
----------
var : tvm.Var
var : tvm.tir.Var
The variable.
expr : tvm.Expr
expr : PrimExpr
The expression.
"""
return self._bind(var, expr)
Expand All @@ -211,7 +200,7 @@ def constraint_scope(self, constraint):
Parameters
----------
constraint : tvm.Expr
constraint : PrimExpr
The constraint expression.
returns
Expand Down Expand Up @@ -240,7 +229,7 @@ def update(self, var, info, override=False):
Parameters
----------
var : tvm.Var
var : tvm.tir.Var
The variable.
info : tvm.Object
Expand All @@ -254,6 +243,3 @@ def update(self, var, info, override=False):
else:
raise TypeError(
"Do not know how to handle type {}".format(type(info)))


tvm._ffi._init_api("tvm.arith")
39 changes: 39 additions & 0 deletions python/tvm/arith/bound.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Bound deduction."""
from . import _ffi_api


def deduce_bound(var, cond, hint_map, relax_map):
"""Deduce the bound of the target variable in the cond.
Parameters
----------
var : Var
The target variable to be deduced.
cond : PrimExpr
The condition
hint_map : Map[Var, IntSet]
Domain of variables used to help deduction.
relax_map : Map[Var, IntSet]
The fomain of the variables to be relaxed
using the provided domain.
"""
return _ffi_api.DeduceBound(var, cond, hint_map, relax_map)
80 changes: 80 additions & 0 deletions python/tvm/arith/int_set.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Integer set."""
import tvm._ffi
from tvm.runtime import Object
from . import _ffi_api


class IntSet(Object):
"""Represent a set of integer in one dimension."""
def is_nothing(self):
"""Whether the set represent nothing"""
return _ffi_api.IntSetIsNothing(self)

def is_everything(self):
"""Whether the set represent everything"""
return _ffi_api.IntSetIsEverything(self)

@staticmethod
def vector(vec):
"""Construct an integer set that covers the vector expr
Parameters
----------
vec : PrimExpr
The vector expression.
Returns
-------
rset : IntSet
The result set.
"""
return _ffi_api.intset_vector(vec)

@staticmethod
def single_point(point):
"""Construct a point set.
Parameters
----------
point : PrimExpr
The vector expression.
Returns
-------
rset : IntSet
The result set.
"""
return _ffi_api.intset_single_point(point)


@tvm._ffi.register_object("arith.IntervalSet")
class IntervalSet(IntSet):
"""Represent set of continuous interval [min_value, max_value]
Parameters
----------
min_value : PrimExpr
The minimum value in the interval.
max_value : PrimExpr
The maximum value in the interval.
"""
def __init__(self, min_value, max_value):
self.__init_handle_by_constructor__(
_ffi_api.IntervalSet, min_value, max_value)
60 changes: 60 additions & 0 deletions python/tvm/arith/pattern.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Detect common patterns."""
from . import _ffi_api


def detect_linear_equation(expr, var_list):
"""Match `expr = sum_{i=0}^{n-1} var[i] * coeff[i] + coeff[n]`
Where coeff[i] and base are invariant of var[j] for all i and j.
Parameters
----------
expr : PrimExpr
The expression to be matched.
var_list : List[tvm.tir.Var]
A list of variables.
Returns
-------
coeff : List[PrimExpr]
A list of co-efficients if the match is successful.
An empty list if the match failed.
"""
return _ffi_api.DetectLinearEquation(expr, var_list)


def detect_clip_bound(expr, var_list):
""" Detect if expression corresponds to clip bound of the vars
Parameters
----------
expr : PrimExpr
The expression to be matched.
var_list : List[tvm.tir.Var]
A list of variables.
Returns
-------
coeff : List[PrimExpr]
`concat([min_value[i], max_value[i]] for i, v in enumerate(var_list))`
An empty list if the match failed.
"""
return _ffi_api.DetectClipBound(expr, var_list)
Loading

0 comments on commit d1e1ac4

Please sign in to comment.