Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[VM][Adreno] Fix using buffers for weights in VM #15671

Merged
merged 2 commits into from
Sep 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
8 changes: 5 additions & 3 deletions src/relay/transforms/annotate_texture_storage.cc
Original file line number Diff line number Diff line change
Expand Up @@ -174,8 +174,11 @@ class StorageInfo : private transform::DeviceAwareExprVisitor {
for (const auto& ttype : FlattenTupleType(fn->params[i]->checked_type())) {
std::string scope = Scope(ttype->shape, GetVirtualDevice(GetRef<Expr>(call)));
if (expr_attrib.as<Conv2DAttrs>() || expr_attrib.as<Conv2DWinogradAttrs>()) {
String kernel_layout = expr_attrib.as<Conv2DAttrs>()
? expr_attrib.as<Conv2DAttrs>()->kernel_layout
: expr_attrib.as<Conv2DWinogradAttrs>()->kernel_layout;
if ((i == weights_pos) && !ttype->dtype.is_float16() &&
CanUseBuffers(call->args[i], ttype->shape, fn->attrs)) {
CanUseBuffers(call->args[i], ttype->shape, kernel_layout)) {
buffers_params.insert(fn->params[i]);
buffers_args.insert(call->args[i]);
scope = "global";
Expand Down Expand Up @@ -426,10 +429,9 @@ class StorageInfo : private transform::DeviceAwareExprVisitor {
}

bool CanUseBuffers(const Expr param, const Array<PrimExpr> shape,
const tvm::DictAttrs param_attrs) const {
const String kernel_layout) const {
bool use_buffer = false;
if (param.as<ConstantNode>() && shape.size() == 5) {
auto kernel_layout = param_attrs.GetAttr<String>("kernel_layout");
if (kernel_layout == "HWOI4o" || kernel_layout == "HWIO4o") {
int a0 = shape[0].as<IntImmNode>()->value;
int a1 = shape[1].as<IntImmNode>()->value;
Expand Down
77 changes: 66 additions & 11 deletions tests/python/relay/opencl_texture/test_conv2d_nchw_texture.py
Original file line number Diff line number Diff line change
Expand Up @@ -692,7 +692,6 @@ def test_residual_block(remote, target, executor_type, dtype):
{"data": input_shape},
{"data": dtype},
target,
static_memory_scope,
)


Expand Down Expand Up @@ -790,11 +789,12 @@ def test_concat(remote, target, executor_type, dtype):

static_memory_scope = [
"",
"global.texture",
"global",
"global.texture-weight",
"global.texture-weight",
"global",
"global.texture-weight",
"global.texture-nhwc",
"global",
"global.texture-weight",
"",
"",
Expand All @@ -803,8 +803,6 @@ def test_concat(remote, target, executor_type, dtype):
"",
]

static_memory_scope = []

if executor_type == "ge":
build_run_compare(
remote,
Expand All @@ -823,7 +821,6 @@ def test_concat(remote, target, executor_type, dtype):
{"data": input_shape},
{"data": dtype},
target,
static_memory_scope,
)


Expand Down Expand Up @@ -968,7 +965,6 @@ def test_pooling_branching_texture_params(remote, target, executor_type, dtype):
{"data": input_shape},
{"data": dtype},
target,
static_memory_scope,
)


Expand Down Expand Up @@ -1111,7 +1107,6 @@ def test_branching_texture_params(remote, target, executor_type, dtype):
{"data": input_shape},
{"data": dtype},
target,
static_memory_scope,
)


Expand Down Expand Up @@ -1212,7 +1207,6 @@ def test_conv2d_different_lowering_same_op(remote, target, executor_type, dtype)
{"data": input_shape},
{"data": dtype},
target,
static_memory_scope,
)


Expand Down Expand Up @@ -1380,7 +1374,6 @@ def test_injective_nwo_inputs1(remote, target, executor_type, dtype):
{"data": input_shape},
{"data": dtype},
target,
static_memory_scope,
)


Expand Down Expand Up @@ -1495,7 +1488,6 @@ def test_injective_nwo_inputs2(remote, target, executor_type, dtype):
{"data": input_shape},
{"data": dtype},
target,
static_memory_scope,
)


Expand Down Expand Up @@ -1534,5 +1526,68 @@ def test_conv2d_to_3_channels(remote, target, executor_type, dtype):
)


@tvm.testing.requires_opencl
@tvm.testing.parametrize_targets("opencl -device=adreno")
def test_conv2d_weight_on_buffers(remote, target, executor_type, dtype):
target = "opencl -device=adreno"
input_shape = (1, 64, 75, 75)
filter_shape = (64, 64, 3, 3)
bias_shape = (64,)
A = relay.var("data", shape=input_shape, dtype=dtype)
W = relay.var("weight", shape=filter_shape, dtype=dtype)
BS = relay.var("bias", shape=bias_shape, dtype=dtype)
conv = relay.nn.conv2d(A, W, padding=[1, 1, 1, 1], channels=64, kernel_size=(3, 3))
conv = relay.nn.bias_add(conv, BS)
conv = relay.op.nn.relu(conv)

mod = relay.Function([A, W, BS], conv)
np.random.seed(0)
initializer = relay.testing.init.Xavier()
filter_data = np.zeros(filter_shape).astype(dtype)
bias_data = np.zeros(bias_shape).astype(dtype)
initializer("weight", filter_data)
initializer("bias", bias_data)
params1 = {
"weight": tvm.nd.array(filter_data),
"bias": tvm.nd.array(bias_data),
}

if executor_type == "ge":
static_memory_scope = [
"",
"global.texture",
"global",
"global.texture-weight",
"",
"",
]
build_run_compare(
remote,
mod,
params1,
{"data": input_shape},
{"data": dtype},
target,
static_memory_scope,
)
else:
static_memory_scope = """
VM VirtualDevice[0]: device type 1, id 0 and mem_scope
VM VirtualDevice[1]: device type 4, id 0 and mem_scope
VM VirtualDevice[2]: device type 4, id 0 and mem_scope global.texture
VM VirtualDevice[3]: device type 4, id 0 and mem_scope global
VM VirtualDevice[4]: device type 4, id 0 and mem_scope global.texture-weight
"""
build_run_compare_vm(
remote,
mod,
params1,
{"data": input_shape},
{"data": dtype},
target,
static_memory_scope,
)


if __name__ == "__main__":
tvm.testing.main()
18 changes: 5 additions & 13 deletions tests/python/relay/opencl_texture/utils/adreno_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,19 +161,11 @@ def build_run_compare_vm(
tvm_mod_nchwc, target=target, target_host=target_host, params=params1
)

# TODO(echuraev): enable scope checking
## verification that storage_scope has expected textures scopes
# graph_json = json.loads(graph)
# if "storage_scope" in graph_json["attrs"]:
# assert (
# len(static_mem_scopes) == len(graph_json["attrs"]["storage_scope"][1])
# or len(static_mem_scopes) == 0
# )
# else:
# assert len(static_mem_scopes) == 0

# for i in range(0, len(static_mem_scopes)):
# assert static_mem_scopes[i] == graph_json["attrs"]["storage_scope"][1][i]
if len(static_mem_scopes) > 0:
mem_scopes_lines = static_mem_scopes.strip().split("\n")
vm_lines = vmc._get_virtual_devices().strip().split("\n")
for i in range(0, len(mem_scopes_lines)):
assert mem_scopes_lines[i].strip() == vm_lines[i].strip()

if remote is None:
dev = tvm.opencl()
Expand Down