@ajtulloch @tqchen
Hi all. This is a very interesting problem.
I use the 2207 issues bug fixed version of tvm, and I edited this case a little, add cache_read/cache_write in schedule part.
from topi.util import get_const_tuple
import tvm
def func(Elements, Lengths):
def f(n, d):
rg = tvm.reduce_axis((0, Lengths[n]))
return tvm.sum(Elements[rg, d], axis=rg)
(N,) = get_const_tuple(Lengths.shape)
(_, D) = get_const_tuple(Elements.shape)
return tvm.compute((N, D), f, name="Y")
def run(N, I, D):
Elements = tvm.placeholder(shape=(I, D), dtype="float32", name="Elements")
Lengths = tvm.placeholder(shape=(N,), dtype="int32", name="Lengths")
Y = func(Elements, Lengths)
s = tvm.create_schedule([Y.op])
Elements_local = s.cache_read(Elements, "local", [Y])
Lengths_local = s.cache_read(Lengths, "local", [Y])
Y_local = s.cache_write(Y, "local")
print(tvm.lower(s, [Elements, Lengths, Y], simple_mode=True))
#print(tvm.save_json(Y))
#f = tvm.build(s, [Elements, Lengths, Y], target="llvm")
run(N=10, I=10, D=128)
Unfortunately, tvm.lower() fails with
tvm._ffi.base.TVMError: [08:42:49] /home/dylan/tvm/src/schedule/schedule_dataflow_rewrite.cc:166: Check failed: iv->iter_type == kDataPar (2 vs. 0) Can only relayout with in data parallel dimensions
I don't known what is the main of the problem cause. I guess the problem is cause by the
rg = tvm.reduce_axis((0, Lengths[n]))
But if I write the code as follows, the error is missing.
rg = tvm.reduce_axis((0, 4))
@ajtulloch @tqchen