Skip to content

Commit

Permalink
[COMPTIME] Specialize Constant._binary() for compilation speedup (h…
Browse files Browse the repository at this point in the history
…idet-org#148)

**Overview** 
Specialize function `Constant._binary()` for compilation speedup

**Compilation time improvement results** 
matmul_f16 with `max_parallel_jobs=1`
Before: 2m 11.2s
After: 2m 4.4s
Speedup: 5.5%

**Additional test**
matmul_f16 has 177 candidates. I checked that all of them remained the same(no functional changes)
  • Loading branch information
vadiklyutiy committed Apr 20, 2024
1 parent a220a35 commit eefc9d8
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 6 deletions.
1 change: 1 addition & 0 deletions python/hidet/ir/dtypes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from .vector import float16x2, float32x4, float32x8, int8x4, uint8x4, int4bx8, uint4bx8, vectorize
from .vector import f16x2, f32x4, f32x8, i4x8, u4x8
from .complex import complex64, complex128
from .integer import IntegerType
from .promotion import promote_type
from .utils import dtype_to_numpy, finfo, iinfo

Expand Down
47 changes: 41 additions & 6 deletions python/hidet/ir/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import operator
import numpy as np
import hidet.option
from hidet.ir.dtypes import boolean, int32, int64, uint64, IntegerType, promote_type
from .node import Node
from .type import BaseType, TensorType, DataType, TensorPointerType, PointerType, FuncType, StringType, ArrayType
from .type import tensor_pointer_type, string_type, tensor_type, data_type
Expand Down Expand Up @@ -217,8 +218,6 @@ def _binary(cls, a: Expr, b: Expr): # pylint: disable=bad-staticmethod-argument
if not isinstance(b, Expr):
b = convert(b)
if isinstance(a, Constant) and isinstance(b, Constant):
from hidet.ir.dtypes import promote_type

if a.type.is_data_type() and b.type.is_data_type():
value = operator_dict[cls](a.value, b.value)
if cls in [Equal, NotEqual, LessThan, LessEqual, LogicalAnd, LogicalOr]:
Expand Down Expand Up @@ -248,7 +247,6 @@ def _binary(cls, a: Expr, b: Expr): # pylint: disable=bad-staticmethod-argument
else:
raise ValueError('unknown binary operator {}'.format(cls))
elif isinstance(b, Constant) and b.type.is_data_type():
from hidet.ir.dtypes import promote_type
from hidet.ir.tools import infer_type

if b == 0:
Expand All @@ -259,7 +257,6 @@ def _binary(cls, a: Expr, b: Expr): # pylint: disable=bad-staticmethod-argument
elif b == 1 and cls in [Multiply, Div]:
return a
elif isinstance(a, Constant):
from hidet.ir.dtypes import promote_type
from hidet.ir.tools import infer_type

if a == 0:
Expand Down Expand Up @@ -497,6 +494,36 @@ def __complex__(self):
def array(self) -> np.ndarray:
return self.value

# This speciallisation of Expr._binary is done for speedup purposes only
# and fully equvivalent to Expr._binary by functionality
@staticmethod
def _binary(cls, a: Expr, b: Expr): # pylint: disable=bad-staticmethod-argument
def _binary_internal(cls, a: int, b: int, a_type, res_type):
value = operator_dict[cls](a, b)
if cls in [Equal, NotEqual, LessThan, LessEqual]:
return boolean.true if value else boolean.false
elif cls in [LeftShift, RightShift]:
return constant_int(value, a_type)
else:
return constant_int(value, res_type)

if (
isinstance(a, Constant)
and isinstance(b, Constant)
and isinstance(a.type, IntegerType)
and isinstance(b.type, IntegerType)
):
res_type = a.type if a.type is b.type else promote_type(a.type, b.type)
return _binary_internal(cls, a.value, b.value, a.type, res_type)
if isinstance(b, int) and isinstance(a.type, IntegerType):
res_type = int64 if (a.type is int64 or a.type is uint64) else int32
return _binary_internal(cls, a.value, b, a.type, res_type)
if isinstance(a, int) and isinstance(b.type, IntegerType):
res_type = int64 if (b.type is int64 or b.type is uint64) else int32
return _binary_internal(cls, a, b.value, int32, res_type)
# 2.5% cases fall here
return super(Constant, Constant)._binary(cls, a, b)


class IfThenElse(Expr):
"""
Expand Down Expand Up @@ -927,8 +954,6 @@ def is_constant(e: Union[Expr, PyScalar], *other: Union[Expr, PyScalar]) -> bool


def constant(value, const_type: Union[str, BaseType]) -> Constant:
from hidet.ir.dtypes import boolean

if const_type and isinstance(const_type, str):
const_type = data_type(const_type)

Expand Down Expand Up @@ -967,6 +992,16 @@ def constant(value, const_type: Union[str, BaseType]) -> Constant:
return Constant(value, const_type)


def constant_int(value: int, const_type: IntegerType) -> Constant:
if -128 <= value <= 128:
# pylint: disable=protected-access
if (value, const_type.name) not in Constant._constant_pool:
Constant._constant_pool[(value, const_type.name)] = Constant(value, const_type)
return Constant._constant_pool[(value, const_type.name)]
else:
return Constant(value, const_type)


def symbol_var(name: str, dtype='int32') -> SymbolVar:
dtype = data_type(dtype)
if name not in SymbolVar.name2symbol:
Expand Down

0 comments on commit eefc9d8

Please sign in to comment.