Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 9 additions & 7 deletions python/tvm/topi/adreno/conv2d_winograd_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def conv2d_winograd_comp(

convert_from4d = False
if len(data.shape) == 4:
convert_from4d = True
if layout == "NCHW":
N, DCI, H, W = get_const_tuple(data.shape)
else:
Expand Down Expand Up @@ -120,7 +121,6 @@ def conv2d_winograd_comp(
data = tvm.te.placeholder(dshape, data.dtype, name="data_placeholder")
kernel = tvm.te.placeholder(kshape, kernel.dtype, name="kernel_placeholder")
else:
convert_from4d = True
data = pack_input(
data, layout, N, in_channel_chunks, in_channel_block, in_channel_tail, H, W
)
Expand Down Expand Up @@ -220,9 +220,9 @@ def conv2d_winograd_comp(
idxdiv = tvm.tir.indexdiv
idxmod = tvm.tir.indexmod
if layout == "NCHW":
N, CI, H, W, CB = get_const_tuple(data.shape)
N, CI, _, _, CB = get_const_tuple(data.shape)
else:
N, H, W, CI, CB = get_const_tuple(data.shape)
N, _, _, CI, CB = get_const_tuple(data.shape)

# pack input tile
if layout == "NCHW":
Expand Down Expand Up @@ -494,16 +494,18 @@ def schedule_conv2d_winograd(cfg, s, output, pre_computed):
s[OL].set_scope("local")
output = s.outputs[0]

m = alpha - 3 + 1
if len(s[output].op.axis) == 4:
n, co, h, w = s[output].op.axis
cb = None
else:
n, co, h, w, _ = s[output].op.axis
ho, wo, hi, wi = s[output].tile(h, w, m, m)
n, co, h, w, cb = s[output].op.axis
inverse_scope, n = s[output].split(n, nparts=1)

fused = s[output].fuse(n, co, ho, wo)
fused = s[output].fuse(n, co, h, w)
bb, tt = s[output].split(fused, 128)
if cb is not None:
s[output].reorder(bb, tt, cb)
s[output].vectorize(cb)

s[output].bind(bb, te.thread_axis("blockIdx.x"))
s[output].bind(tt, te.thread_axis("threadIdx.x"))
Expand Down
60 changes: 59 additions & 1 deletion tests/python/relay/test_conv2d_nchw_texture.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import numpy as np
from tvm import relay
from tvm.relay import testing
from tvm.contrib import utils
from utils.adreno_utils import gpu_preprocess, build_run_compare


Expand Down Expand Up @@ -432,6 +433,63 @@ def test_conv2d_vgg16_winograd_4d():
"bias": tvm.nd.array(bias_data),
}

graph = build_run_compare(mod, params1, {"data": input_shape}, dtype, target)
temp = utils.tempdir()
stat_file = temp.relpath("stat.log")
with open(stat_file, "w") as f:
f.write(
'{"input": ["opencl -keys=adreno,opencl,gpu -device=adreno -max_num_threads=256", "conv2d_nchw_winograd_acc32.image2d", [["TENSOR", [1, 512, 28, 28], "float16"], ["TENSOR", [512, 512, 3, 3], "float16"], [1, 1], [1, 1, 1, 1], [1, 1], "float16"], {}], "config": {"index": 1591, "code_hash": null, "entity": [["auto_unroll_max_step", "ot", 4], ["tile_y", "sp", [-1, 1, 32]], ["tile_x", "sp", [-1, 4, 2]], ["tile_rc", "sp", [-1, 8]]]}, "result": [[0.0037244], 0, 7.06374192237854, 1653898629.7427933], "version": 0.2, "tvm_version": "0.8.dev0"}\n'
)
graph = build_run_compare(
mod, params1, {"data": input_shape}, dtype, target, stat_file=stat_file
)
matches = re.findall("winograd", graph)
assert len(matches) > 0


@tvm.testing.requires_opencl
def test_conv2d_winograd_conv():
target = "opencl --device=adreno"
dtype = "float16"

input_shape = (1, 4, 3, 3)
A = relay.var("data", shape=input_shape, dtype=dtype)
filter_shape3 = (8, 4, 3, 3)
bias_shape3 = (8,)
B3 = relay.var("weight3", shape=filter_shape3, dtype=dtype)
D = relay.nn.conv2d(
A, B3, padding=[1, 1, 1, 1], channels=8, kernel_size=[3, 3], out_dtype=dtype
)

filter_shape4 = (8, 8, 3, 3)
bias_shape4 = (8,)
B4 = relay.var("weight4", shape=filter_shape4, dtype=dtype)
D = relay.nn.conv2d(
D, B4, padding=[1, 1, 1, 1], channels=8, kernel_size=[3, 3], out_dtype=dtype
)
mod = relay.Function([A, B3, B4], D)
np.random.seed(1)
initializer = relay.testing.init.Xavier()
filter_data3 = np.zeros(filter_shape3).astype(dtype)
bias_data3 = np.zeros(bias_shape3).astype(dtype)
filter_data4 = np.zeros(filter_shape4).astype(dtype)
bias_data4 = np.zeros(bias_shape4).astype(dtype)
initializer("weight", filter_data3)
initializer("bias", bias_data3)
initializer("weight", filter_data4)
initializer("bias", bias_data4)
params1 = {
"weight3": tvm.nd.array(filter_data3),
"weight4": tvm.nd.array(filter_data4),
}

temp = utils.tempdir()
stat_file = temp.relpath("stat.log")
with open(stat_file, "w") as f:
f.write(
'{"input": ["opencl -keys=adreno,opencl,gpu -device=adreno -max_num_threads=256", "conv2d_nchw_winograd_acc32.image2d", [["TENSOR", [1, 4, 3, 3], "float16"], ["TENSOR", [8, 4, 3, 3], "float16"], [1, 1], [1, 1, 1, 1], [1, 1], "float16"], {}], "config": {"index": 1591, "code_hash": null, "entity": [["auto_unroll_max_step", "ot", 4], ["tile_y", "sp", [-1, 1, 32]], ["tile_x", "sp", [-1, 4, 2]], ["tile_rc", "sp", [-1, 8]]]}, "result": [[0.0037244], 0, 7.06374192237854, 1653898629.7427933], "version": 0.2, "tvm_version": "0.8.dev0"}\n'
)
graph = build_run_compare(
mod, params1, {"data": input_shape}, dtype, target, stat_file=stat_file
)
matches = re.findall("winograd", graph)
assert len(matches) > 0
25 changes: 20 additions & 5 deletions tests/python/relay/utils/adreno_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import tvm
import numpy as np
from tvm import relay
from tvm import autotvm
from tvm.relay import testing
from tvm.relay.transform import recast
from tvm.contrib import graph_runtime
Expand All @@ -45,7 +46,13 @@ def get_cpu_reference(mod, params1, input_shape, inputs):

# build module run with opencl and cpu, compare results
def build_run_compare(
tvm_mod, params1, input_shape, dtype="float32", target="llvm", gpu_preprocess=None
tvm_mod,
params1,
input_shape,
dtype="float32",
target="llvm",
gpu_preprocess=None,
stat_file=None,
):

if "TVM_TRACKER_HOST" in os.environ and "TVM_TRACKER_PORT" in os.environ:
Expand All @@ -63,10 +70,18 @@ def build_run_compare(
else:
tvm_mod_nchwc = tvm_mod

with relay.build_config(opt_level=3):
graph, lib, params = relay.build(
tvm_mod_nchwc, target_host=target_host, target=target, params=params1
)
if stat_file is not None:
with autotvm.apply_history_best(stat_file):
with tvm.transform.PassContext(opt_level=3):
graph, lib, params = relay.build(
tvm_mod_nchwc, target_host=target_host, target=target, params=params1
)
else:
with tvm.transform.PassContext(opt_level=3):
graph, lib, params = relay.build(
tvm_mod_nchwc, target_host=target_host, target=target, params=params1
)

if run_on_host:
ctx = tvm.opencl()
m = graph_runtime.create(graph, lib, ctx)
Expand Down