Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Frontend][Relay] Fix MXNet frontend to support NLP backbones in GluonNLP #6699

Merged
merged 18 commits into from Oct 18, 2020

Conversation

sxjscience
Copy link
Member

Fix the MXNet 2.0 integration in relay. Tested the BERT and ALBERT model in the new GluonNLP 1.0 and has passed the test. I will later add unittests in GluonNLP side to ensure that the backbones can be run with the graph runtime.

import mxnet as mx
import numpy as np
import gluonnlp
from gluonnlp.models import get_backbone
import numpy.testing as npt

mx.npx.set_np()

model_cls, cfg, tokenizer, backbone_param_path, _ = get_backbone('google_albert_base_v2')

model = model_cls.from_cfg(cfg)
model.load_parameters(backbone_param_path)
model.hybridize()


batch_size = 1
seq_length = 128
token_ids = mx.np.random.randint(0, cfg.MODEL.vocab_size, (batch_size, seq_length), dtype=np.int32)
token_types = mx.np.random.randint(0, 2, (batch_size, seq_length), dtype=np.int32)
valid_length = mx.np.random.randint(seq_length // 2, seq_length, (batch_size,), dtype=np.int32)
mx_out = model(token_ids, token_types, valid_length)

import tvm
from tvm import relay
import tvm.contrib.graph_runtime as runtime

shape_dict = {
    'data0': (batch_size, seq_length),
    'data1': (batch_size, seq_length),
    'data2': (batch_size,)
}

dtype_dict = {
    'data0': 'int32',
    'data1': 'int32',
    'data2': 'int32'
}

sym = model._cached_graph[1]

params = {}
for k, v in model.collect_params().items():
    params[v._var_name] = tvm.nd.array(v.data().asnumpy())
mod, params = relay.frontend.from_mxnet(sym, shape=shape_dict, dtype=dtype_dict, arg_params=params)
print(mod)
# G4
target = "cuda -model=t4"

with relay.build_config(opt_level=3, required_pass=["FastMath"]):
    graph, lib, cparams = relay.build(mod, target, params=params)

ctx = tvm.gpu()
rt = runtime.create(graph, lib, ctx)
rt.set_input(**cparams)
rt.set_input(data0=token_ids, data1=token_types, data2=valid_length)
rt.run()
for i in range(rt.get_num_outputs()):
    out = rt.get_output(i)
    print(out.asnumpy())# verify the correctness
    npt.assert_allclose(out.asnumpy(), mx_out[i].asnumpy(), rtol=1e-3, atol=1e-2)

Update type_relations.cc

Update transform.cc

Update transform.cc

Update transform.cc

Update transform.cc

Update transform.cc

Update mxnet.py

Update mxnet.py

Update mxnet.py

Update mxnet.py

Update mxnet.py

Update mxnet.py

Update mxnet.py

Update mxnet.py

Update mxnet.py

Update mxnet.py

Update mxnet.py

Update mxnet.py

Update mxnet.py

Update mxnet.py

Update mxnet.py

Update mxnet.py

Update mxnet.py

Update mxnet.py

Update mxnet.py

update

Update mxnet.py

debug

Update generic.py

Update topi_integration.py

fix bug

update

Update test_forward.py

Update test_forward.py

fix test case

Update mxnet.py

update

Update mxnet.py

Update mxnet.py

Update test_forward.py

Update mxnet.py

Update mxnet.py

Update test_forward.py

Update mxnet.py

Update mxnet.py

Update mxnet.py

debug

Update mxnet.py

Update mxnet.py

Update test_forward.py

Update mxnet.py
@sxjscience
Copy link
Member Author

@yzhliu @comaniac @icemelon9

python/tvm/relay/frontend/mxnet.py Outdated Show resolved Hide resolved
python/tvm/relay/frontend/mxnet.py Outdated Show resolved Hide resolved
python/tvm/relay/frontend/mxnet.py Outdated Show resolved Hide resolved
python/tvm/relay/frontend/mxnet.py Outdated Show resolved Hide resolved
python/tvm/topi/x86/batch_matmul.py Show resolved Hide resolved
python/tvm/relay/frontend/mxnet.py Outdated Show resolved Hide resolved
@jroesch
Copy link
Member

jroesch commented Oct 16, 2020

As we add more tests can we measure what kind of time increase this will induce in CI? integration tests are becoming increasingly slow and expensive to run. cc @areusch and @tkonolige

@sxjscience
Copy link
Member Author

The integration tests take a very long time because there are two many combinations. For example: https://github.com/apache/incubator-tvm/blob/461e75bd5ffaf45a0f270998514d444463d11261/tests/python/frontend/mxnet/test_forward.py#L2119-L2125

We may try to simplify the tests by not using a full cartesian product

@sxjscience
Copy link
Member Author

sxjscience commented Oct 18, 2020

I've verified the TVM integration with 5 NLP backbones in GluonNLP: BERT, ALBERT, ELECTRA, RoBERTA, and BART

import mxnet as mx
import numpy as np
import gluonnlp
from gluonnlp.models import get_backbone
import numpy.testing as npt
import tvm
from tvm import relay
import tvm.contrib.graph_runtime as runtime


mx.npx.set_np()

instance_info = {
    'g4': {'target': "cuda -model=t4", 'use_gpu': True},
    'c4': {'target': 'llvm -mcpu=core-avx2 -libs=cblas', 'use_gpu': False},
    'c5': {'target': 'llvm -mcpu=skylake-avx512 -libs=cblas', 'use_gpu': False},
    'p3': {'target': 'cuda -model=v100', 'use_gpu': True}
}


def test_backbone(model_name, batch_size=2, seq_length=128, instance='g4',
                  required_pass=None, opt_level=3):
    if required_pass is None:
        required_pass = ["FastMath"]
    model_cls, cfg, tokenizer, backbone_param_path, _ = get_backbone(model_name)
    model = model_cls.from_cfg(cfg)
    model.load_parameters(backbone_param_path)
    model.hybridize()
    token_ids = mx.np.random.randint(0, cfg.MODEL.vocab_size, (batch_size, seq_length), dtype=np.int32)
    token_types = mx.np.random.randint(0, 2, (batch_size, seq_length), dtype=np.int32)
    valid_length = mx.np.random.randint(seq_length // 2, seq_length, (batch_size,), dtype=np.int32)
    if 'bart' in model_name:
        mx_out = model(token_ids, valid_length, token_ids, valid_length)
        shape_dict = {
            'data0': token_ids.shape,
            'data1': valid_length.shape,
            'data2': token_ids.shape,
            'data3': valid_length.shape,
        }
        dtype_dict = {
            'data0': token_ids.dtype.name,
            'data1': valid_length.dtype.name,
            'data2': token_ids.dtype.name,
            'data3': valid_length.dtype.name,
        }
    elif 'roberta' in model_name or 'xlmr' in model_name:
        mx_out = model(token_ids, valid_length)
        shape_dict = {
            'data0': token_ids.shape,
            'data1': valid_length.shape,
        }
        dtype_dict = {
            'data0': token_ids.dtype.name,
            'data1': valid_length.dtype.name,
        }
    else:
        mx_out = model(token_ids, token_types, valid_length)
        shape_dict = {
            'data0': token_ids.shape,
            'data1': token_types.shape,
            'data2': valid_length.shape
        }
        dtype_dict = {
            'data0': token_ids.dtype.name,
            'data1': token_types.dtype.name,
            'data2': valid_length.dtype.name
        }
    sym = model._cached_graph[1]
    params = {}
    for k, v in model.collect_params().items():
        params[v._var_name] = tvm.nd.array(v.data().asnumpy())
    mod, params = relay.frontend.from_mxnet(sym, shape=shape_dict, dtype=dtype_dict, arg_params=params)
    target = instance_info[instance]['target']
    use_gpu = instance_info[instance]['use_gpu']
    with relay.build_config(opt_level=opt_level, required_pass=required_pass):
        graph, lib, cparams = relay.build(mod, target, params=params)
    if use_gpu:
        ctx = tvm.gpu()
    else:
        ctx = tvm.cpu()
    rt = runtime.create(graph, lib, ctx)
    rt.set_input(**cparams)
    if 'bart' in model_name:
        rt.set_input(data0=token_ids, data1=valid_length, data2=token_ids, data3=valid_length)
    elif 'roberta' in model_name:
        rt.set_input(data0=token_ids, data1=valid_length)
    else:
        rt.set_input(data0=token_ids, data1=token_types, data2=valid_length)
    rt.run()
    for i in range(rt.get_num_outputs()):
        out = rt.get_output(i)
        if rt.get_num_outputs() == 1:
            mx_out_gt = mx_out.asnumpy()
        else:
            mx_out_gt = mx_out[i].asnumpy()
        if 'mobilebert' in model_name and len(out.shape) == 3:
            npt.assert_allclose(out.asnumpy()[:, 1:, :], mx_out[i].asnumpy()[:, 1:, :],
                                rtol=6e-2, atol=6e-2)
        else:
            npt.assert_allclose(out.asnumpy(), mx_out_gt, rtol=6e-2, atol=6e-2)
# test_backbone('google_en_cased_bert_base', instance='g4')
test_model_names = ['google_albert_base_v2',
                    'google_en_cased_bert_base',
                    'google_electra_small',
                    'fairseq_roberta_base',
                    'fairseq_bart_base']
for model_name in test_model_names:
    test_backbone(model_name, instance='g4')

Copy link
Member

@yzhliu yzhliu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good to me. Thanks @sxjscience

Copy link
Contributor

@comaniac comaniac left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thanks!

@comaniac comaniac merged commit 7c2a2e5 into apache:main Oct 18, 2020
@comaniac
Copy link
Contributor

Thanks @sxjscience @yzhliu. The test simplification could be in the follow up PRs.

trevor-m pushed a commit to trevor-m/tvm that referenced this pull request Oct 29, 2020
…nNLP (apache#6699)

* update

Update type_relations.cc

Update transform.cc

Update transform.cc

Update transform.cc

Update transform.cc

Update transform.cc

Update mxnet.py

Update mxnet.py

Update mxnet.py

Update mxnet.py

Update mxnet.py

Update mxnet.py

Update mxnet.py

Update mxnet.py

Update mxnet.py

Update mxnet.py

Update mxnet.py

Update mxnet.py

Update mxnet.py

Update mxnet.py

Update mxnet.py

Update mxnet.py

Update mxnet.py

Update mxnet.py

Update mxnet.py

update

Update mxnet.py

debug

Update generic.py

Update topi_integration.py

fix bug

update

Update test_forward.py

Update test_forward.py

fix test case

Update mxnet.py

update

Update mxnet.py

Update mxnet.py

Update test_forward.py

Update mxnet.py

Update mxnet.py

Update test_forward.py

Update mxnet.py

Update mxnet.py

Update mxnet.py

debug

Update mxnet.py

Update mxnet.py

Update test_forward.py

Update mxnet.py

* address comments

* Update mxnet.py

* Update mxnet.py

* fix

* improve where test

* Update test_forward.py

* Update test_forward.py

* Update test_forward.py

* update

* Update mxnet.py

* Update mxnet.py

* Update mxnet.py

debug

Update common.py

update

Update mxnet.py

update

Update test_forward.py

Update test_forward.py

* update

* fix lint

* Update mxnet.py

* Update test_op_level1.py

* fix lint
trevor-m pushed a commit to trevor-m/tvm that referenced this pull request Dec 2, 2020
…nNLP (apache#6699)

* update

Update type_relations.cc

Update transform.cc

Update transform.cc

Update transform.cc

Update transform.cc

Update transform.cc

Update mxnet.py

Update mxnet.py

Update mxnet.py

Update mxnet.py

Update mxnet.py

Update mxnet.py

Update mxnet.py

Update mxnet.py

Update mxnet.py

Update mxnet.py

Update mxnet.py

Update mxnet.py

Update mxnet.py

Update mxnet.py

Update mxnet.py

Update mxnet.py

Update mxnet.py

Update mxnet.py

Update mxnet.py

update

Update mxnet.py

debug

Update generic.py

Update topi_integration.py

fix bug

update

Update test_forward.py

Update test_forward.py

fix test case

Update mxnet.py

update

Update mxnet.py

Update mxnet.py

Update test_forward.py

Update mxnet.py

Update mxnet.py

Update test_forward.py

Update mxnet.py

Update mxnet.py

Update mxnet.py

debug

Update mxnet.py

Update mxnet.py

Update test_forward.py

Update mxnet.py

* address comments

* Update mxnet.py

* Update mxnet.py

* fix

* improve where test

* Update test_forward.py

* Update test_forward.py

* Update test_forward.py

* update

* Update mxnet.py

* Update mxnet.py

* Update mxnet.py

debug

Update common.py

update

Update mxnet.py

update

Update test_forward.py

Update test_forward.py

* update

* fix lint

* Update mxnet.py

* Update test_op_level1.py

* fix lint
trevor-m pushed a commit to trevor-m/tvm that referenced this pull request Dec 4, 2020
…nNLP (apache#6699)

* update

Update type_relations.cc

Update transform.cc

Update transform.cc

Update transform.cc

Update transform.cc

Update transform.cc

Update mxnet.py

Update mxnet.py

Update mxnet.py

Update mxnet.py

Update mxnet.py

Update mxnet.py

Update mxnet.py

Update mxnet.py

Update mxnet.py

Update mxnet.py

Update mxnet.py

Update mxnet.py

Update mxnet.py

Update mxnet.py

Update mxnet.py

Update mxnet.py

Update mxnet.py

Update mxnet.py

Update mxnet.py

update

Update mxnet.py

debug

Update generic.py

Update topi_integration.py

fix bug

update

Update test_forward.py

Update test_forward.py

fix test case

Update mxnet.py

update

Update mxnet.py

Update mxnet.py

Update test_forward.py

Update mxnet.py

Update mxnet.py

Update test_forward.py

Update mxnet.py

Update mxnet.py

Update mxnet.py

debug

Update mxnet.py

Update mxnet.py

Update test_forward.py

Update mxnet.py

* address comments

* Update mxnet.py

* Update mxnet.py

* fix

* improve where test

* Update test_forward.py

* Update test_forward.py

* Update test_forward.py

* update

* Update mxnet.py

* Update mxnet.py

* Update mxnet.py

debug

Update common.py

update

Update mxnet.py

update

Update test_forward.py

Update test_forward.py

* update

* fix lint

* Update mxnet.py

* Update test_op_level1.py

* fix lint
trevor-m pushed a commit to neo-ai/tvm that referenced this pull request Dec 4, 2020
…nNLP (apache#6699)

* update

Update type_relations.cc

Update transform.cc

Update transform.cc

Update transform.cc

Update transform.cc

Update transform.cc

Update mxnet.py

Update mxnet.py

Update mxnet.py

Update mxnet.py

Update mxnet.py

Update mxnet.py

Update mxnet.py

Update mxnet.py

Update mxnet.py

Update mxnet.py

Update mxnet.py

Update mxnet.py

Update mxnet.py

Update mxnet.py

Update mxnet.py

Update mxnet.py

Update mxnet.py

Update mxnet.py

Update mxnet.py

update

Update mxnet.py

debug

Update generic.py

Update topi_integration.py

fix bug

update

Update test_forward.py

Update test_forward.py

fix test case

Update mxnet.py

update

Update mxnet.py

Update mxnet.py

Update test_forward.py

Update mxnet.py

Update mxnet.py

Update test_forward.py

Update mxnet.py

Update mxnet.py

Update mxnet.py

debug

Update mxnet.py

Update mxnet.py

Update test_forward.py

Update mxnet.py

* address comments

* Update mxnet.py

* Update mxnet.py

* fix

* improve where test

* Update test_forward.py

* Update test_forward.py

* Update test_forward.py

* update

* Update mxnet.py

* Update mxnet.py

* Update mxnet.py

debug

Update common.py

update

Update mxnet.py

update

Update test_forward.py

Update test_forward.py

* update

* fix lint

* Update mxnet.py

* Update test_op_level1.py

* fix lint
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants