# PyTorch 翻译为 ONNX

In [1]:
%cd ..
from utils.onnx_utils import (
    get_input_data_shape_dict,
    make_constant_node, get_onnxruntime_output,
    get_tvm_output, get_tvm_output_with_vm,
    verify_with_ort, verify_with_ort_with_inputs,
    quantize_and_verify_with_ort
)

/media/pc/data/lxw/ai/tvm/xinetzone/tvm-book/doc/tutorials/frontend


In [2]:
import tvm
from tvm import relay

## PyTorch 算子测试

### `unsqueeze_constant`

In [3]:
import numpy as np
import torch
from torch import nn
import onnx
from tvm import relay
import tempfile

class Flatten(nn.Module):
    def forward(self, input_):
        return input_.view(input_.size(0), -1)

with tempfile.NamedTemporaryFile() as f:
    file_name = f.name
    input_size = (1, 16, 32, 32)
    dummy_input = torch.randn(*input_size)
    layer = nn.Sequential(nn.Flatten(), nn.Linear(16 * 32 * 32, 64))
    torch.onnx.export(layer, dummy_input, file_name, export_params=True)

    onnx_model = onnx.load(file_name)
    relay.frontend.from_onnx(onnx_model, {"onnx::Flatten_0": input_size})

verbose: False, log level: Level.ERROR



### `embedding_bag`

In [4]:
def test_aten(target, dev):
    """test_aten"""
    torch.set_grad_enabled(False)

    def _convert_to_onnx(model, inputs):
        file_name = "aten_model.onnx"
        torch.onnx.export(
            model,
            inputs,
            file_name,
            export_params=True,
            verbose=False,
            opset_version=10,
            operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN,
        )
        onnx_model = onnx.load(file_name)
        return onnx_model

    def verify_embedding_bag(num_embedding, embedding_dim, data_shape, num_bags=None):
        dummy_data = torch.randint(0, num_embedding - 1, data_shape)
        tvm_inputs = [dummy_data.numpy()]
        model = torch.nn.EmbeddingBag(num_embedding, embedding_dim)
        onnx_model = _convert_to_onnx(model, dummy_data)
        torch_out = model(dummy_data)
        tvm_out = get_tvm_output_with_vm(
            onnx_model,
            tvm_inputs,
            freeze_params=True,
            target=target,
            dev=dev,
        )
        np.testing.assert_allclose(torch_out.numpy(), tvm_out, atol=5e-7)

    verify_embedding_bag(10, 3, [2, 10])
    verify_embedding_bag(32, 2, [3, 3])

In [5]:
test_aten("llvm", tvm.cpu())

verbose: False, log level: Level.ERROR



AssertionError: Operator numel is not supported.

### `index_put_slice`

In [6]:
class IndexPutModel(torch.nn.Module):
    def __init__(self, indices, values, accumulate):
        super().__init__()
        self.indices = indices
        self.values = values
        self.accumulate = accumulate

    def forward(self, x):
        return x.index_put(self.indices, self.values, self.accumulate)

def _convert_to_onnx(model, dummy_data):
    file_name = "aten_model.onnx"
    torch.onnx.export(
        model,
        dummy_data,
        file_name,
        export_params=True,
        verbose=False,
        opset_version=11,
        operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK,
    )
    onnx_model = onnx.load(file_name)
    return onnx_model

def verify_index_put(data_shape, indices, accumulate):
    target = "llvm"
    dev = tvm.cpu()
    dummy_data = torch.ones(data_shape)
    tvm_inputs = [dummy_data.numpy()]
    values = torch.rand(indices[0].size())
    model = IndexPutModel(indices, values, accumulate)
    onnx_model = _convert_to_onnx(model, dummy_data)
    torch_out = model(dummy_data)

    tvm_out = get_tvm_output_with_vm(onnx_model, tvm_inputs, target, dev, freeze_params=True)
    tvm.testing.assert_allclose(torch_out.numpy(), tvm_out)

shape = (3, 5)
xidx = torch.tensor([0, 1, 2, 2])
yidx = torch.tensor([0, 1, 3, 4])
verify_index_put(shape, [xidx, yidx], True)

shape = (3, 5, 3)
xidx = torch.tensor([0, 1, 2, 2, 0])
yidx = torch.tensor([0, 1, 3, 4, 0])
zidx = torch.tensor([0, 1, 1, 2, 0])
verify_index_put(shape, [xidx, yidx, zidx], False)

def verify_index_put_slice(data_shape, value_shape, accumulate):
    dummy_data = torch.ones(data_shape)
    tvm_inputs = [dummy_data.numpy()]
    indices = []
    index_shape = [1] * len(value_shape)
    index_shape[0] = -1
    for _, v_shape in enumerate(value_shape):
        indices.append(torch.arange(0, v_shape).reshape(tuple(index_shape)))
        index_shape.pop()
    values = torch.rand(value_shape)

    model = IndexPutModel(indices, values, accumulate)
    onnx_model = _convert_to_onnx(model, dummy_data)
    torch_out = model(dummy_data)

    target = "llvm"
    dev = tvm.cpu()
    tvm_out = get_tvm_output_with_vm(onnx_model, tvm_inputs, target, dev, freeze_params=True)
    np.testing.assert_allclose(torch_out.numpy(), tvm_out)

verify_index_put_slice((3, 3), (2, 2), False)
verify_index_put_slice((2, 3, 4), (1, 2, 3), True)
verify_index_put_slice((2, 3, 4, 5), (1, 2, 3, 1), False)

verbose: False, log level: Level.ERROR

verbose: False, log level: Level.ERROR

verbose: False, log level: Level.ERROR

verbose: False, log level: Level.ERROR

verbose: False, log level: Level.ERROR



## `torchvision` 模型

In [None]:
import numpy as np
import torch
import torchvision
import onnx

def check_torch_conversion(model, input_size, target, dev):
    dummy_input = torch.randn(*input_size)
    file_name = f"{model.__name__}.onnx"
    # Set verbose=True for more output
    torch.onnx.export(model(), dummy_input, file_name, export_params=True, verbose=False)
    onnx_model = onnx.load(file_name)
    input_data = np.random.uniform(size=input_size).astype("float32")
    verify_with_ort_with_inputs(
        onnx_model, [input_data], apply_softmax=True, target=target, dev=dev
    )

# def test_alexnet():
# Torch's ONNX export does not support the adaptive pooling used by AlexNet?
# check_torch_conversion(torchvision.models.alexnet, (1,3,224,224))

# Torch's ONNX export does not support the adaptive pooling used by vgg16?
# def test_vgg16():
#     check_torch_conversion(torchvision.models.vgg16, (1,3,224,224))

# TODO(@jroesch): Update Torch + ONNX to support this import.
# def test_squeezenet():
#     # Torch's ONNX export does not support the max pooling used by Squezenet
#     check_torch_conversion(torchvision.models.squeezenet1_0, (1,3,224,224))

# TODO(@jroesch): Update Torch + ONNX to support this import.
# def test_googlenet():
#     check_torch_conversion(torchvision.models.googlenet, (1,3,224,224))

# TODO(@jroesch): Update Torch + ONNX to support this import.
# def test_shufflenetv2():
#     check_torch_conversion(torchvision.models.shufflenetv2, (1,3,224,224))

@tvm.testing.parametrize_targets
def test_densenet(target, dev):
    check_torch_conversion(torchvision.models.densenet161, (1, 3, 224, 224), target, dev)


@tvm.testing.parametrize_targets
def test_inception(target, dev):
    check_torch_conversion(torchvision.models.inception_v3, (1, 3, 224, 224), target, dev)

### `resnet18`

In [None]:
target = "llvm"
dev = tvm.cpu()
check_torch_conversion(torchvision.models.resnet18, (1, 3, 224, 224), target, dev)
# check_torch_conversion(torchvision.models.resnet101, (1,3,224,224))