Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Unity][MSC][M4.2][Step2] Enable plugin with manager, test plugins in compile pipeline #16581

Merged
merged 1 commit into from
Feb 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
9 changes: 7 additions & 2 deletions python/tvm/contrib/msc/framework/torch/frontend/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def from_torch(
build_config: Optional[Dict[str, str]] = None,
opt_config: Optional[Dict[str, str]] = None,
as_msc: bool = True,
custom_convert_map: dict = None,
) -> Tuple[Union[MSCGraph, tvm.IRModule], Dict[str, tvm.nd.array]]:
"""Change torch nn.Module to MSCGraph.

Expand All @@ -91,6 +92,8 @@ def from_torch(
The config for optimize the relay before translate.
as_msc: bool
Set to to return msc graph, otherwise relax mod
custom_convert_map: dict
The convert map for plugin

Returns
-------
Expand All @@ -103,7 +106,7 @@ def from_torch(
if via_relax:
graph_model, params = torch.fx.symbolic_trace(model), None
with torch.no_grad():
relax_mod = from_fx(graph_model, input_info)
relax_mod = from_fx(graph_model, input_info, custom_convert_map=custom_convert_map)
else:
datas = [np.random.rand(*i[0]).astype(i[1]) for i in input_info]
torch_datas = [torch.from_numpy(i) for i in datas]
Expand All @@ -116,7 +119,9 @@ def from_torch(
shape_list = list(zip(input_names, input_info))
else:
shape_list = [("input" + str(idx), i_info) for idx, i_info in enumerate(input_info)]
relay_mod, params = tvm.relay.frontend.from_pytorch(scripted_model, shape_list)
relay_mod, params = tvm.relay.frontend.from_pytorch(
scripted_model, shape_list, custom_convert_map=custom_convert_map
)
relax_mod = relay_to_relax(relay_mod, params, trans_config, build_config, opt_config)
if not as_msc:
return relax_mod, params
Expand Down
33 changes: 31 additions & 2 deletions python/tvm/relax/frontend/torch/fx_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1459,17 +1459,34 @@ def create_convert_map(self):
"scaled_dot_product_attention": self._scaled_dot_product_attention,
}

def update_convert_map(self, custom_convert_map: dict):
"""Update self.convert_map with custom convert map

Parameters
----------
custom_convert_map : Dictionary of str to Relax op
A custom op conversion map in the same format as self.convert_map
"""

self.convert_map.update(custom_convert_map)

def from_fx(
self,
model,
input_info: List[Tuple[Tuple[int], str]],
keep_params_as_input: bool,
unwrap_unit_return_tuple: bool,
no_bind_return_tuple: bool,
custom_convert_map: dict = None,
) -> tvm.IRModule:
"""Convert a PyTorch FX GraphModule to a Relax program."""
from torch import fx

if custom_convert_map:
custom_ops = set(custom_convert_map.keys())
self.update_convert_map(custom_convert_map)
else:
custom_ops = set()
self.named_modules = dict(model.named_modules())

graph: fx.Graph = model.graph
Expand Down Expand Up @@ -1548,7 +1565,10 @@ def from_fx(
assert (
func_name in self.convert_map
), f"Unsupported function type {func_name}"
self.env[node] = self.convert_map[func_name](node)
if func_name in custom_ops:
self.env[node] = self.convert_map[func_name](node, self)
else:
self.env[node] = self.convert_map[func_name](node)
elif node.op == "call_method":
assert (
node.target in self.convert_map
Expand All @@ -1572,6 +1592,7 @@ def from_fx(
keep_params_as_input: bool = False,
unwrap_unit_return_tuple: bool = False,
no_bind_return_tuple: bool = False,
custom_convert_map: dict = None,
) -> tvm.IRModule:
"""Convert a PyTorch FX GraphModule to a Relax program

Expand All @@ -1594,6 +1615,9 @@ def from_fx(
A boolean flag indicating whether to bind the return tuple as a relax var.
If the flag is true and the return value is a tuple, it will not bind it to a var.

custom_convert_map : Dictionary of str to Relax op
A custom op conversion map in the same format as TorchFXImporter.convert_map

Returns
-------
output : tvm.IRModule
Expand Down Expand Up @@ -1662,5 +1686,10 @@ def forward(self, input):
check the placeholder rows in the beginning of the tabular.
"""
return TorchFXImporter().from_fx(
model, input_info, keep_params_as_input, unwrap_unit_return_tuple, no_bind_return_tuple
model,
input_info,
keep_params_as_input,
unwrap_unit_return_tuple,
no_bind_return_tuple,
custom_convert_map=custom_convert_map,
)
58 changes: 58 additions & 0 deletions tests/python/contrib/test_msc/test_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from tvm import relax
from tvm.relax.transform import BindParams
from tvm.script import relax as R
from tvm.contrib.msc.pipeline import MSCManager
from tvm.contrib.msc.plugin import build_plugins
from tvm.contrib.msc.core.utils.namespace import MSCFramework
from tvm.contrib.msc.core import utils as msc_utils
Expand Down Expand Up @@ -287,6 +288,39 @@ def _test_torch_plugin(manager):
assert outputs.min() >= 0 and outputs.max() <= 0.5


def _test_with_manager(plugins, compile_type, expected_info):
"""Test the plugin with manager"""

path = "test_plugin_" + compile_type
model = _get_torch_model(plugins[MSCFramework.TORCH])
if torch.cuda.is_available():
model = model.to(torch.device("cuda:0"))
config = {
"workspace": msc_utils.msc_dir(path),
"model_type": MSCFramework.TORCH,
"verbose": "critical",
"inputs": [["input_0", [1, 3, 224, 224], "float32"]],
"outputs": ["output"],
"dataset": {"prepare": {"loader": "from_random", "max_iter": 5}},
"prepare": {"profile": {"benchmark": {"repeat": 10}}},
"baseline": {
"profile": {"check": {"atol": 1e-2, "rtol": 1e-2}, "benchmark": {"repeat": 10}},
},
"compile": {
"run_type": compile_type,
"profile": {"check": {"atol": 1e-2, "rtol": 1e-2}, "benchmark": {"repeat": 10}},
},
}
manager = MSCManager(model, config, plugins=plugins)
report = manager.run_pipe()
model_info = manager.runner.model_info
manager.destory()
assert report["success"], "Failed to run pipe for torch -> {}".format(compile_type)
assert msc_utils.dict_equal(
model_info, expected_info
), "Model info {} mismatch with expected {}".format(model_info, expected_info)


def test_plugin():
"""Test the plugins"""

Expand All @@ -302,6 +336,30 @@ def test_plugin():
_test_tvm_plugin(managers[MSCFramework.TVM], "cuda")
_test_torch_plugin(managers[MSCFramework.TORCH])

# test the plugin with manager
model_info = {
"inputs": [
{"name": "input_0", "shape": [1, 3, 224, 224], "dtype": "float32", "layout": "NCHW"}
],
"outputs": [
{"name": "output", "shape": [1, 6, 218, 218], "dtype": "float32", "layout": "NCHW"}
],
"nodes": {"total": 4, "input": 1, "msc.conv2d_bias": 1, "MyRelu": 1, "nn.max_pool2d": 1},
}
_test_with_manager(managers, MSCFramework.TORCH, model_info)
_test_with_manager(managers, MSCFramework.TVM, model_info)
if tvm.get_global_func("relax.ext.tensorrt", True) is not None:
byoc_info = {
"inputs": [
{"name": "input_0", "shape": [1, 3, 224, 224], "dtype": "float32", "layout": "NCHW"}
],
"outputs": [
{"name": "output", "shape": [1, 6, 218, 218], "dtype": "float32", "layout": ""}
],
"nodes": {"total": 2, "input": 1, "msc_tensorrt": 1},
}
_test_with_manager(managers, MSCFramework.TENSORRT, byoc_info)

plugin_root.destory()


Expand Down