# MIGraphX Accelerator Example

In [1]:
import sys
import os
from pathlib import Path
libpath = os.path.join(Path.cwd().parents[1],'py')
sys.path.append(libpath)


import torch
import migraphx
import torchvision.models as models
from torch_migraphx.fx.fx2mgx import MGXInterpreter
from torch_migraphx.fx.mgx_module import MGXModule
import torch_migraphx.fx.tracer.acc_tracer.acc_tracer as acc_tracer
from torch_migraphx.fx.tracer.acc_tracer.acc_shape_prop import AccShapeProp
from torch_migraphx.fx.tools.mgx_splitter import MGXSplitter

from torch.fft import fft2


Below we define a simple network we will lower to migraphx. It also contains an unsupported operation (fft2) we must handle.

In [2]:
class ConvNet(torch.nn.Module):
    def __init__(self, k, in_ch):
        super(ConvNet, self).__init__()
        self.conv = torch.nn.Conv2d(in_ch, in_ch * 2, k, padding='same')
        self.bn = torch.nn.BatchNorm2d(in_ch * 2)
        self.relu = torch.nn.ReLU()
        self.linear = torch.nn.Linear(224 * 224 * in_ch * 2, 64)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        x = fft2(x).abs()  #unsupported op
        x = x.flatten(1)
        x = self.linear(x)
        return x

In [14]:
k, in_ch = 3, 3
model = ConvNet(k, in_ch)
model.eval()
sample_inputs = [torch.randn(50, 3, 224, 224)]

model = model.cuda()
sample_inputs = [i.cuda() for i in sample_inputs]

First we use our custom fx tracer (acc_tracer) to generate a graph representation of the above module. The custom tracer also normalizes all supported torch operations to map to acc ops.

In [15]:
#Trace model using custom tracer
traced = acc_tracer.trace(model, sample_inputs)
traced.graph.print_tabular()

opcode         name             target                                   args         kwargs
-------------  ---------------  ---------------------------------------  -----------  ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------
placeholder    x                x                                        ()           {}
get_attr       conv_weight      conv.weight                              ()           {}
get_attr       conv_bias        conv.bias                                ()           {}
call_function  conv2d_1         <function conv2d at 0x7f4b3faae160>      ()           {'input': x, 'weight': conv_weight, 'bias': conv_bias, 'stride': (1, 1), 'padding': 'same', 'dilation': (1, 1), 'groups': 1}
get_attr       bn_weight        bn.weight                                ()           {}
get_attr       bn_bias          bn.bias                                  ()       

Next, we split the graph into subgraphs that are supported by migraphx and ones that need to run via the torch implementation. Submodules named 'run_on_acc_{}' are marked to be lowered to migraphx and the ones named 'run_via_torch_{}' are marked to be executed though its original torch implementation.

In [16]:
splitter = MGXSplitter(traced, sample_inputs)
_ = splitter.node_support_preview()


Supported node types in the model:
acc_ops.conv2d: ((), {'input': torch.float32, 'weight': torch.float32, 'bias': torch.float32})
acc_ops.batch_norm: ((), {'input': torch.float32, 'running_mean': torch.float32, 'running_var': torch.float32, 'weight': torch.float32, 'bias': torch.float32})
acc_ops.relu: ((), {'input': torch.float32})
acc_ops.flatten: ((), {'input': torch.float32})
acc_ops.linear: ((), {'input': torch.float32, 'weight': torch.float32, 'bias': torch.float32})

Unsupported node types in the model:
torch._C._fft.fft_fft2: ((torch.float32,), {})
abs: ((torch.complex64,), {})



In [17]:
split_mod = splitter()
print(split_mod.graph)

graph():
    %x : [#users=1] = placeholder[target=x]
    %_run_on_acc_0 : [#users=1] = call_module[target=_run_on_acc_0](args = (%x,), kwargs = {})
    %_run_via_torch_1 : [#users=1] = call_module[target=_run_via_torch_1](args = (%_run_on_acc_0,), kwargs = {})
    %_run_on_acc_2 : [#users=1] = call_module[target=_run_on_acc_2](args = (%_run_via_torch_1,), kwargs = {})
    return _run_on_acc_2


In [6]:
print(split_mod._run_on_acc_0.graph)
print(split_mod._run_via_torch_1.graph)
print(split_mod._run_on_acc_2.graph)

graph():
    %x : [#users=1] = placeholder[target=x]
    %conv_weight : [#users=1] = get_attr[target=conv.weight]
    %conv_bias : [#users=1] = get_attr[target=conv.bias]
    %conv2d_1 : [#users=1] = call_function[target=torch_migraphx.fx.tracer.acc_tracer.acc_ops.conv2d](args = (), kwargs = {input: %x, weight: %conv_weight, bias: %conv_bias, stride: (1, 1), padding: same, dilation: (1, 1), groups: 1})
    %bn_running_mean : [#users=1] = get_attr[target=bn.running_mean]
    %bn_running_var : [#users=1] = get_attr[target=bn.running_var]
    %bn_weight : [#users=1] = get_attr[target=bn.weight]
    %bn_bias : [#users=1] = get_attr[target=bn.bias]
    %batch_norm_1 : [#users=1] = call_function[target=torch_migraphx.fx.tracer.acc_tracer.acc_ops.batch_norm](args = (), kwargs = {input: %conv2d_1, running_mean: %bn_running_mean, running_var: %bn_running_var, weight: %bn_weight, bias: %bn_bias, training: False, momentum: 0.1, eps: 1e-05})
    %relu_1 : [#users=1] = call_function[target=torch_mi

Next, we convert any submodules that are eligible to be lowered to migraphx

In [18]:
# Need sample inputs when lowering submodules
def get_submod_inputs(mod, submod, inputs):
    acc_inputs = None

    def get_input(self, inputs):
        nonlocal acc_inputs
        acc_inputs = inputs

    handle = submod.register_forward_pre_hook(get_input)
    mod(*inputs)
    handle.remove()
    return acc_inputs

In [19]:
for name, _ in split_mod.named_children():
    if "_run_on_acc" in name:
        submod = getattr(split_mod, name)
        # Get submodule inputs for fx2trt
        acc_inputs = get_submod_inputs(split_mod, submod, sample_inputs)
        AccShapeProp(submod).propagate(*acc_inputs)

        # fx2trt replacement
        interp = MGXInterpreter(
            submod,
            acc_inputs
        )
        interp.run()
        mgx_mod = MGXModule(interp.program, interp.get_input_names())

        setattr(split_mod, name, mgx_mod)



The creation of MGXModule automatically runs all optimization passes available in MIGraphX and stores the complied program. We can see the hip instructions by printing the stored programs.

In [20]:
split_mod._run_on_acc_0.program.print()
split_mod._run_on_acc_2.program.print()

module: "main"
main:@0 = check_context::migraphx::version_1::gpu::context -> float_type, {}, {}
main:@1 = hip::hip_allocate_memory[shape=float_type, {0}, {1},id=main:scratch] -> float_type, {0}, {1}
main:@2 = hip::hip_copy_literal[id=main:@literal:1] -> float_type, {6}, {1}
main:@3 = hip::hip_copy_literal[id=main:@literal:0] -> float_type, {6, 3, 3, 3}, {27, 9, 3, 1}
main:#output_0 = @param:main:#output_0 -> float_type, {50, 6, 224, 224}, {301056, 50176, 224, 1}
main:@5 = broadcast[axis=1,out_lens={50, 6, 224, 224}](main:@2) -> float_type, {50, 6, 224, 224}, {0, 1, 0, 0}
x = @param:x -> float_type, {50, 3, 224, 224}, {150528, 50176, 224, 1}
main:@7 = gpu::miopen_fusion[ops={{op=convolution[padding={1, 1, 1, 1},stride={1, 1},dilation={1, 1},group=1,padding_mode=0],alpha=1,beta=0}, {op=add,alpha=1,beta=0}, {op=relu,alpha=1,beta=0}}](x,main:@3,main:@5,main:#output_0) -> float_type, {50, 6, 224, 224}, {301056, 50176, 224, 1}
main:@8 = @return(main:@7)


module: "main"
main:@0 = check_conte

Below we ensure that the converted modules produce the same output as the original model

In [26]:
split_mod = split_mod.cuda()
model = model.cuda()
sample_inputs = [i.cuda() for i in sample_inputs]

torch_out = model(*sample_inputs)
lowered_model_out = split_mod(*sample_inputs)

torch.testing.assert_close(torch_out,
                            lowered_model_out,
                            atol=3e-3,
                            rtol=1e-2)


Modules that contain MGXModules as submodules can be saved and loaded in the same manner as torch modules. 

In [27]:
torch.save(split_mod, 'split_mod.pt')

In [28]:
reload_split_mod = torch.load('split_mod.pt')

In [29]:
reload_mod_out = reload_split_mod(*sample_inputs)

torch.testing.assert_close(torch_out,
                            reload_mod_out,
                            atol=3e-3,
                            rtol=1e-2)