# MIGraphX Accelerator Example

In [None]:
import sys
import os
from pathlib import Path
libpath = os.path.join(Path.cwd().parents[1],'py')
sys.path.append(libpath)
sys.path.append('/opt/rocm/lib')


import torch
import migraphx
import torchvision.models as models
from torch_migraphx.fx.fx2mgx import MGXInterpreter
from torch_migraphx.fx.mgx_module import MGXModule
import torch_migraphx.fx.tracer.acc_tracer.acc_tracer as acc_tracer
from torch_migraphx.fx.tracer.acc_tracer.acc_shape_prop import AccShapeProp
from torch_migraphx.fx.tools.mgx_splitter import MGXSplitter

from torch.fft import fft2


Below we define a simple network we will lower to migraphx. It also contains an unsupported operation (fft2) we must handle.

In [None]:
class ConvNet(torch.nn.Module):
    def __init__(self, k, in_ch):
        super(ConvNet, self).__init__()
        self.conv = torch.nn.Conv2d(in_ch, in_ch * 2, k, padding='same')
        self.bn = torch.nn.BatchNorm2d(in_ch * 2)
        self.relu = torch.nn.ReLU()
        self.linear = torch.nn.Linear(224 * 224 * in_ch * 2, 64)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        x = fft2(x).abs()  #unsupported op
        x = x.flatten(1)
        x = self.linear(x)
        return x

In [None]:
k, in_ch = 3, 3
model = ConvNet(k, in_ch)
model.eval()
sample_inputs = [torch.randn(50, 3, 224, 224)]

model = model.cuda()
sample_inputs = [i.cuda() for i in sample_inputs]

First we use our custom fx tracer (acc_tracer) to generate a graph representation of the above module. The custom tracer also normalizes all supported torch operations to map to acc ops.

In [None]:
#Trace model using custom tracer
traced = acc_tracer.trace(model, sample_inputs)
traced.graph.print_tabular()

Next, we split the graph into subgraphs that are supported by migraphx and ones that need to run via the torch implementation. Submodules named 'run_on_acc_{}' are marked to be lowered to migraphx and the ones named 'run_via_torch_{}' are marked to be executed though its original torch implementation.

In [None]:
splitter = MGXSplitter(traced, sample_inputs)
_ = splitter.node_support_preview()

In [None]:
split_mod = splitter()
print(split_mod.graph)

In [None]:
print(split_mod._run_on_acc_0.graph)
print(split_mod._run_via_torch_1.graph)
print(split_mod._run_on_acc_2.graph)

Next, we convert any submodules that are eligible to be lowered to migraphx

In [None]:
# Need sample inputs when lowering submodules
def get_submod_inputs(mod, submod, inputs):
    acc_inputs = None

    def get_input(self, inputs):
        nonlocal acc_inputs
        acc_inputs = inputs

    handle = submod.register_forward_pre_hook(get_input)
    mod(*inputs)
    handle.remove()
    return acc_inputs

In [None]:
for name, _ in split_mod.named_children():
    if "_run_on_acc" in name:
        submod = getattr(split_mod, name)
        # Get submodule inputs for fx2trt
        acc_inputs = get_submod_inputs(split_mod, submod, sample_inputs)
        AccShapeProp(submod).propagate(*acc_inputs)

        # fx2trt replacement
        interp = MGXInterpreter(
            submod,
            acc_inputs
        )
        interp.run()
        mgx_mod = MGXModule(interp.program, interp.get_input_names())

        setattr(split_mod, name, mgx_mod)



The creation of MGXModule automatically runs all optimization passes available in MIGraphX and stores the complied program. We can see the hip instructions by printing the stored programs.

In [None]:
split_mod._run_on_acc_0.program.print()
split_mod._run_on_acc_2.program.print()

Below we ensure that the converted modules produce the same output as the original model

In [None]:
split_mod = split_mod.cuda()
model = model.cuda()
sample_inputs = [i.cuda() for i in sample_inputs]

torch_out = model(*sample_inputs)
lowered_model_out = split_mod(*sample_inputs)

torch.testing.assert_close(torch_out,
                            lowered_model_out,
                            atol=3e-3,
                            rtol=1e-2)


Modules that contain MGXModules as submodules can be saved and loaded in the same manner as torch modules. 

In [None]:
torch.save(split_mod, 'split_mod.pt')

In [None]:
reload_split_mod = torch.load('split_mod.pt')

In [None]:
reload_mod_out = reload_split_mod(*sample_inputs)

torch.testing.assert_close(torch_out,
                            reload_mod_out,
                            atol=3e-3,
                            rtol=1e-2)

In [None]:
from torch_migraphx.fx import lower_to_mgx
lowered_model = lower_to_mgx(model, sample_inputs)

In [None]:
lowered_out = lowered_model(*sample_inputs)
torch_out = model(*sample_inputs)
torch.testing.assert_close(torch_out,
                            lowered_out,
                            atol=3e-3,
                            rtol=1e-2)