Skip to content

Commit

Permalink
[Relay tests] AlterOpLayout - Temporary attr update (#4357)
Browse files Browse the repository at this point in the history
  • Loading branch information
anijain2305 authored and tqchen committed Nov 19, 2019
1 parent f1d6f33 commit 26eb405
Show file tree
Hide file tree
Showing 11 changed files with 829 additions and 671 deletions.
6 changes: 6 additions & 0 deletions include/tvm/relay/op.h
Expand Up @@ -258,6 +258,12 @@ class OpRegistry {
inline OpRegistry& set_attr(const std::string& attr_name, // NOLINT(*)
const ValueType& value, int plevel = 10);

/*!
* \brief Resets an attr of the registry.
* \param attr_name The name of the attribute.
*/
inline void reset_attr(const std::string& attr_name);

// set the name of the op to be the same as registry
inline OpRegistry& set_name() { // NOLINT(*)
if (get()->name.length() == 0) {
Expand Down
10 changes: 10 additions & 0 deletions python/tvm/relay/op/op.py
Expand Up @@ -64,6 +64,16 @@ def set_attr(self, attr_name, value, plevel=10):
"""
_OpSetAttr(self, attr_name, value, plevel)

def reset_attr(self, attr_name):
"""Reset attribute about the operator.
Parameters
----------
attr_name : str
The attribute name
"""
_OpResetAttr(self, attr_name)


def get(op_name):
"""Get the Op for a given name
Expand Down
1 change: 1 addition & 0 deletions python/tvm/relay/testing/__init__.py
Expand Up @@ -37,6 +37,7 @@
from . import vgg
from . import densenet
from . import yolo_detection
from . import temp_op_attr

from .config import ctx_list
from .init import create_workload
Expand Down
63 changes: 63 additions & 0 deletions python/tvm/relay/testing/temp_op_attr.py
@@ -0,0 +1,63 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
""" Defines a TempOpAttr class that allows temporarily changing an attr of the
operator to allow unit testing. This is useful for AlterOpLayout and Legalize
tests."""

from tvm import relay

class TempOpAttr(object):
""" Temporarily changes the attr of an op. """
def __init__(self, op_name, attr_key, attr_value):
""" Saves the required info for RAII pattern usage.
Parameters
----------
op_name : str
The op name.
attr_key : str
The attribute name.
attr_value : object
The attribute value.
Examples
--------
.. code-block:: python
# Temporarily update FTVMAlterOpLayout to a user-defined packed function.
# After the test is finished, the attr value will be set back to the original value.
with TempOpAttr("nn.conv2d", "FTVMAlterOpLayout", alter_conv2d):
my_mod = relay.transform.AlterOpLayout()(my_mod)
"""
self.op = relay.op.get(op_name)
self.attr_key = attr_key
self.attr_value = attr_value

def __enter__(self):
self.older_attr = self.op.get_attr(self.attr_key)
self.op.reset_attr(self.attr_key)
self.op.set_attr(self.attr_key, self.attr_value)
return self

def __exit__(self, ptype, value, trace):
self.op.reset_attr(self.attr_key)
if self.older_attr:
self.op.set_attr(self.attr_key, self.older_attr)
28 changes: 27 additions & 1 deletion src/relay/ir/op.cc
Expand Up @@ -95,6 +95,20 @@ const bool Op::HasGenericAttr(const std::string& key) {
return true;
}

// Resets attr of the OpMap.
void OpRegistry::reset_attr(const std::string& key) {
OpManager* mgr = OpManager::Global();
std::lock_guard<std::mutex> lock(mgr->mutex);
std::unique_ptr<GenericOpMap>& op_map = mgr->attr[key];
if (op_map == nullptr) {
return;
}
uint32_t index = op_->index_;
if (op_map->data_.size() > index) {
op_map->data_[index] = std::make_pair(TVMRetValue(), 0);
}
}

void OpRegistry::UpdateAttr(const std::string& key,
TVMRetValue value,
int plevel) {
Expand All @@ -113,7 +127,10 @@ void OpRegistry::UpdateAttr(const std::string& key,
CHECK(p.second != plevel)
<< "Attribute " << key << " of operator " << this->name
<< " is already registered with same plevel=" << plevel;
if (p.second < plevel) {
CHECK(value.type_code() != kNull)
<< "Registered packed_func is Null for " << key
<< " of operator " << this->name;
if (p.second < plevel && value.type_code() != kNull) {
op_map->data_[index] = std::make_pair(value, plevel);
}
}
Expand Down Expand Up @@ -152,6 +169,15 @@ TVM_REGISTER_API("relay.op._OpSetAttr")
reg.set_attr(attr_name, value, plevel);
});

TVM_REGISTER_API("relay.op._OpResetAttr")
.set_body([](TVMArgs args, TVMRetValue* rv) {
Op op = args[0];
std::string attr_name = args[1];
auto& reg =
OpRegistry::Registry()->__REGISTER_OR_GET__(op->name);
reg.reset_attr(attr_name);
});

TVM_REGISTER_API("relay.op._Register")
.set_body([](TVMArgs args, TVMRetValue* rv) {
std::string op_name = args[0];
Expand Down
47 changes: 47 additions & 0 deletions tests/python/relay/test_ir_op.py
Expand Up @@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
from tvm import relay
from tvm.relay.testing.temp_op_attr import TempOpAttr

def test_op_attr():
log_op = relay.op.get("log")
Expand All @@ -27,6 +28,50 @@ def test(x):
assert log_op.get_attr("ftest") is None
assert relay.op.get("exp").get_attr("ftest")(1) == 2

def test_op_reset_attr():
""" Tests reset_attr functionality. """
def add1(x):
return x + 1

def add2(x):
return x + 2

# Register fadd1 and fadd2 attributes.
relay.op.register("exp", "fadd1", add1)
relay.op.register("log", "fadd1", add1)
relay.op.register("log", "fadd2", add2)

# Reset log fadd1 attr.
log_op = relay.op.get("log")
log_op.reset_attr("fadd1")

# Check that fadd1 attr is resetted.
assert log_op.get_attr("fadd1") is None

# Check that fadd1 attr of other ops are intact.
assert relay.op.get("exp").get_attr("fadd1")(1) == 2

# Check that other attrs of the log op are intact.
assert relay.op.get("log").get_attr("fadd2")(1) == 3

def test_op_temp_attr():
""" Tests reset_attr functionality. """
def add1(x):
return x + 1

def add2(x):
return x + 2

# Set original attr value is add1.
relay.op.register("sqrt", "ftest", add1)

with TempOpAttr("sqrt", "ftest", add2):
# Check that the attr value is updated to add2.
assert relay.op.get("sqrt").get_attr("ftest")(1) == 3

# Check that the attr value is recovered to add1.
assert relay.op.get("sqrt").get_attr("ftest")(1) == 2

def test_op_level1():
x = relay.Var("x")

Expand All @@ -47,5 +92,7 @@ def test_op_level3():

if __name__ == "__main__":
test_op_attr()
test_op_reset_attr()
test_op_temp_attr()
test_op_level1()
test_op_level3()

0 comments on commit 26eb405

Please sign in to comment.