diff --git a/pygpu/reduction.py b/pygpu/reduction.py index 6bc7c9c0df..812af8b6c4 100644 --- a/pygpu/reduction.py +++ b/pygpu/reduction.py @@ -8,6 +8,12 @@ from .tools import ArrayArg, check_args, prod, lfu_cache from .elemwise import parse_c_args, massage_op +def _ceil_log2(x): + # nearest power of 2 (going up) + if x != 0: + return int(math.ceil(math.log(x, 2))) + else: + return 0 basic_kernel = Template(""" ${preamble} @@ -172,11 +178,7 @@ def __init__(self, context, dtype_out, neutral, reduce_expr, redux, def _find_kernel_ls(self, tmpl, max_ls, *tmpl_args): local_size = min(self.init_local_size, max_ls) - # nearest power of 2 (going up) - if local_size != 0: - count_lim = int(math.ceil(math.log(local_size, 2))) - else: - count_lim = 0 + count_lim = _ceil_log2(local_size) local_size = 2**count_lim loop_count = 0 while loop_count <= count_lim: @@ -248,7 +250,7 @@ def __call__(self, *args, **kwargs): if self.init_local_size < n: k, _, _, ls = self._get_basic_kernel(self.init_local_size, nd) else: - k, _, _, ls = self._get_basic_kernel(n, nd) + k, _, _, ls = self._get_basic_kernel(2**_ceil_log2(n), nd) kargs = [n, out] kargs.extend(dims)