Skip to content

Commit

Permalink
[Unity][MSC][M4.2][Step2] Enable plugin with manager, test plugins in…
Browse files Browse the repository at this point in the history
… compile pipeline (#16581)

enable plugin with manager
  • Loading branch information
Archermmt committed Feb 20, 2024
1 parent 4600002 commit 2066ce9
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 4 deletions.
9 changes: 7 additions & 2 deletions python/tvm/contrib/msc/framework/torch/frontend/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def from_torch(
build_config: Optional[Dict[str, str]] = None,
opt_config: Optional[Dict[str, str]] = None,
as_msc: bool = True,
custom_convert_map: dict = None,
) -> Tuple[Union[MSCGraph, tvm.IRModule], Dict[str, tvm.nd.array]]:
"""Change torch nn.Module to MSCGraph.
Expand All @@ -91,6 +92,8 @@ def from_torch(
The config for optimize the relay before translate.
as_msc: bool
Set to to return msc graph, otherwise relax mod
custom_convert_map: dict
The convert map for plugin
Returns
-------
Expand All @@ -103,7 +106,7 @@ def from_torch(
if via_relax:
graph_model, params = torch.fx.symbolic_trace(model), None
with torch.no_grad():
relax_mod = from_fx(graph_model, input_info)
relax_mod = from_fx(graph_model, input_info, custom_convert_map=custom_convert_map)
else:
datas = [np.random.rand(*i[0]).astype(i[1]) for i in input_info]
torch_datas = [torch.from_numpy(i) for i in datas]
Expand All @@ -116,7 +119,9 @@ def from_torch(
shape_list = list(zip(input_names, input_info))
else:
shape_list = [("input" + str(idx), i_info) for idx, i_info in enumerate(input_info)]
relay_mod, params = tvm.relay.frontend.from_pytorch(scripted_model, shape_list)
relay_mod, params = tvm.relay.frontend.from_pytorch(
scripted_model, shape_list, custom_convert_map=custom_convert_map
)
relax_mod = relay_to_relax(relay_mod, params, trans_config, build_config, opt_config)
if not as_msc:
return relax_mod, params
Expand Down
33 changes: 31 additions & 2 deletions python/tvm/relax/frontend/torch/fx_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1459,17 +1459,34 @@ def create_convert_map(self):
"scaled_dot_product_attention": self._scaled_dot_product_attention,
}

def update_convert_map(self, custom_convert_map: dict):
"""Update self.convert_map with custom convert map
Parameters
----------
custom_convert_map : Dictionary of str to Relax op
A custom op conversion map in the same format as self.convert_map
"""

self.convert_map.update(custom_convert_map)

def from_fx(
self,
model,
input_info: List[Tuple[Tuple[int], str]],
keep_params_as_input: bool,
unwrap_unit_return_tuple: bool,
no_bind_return_tuple: bool,
custom_convert_map: dict = None,
) -> tvm.IRModule:
"""Convert a PyTorch FX GraphModule to a Relax program."""
from torch import fx

if custom_convert_map:
custom_ops = set(custom_convert_map.keys())
self.update_convert_map(custom_convert_map)
else:
custom_ops = set()
self.named_modules = dict(model.named_modules())

graph: fx.Graph = model.graph
Expand Down Expand Up @@ -1548,7 +1565,10 @@ def from_fx(
assert (
func_name in self.convert_map
), f"Unsupported function type {func_name}"
self.env[node] = self.convert_map[func_name](node)
if func_name in custom_ops:
self.env[node] = self.convert_map[func_name](node, self)
else:
self.env[node] = self.convert_map[func_name](node)
elif node.op == "call_method":
assert (
node.target in self.convert_map
Expand All @@ -1572,6 +1592,7 @@ def from_fx(
keep_params_as_input: bool = False,
unwrap_unit_return_tuple: bool = False,
no_bind_return_tuple: bool = False,
custom_convert_map: dict = None,
) -> tvm.IRModule:
"""Convert a PyTorch FX GraphModule to a Relax program
Expand All @@ -1594,6 +1615,9 @@ def from_fx(
A boolean flag indicating whether to bind the return tuple as a relax var.
If the flag is true and the return value is a tuple, it will not bind it to a var.
custom_convert_map : Dictionary of str to Relax op
A custom op conversion map in the same format as TorchFXImporter.convert_map
Returns
-------
output : tvm.IRModule
Expand Down Expand Up @@ -1662,5 +1686,10 @@ def forward(self, input):
check the placeholder rows in the beginning of the tabular.
"""
return TorchFXImporter().from_fx(
model, input_info, keep_params_as_input, unwrap_unit_return_tuple, no_bind_return_tuple
model,
input_info,
keep_params_as_input,
unwrap_unit_return_tuple,
no_bind_return_tuple,
custom_convert_map=custom_convert_map,
)
58 changes: 58 additions & 0 deletions tests/python/contrib/test_msc/test_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from tvm import relax
from tvm.relax.transform import BindParams
from tvm.script import relax as R
from tvm.contrib.msc.pipeline import MSCManager
from tvm.contrib.msc.plugin import build_plugins
from tvm.contrib.msc.core.utils.namespace import MSCFramework
from tvm.contrib.msc.core import utils as msc_utils
Expand Down Expand Up @@ -287,6 +288,39 @@ def _test_torch_plugin(manager):
assert outputs.min() >= 0 and outputs.max() <= 0.5


def _test_with_manager(plugins, compile_type, expected_info):
"""Test the plugin with manager"""

path = "test_plugin_" + compile_type
model = _get_torch_model(plugins[MSCFramework.TORCH])
if torch.cuda.is_available():
model = model.to(torch.device("cuda:0"))
config = {
"workspace": msc_utils.msc_dir(path),
"model_type": MSCFramework.TORCH,
"verbose": "critical",
"inputs": [["input_0", [1, 3, 224, 224], "float32"]],
"outputs": ["output"],
"dataset": {"prepare": {"loader": "from_random", "max_iter": 5}},
"prepare": {"profile": {"benchmark": {"repeat": 10}}},
"baseline": {
"profile": {"check": {"atol": 1e-2, "rtol": 1e-2}, "benchmark": {"repeat": 10}},
},
"compile": {
"run_type": compile_type,
"profile": {"check": {"atol": 1e-2, "rtol": 1e-2}, "benchmark": {"repeat": 10}},
},
}
manager = MSCManager(model, config, plugins=plugins)
report = manager.run_pipe()
model_info = manager.runner.model_info
manager.destory()
assert report["success"], "Failed to run pipe for torch -> {}".format(compile_type)
assert msc_utils.dict_equal(
model_info, expected_info
), "Model info {} mismatch with expected {}".format(model_info, expected_info)


def test_plugin():
"""Test the plugins"""

Expand All @@ -302,6 +336,30 @@ def test_plugin():
_test_tvm_plugin(managers[MSCFramework.TVM], "cuda")
_test_torch_plugin(managers[MSCFramework.TORCH])

# test the plugin with manager
model_info = {
"inputs": [
{"name": "input_0", "shape": [1, 3, 224, 224], "dtype": "float32", "layout": "NCHW"}
],
"outputs": [
{"name": "output", "shape": [1, 6, 218, 218], "dtype": "float32", "layout": "NCHW"}
],
"nodes": {"total": 4, "input": 1, "msc.conv2d_bias": 1, "MyRelu": 1, "nn.max_pool2d": 1},
}
_test_with_manager(managers, MSCFramework.TORCH, model_info)
_test_with_manager(managers, MSCFramework.TVM, model_info)
if tvm.get_global_func("relax.ext.tensorrt", True) is not None:
byoc_info = {
"inputs": [
{"name": "input_0", "shape": [1, 3, 224, 224], "dtype": "float32", "layout": "NCHW"}
],
"outputs": [
{"name": "output", "shape": [1, 6, 218, 218], "dtype": "float32", "layout": ""}
],
"nodes": {"total": 2, "input": 1, "msc_tensorrt": 1},
}
_test_with_manager(managers, MSCFramework.TENSORRT, byoc_info)

plugin_root.destory()


Expand Down

0 comments on commit 2066ce9

Please sign in to comment.