Skip to content

Commit

Permalink
Add Tensorflow handling of ONNX conversion
Browse files Browse the repository at this point in the history
Add tf2onnx and onnx packages to setup.py
Use them in convert.py to handle ONNX conversion of TF models.
Add tests of conversion to onnx for tensorflow models
  • Loading branch information
Alberto committed Jan 27, 2022
1 parent c189fea commit b3fc528
Show file tree
Hide file tree
Showing 4 changed files with 182 additions and 40 deletions.
4 changes: 3 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@
"keras2onnx",
"nltk",
"numpy>=1.17",
"onnx",
"onnxconverter-common",
"onnxruntime-tools>=1.4.2",
"onnxruntime>=1.4.0",
Expand Down Expand Up @@ -146,6 +147,7 @@
"starlette",
"tensorflow-cpu>=2.3",
"tensorflow>=2.3",
"tf2onnx",
"timeout-decorator",
"timm",
"tokenizers>=0.10.1,<0.11",
Expand Down Expand Up @@ -241,7 +243,7 @@ def run(self):

extras["tokenizers"] = deps_list("tokenizers")
extras["onnxruntime"] = deps_list("onnxruntime", "onnxruntime-tools")
extras["onnx"] = deps_list("onnxconverter-common", "keras2onnx") + extras["onnxruntime"]
extras["onnx"] = deps_list("onnxconverter-common", "keras2onnx", "onnx", "tf2onnx") + extras["onnxruntime"]
extras["modelcreation"] = deps_list("cookiecutter")

extras["sagemaker"] = deps_list("sagemaker")
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/dependency_versions_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
"keras2onnx": "keras2onnx",
"nltk": "nltk",
"numpy": "numpy>=1.17",
"onnx": "onnx",
"onnxconverter-common": "onnxconverter-common",
"onnxruntime-tools": "onnxruntime-tools>=1.4.2",
"onnxruntime": "onnxruntime>=1.4.0",
Expand Down Expand Up @@ -64,6 +65,7 @@
"starlette": "starlette",
"tensorflow-cpu": "tensorflow-cpu>=2.3",
"tensorflow": "tensorflow>=2.3",
"tf2onnx": "tf2onnx",
"timeout-decorator": "timeout-decorator",
"timm": "timm",
"tokenizers": "tokenizers>=0.10.1,<0.11",
Expand Down
119 changes: 81 additions & 38 deletions src/transformers/onnx/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from packaging.version import Version, parse

from .. import PreTrainedModel, PreTrainedTokenizer, TensorType, TFPreTrainedModel, is_torch_available
from ..file_utils import is_torch_onnx_dict_inputs_support_available
from ..file_utils import is_tf_available, is_torch_onnx_dict_inputs_support_available
from ..utils import logging
from .config import OnnxConfig

Expand Down Expand Up @@ -63,10 +63,14 @@ def check_onnxruntime_requirements(minimum_version: Version):


def export(
tokenizer: PreTrainedTokenizer, model: PreTrainedModel, config: OnnxConfig, opset: int, output: Path
tokenizer: PreTrainedTokenizer,
model: Union[PreTrainedModel, TFPreTrainedModel],
config: OnnxConfig,
opset: int,
output: Path,
) -> Tuple[List[str], List[str]]:
"""
Export a PyTorch backed pipeline to ONNX Intermediate Representation (IR
Export a PyTorch/Tensorflow backed pipeline to ONNX Intermediate Representation (IR)
Args:
tokenizer:
Expand All @@ -78,21 +82,68 @@ def export(
Returns:
"""
if not is_torch_available():
raise ImportError("Cannot convert because PyTorch is not installed. Please install torch first.")
from ..file_utils import torch_version

import torch
from torch.onnx import export
if not (is_torch_available() or is_tf_available()):
raise ImportError(
"Cannot convert because neither PyTorch nor Tensorflow are not installed. Please install torch or tensorflow first."
)

from ..file_utils import torch_version
if is_torch_available():
if not is_torch_onnx_dict_inputs_support_available():
raise AssertionError(f"Unsupported PyTorch version, minimum required is 1.8.0, got: {torch_version}")

if issubclass(type(model), PreTrainedModel):
import torch
from torch.onnx import export

logger.info(f"Using framework PyTorch: {torch.__version__}")
with torch.no_grad():
model.config.return_dict = True
model.eval()

# Check if we need to override certain configuration item
if config.values_override is not None:
logger.info(f"Overriding {len(config.values_override)} configuration item(s)")
for override_config_key, override_config_value in config.values_override.items():
logger.info(f"\t- {override_config_key} -> {override_config_value}")
setattr(model.config, override_config_key, override_config_value)

# Ensure inputs match
# TODO: Check when exporting QA we provide "is_pair=True"
model_inputs = config.generate_dummy_inputs(tokenizer, framework=TensorType.PYTORCH)
inputs_match, matched_inputs = ensure_model_and_config_inputs_match(model, model_inputs.keys())
onnx_outputs = list(config.outputs.keys())

if not inputs_match:
raise ValueError("Model and config inputs doesn't match")

config.patch_ops()

# export can works with named args but the dict containing named args as to be last element of the args tuple
export(
model,
(model_inputs,),
f=output.as_posix(),
input_names=list(config.inputs.keys()),
output_names=onnx_outputs,
dynamic_axes={name: axes for name, axes in chain(config.inputs.items(), config.outputs.items())},
do_constant_folding=True,
use_external_data_format=config.use_external_data_format(model.num_parameters()),
enable_onnx_checker=True,
opset_version=opset,
)

if not is_torch_onnx_dict_inputs_support_available():
raise AssertionError(f"Unsupported PyTorch version, minimum required is 1.8.0, got: {torch_version}")
config.restore_ops()

return matched_inputs, onnx_outputs
else:
import tensorflow as tf

import onnx
import tf2onnx

logger.info(f"Using framework PyTorch: {torch.__version__}")
with torch.no_grad():
model.config.return_dict = True
model.eval()

# Check if we need to override certain configuration item
if config.values_override is not None:
Expand All @@ -102,33 +153,16 @@ def export(
setattr(model.config, override_config_key, override_config_value)

# Ensure inputs match
# TODO: Check when exporting QA we provide "is_pair=True"
model_inputs = config.generate_dummy_inputs(tokenizer, framework=TensorType.PYTORCH)
model_inputs = config.generate_dummy_inputs(tokenizer, framework=TensorType.TENSORFLOW)
inputs_match, matched_inputs = ensure_model_and_config_inputs_match(model, model_inputs.keys())
onnx_outputs = list(config.outputs.keys())

if not inputs_match:
raise ValueError("Model and config inputs doesn't match")

config.patch_ops()

# export can works with named args but the dict containing named args as to be last element of the args tuple
export(
model,
(model_inputs,),
f=output.as_posix(),
input_names=list(config.inputs.keys()),
output_names=onnx_outputs,
dynamic_axes={name: axes for name, axes in chain(config.inputs.items(), config.outputs.items())},
do_constant_folding=True,
use_external_data_format=config.use_external_data_format(model.num_parameters()),
enable_onnx_checker=True,
opset_version=opset,
)

input_signature = [tf.TensorSpec.from_tensor(tensor, name=key) for key, tensor in model_inputs.items()]
onnx_model, _ = tf2onnx.convert.from_keras(model, input_signature, opset=opset)
onnx.save(onnx_model, output.as_posix())
config.restore_ops()

return matched_inputs, onnx_outputs
return matched_inputs, onnx_outputs


def validate_model_outputs(
Expand All @@ -145,7 +179,10 @@ def validate_model_outputs(

# TODO: generate inputs with a different batch_size and seq_len that was used for conversion to properly test
# dynamic input shapes.
reference_model_inputs = config.generate_dummy_inputs(tokenizer, framework=TensorType.PYTORCH)
if issubclass(type(reference_model), PreTrainedModel):
reference_model_inputs = config.generate_dummy_inputs(tokenizer, framework=TensorType.PYTORCH)
else:
reference_model_inputs = config.generate_dummy_inputs(tokenizer, framework=TensorType.TENSORFLOW)

# Create ONNX Runtime session
options = SessionOptions()
Expand Down Expand Up @@ -195,7 +232,10 @@ def validate_model_outputs(

# Check the shape and values match
for name, ort_value in zip(onnx_named_outputs, onnx_outputs):
ref_value = ref_outputs_dict[name].detach().numpy()
if issubclass(type(reference_model), PreTrainedModel):
ref_value = ref_outputs_dict[name].detach().numpy()
else:
ref_value = ref_outputs_dict[name].numpy()
logger.info(f'\t- Validating ONNX Model output "{name}":')

# Shape
Expand Down Expand Up @@ -228,7 +268,10 @@ def ensure_model_and_config_inputs_match(
:param config_inputs:
:return:
"""
forward_parameters = signature(model.forward).parameters
if issubclass(type(model), PreTrainedModel):
forward_parameters = signature(model.forward).parameters
else:
forward_parameters = signature(model.call).parameters
model_inputs_set = set(model_inputs)

# We are fine if config_inputs has more keys than model_inputs
Expand Down
97 changes: 96 additions & 1 deletion tests/test_onnx_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,16 @@
LayoutLMConfig,
MBartConfig,
RobertaConfig,
TFAlbertModel,
TFBartModel,
TFBertModel,
TFDistilBertModel,
TFGPT2Model,
TFMBartModel,
TFRobertaModel,
TFXLMRobertaModel,
XLMRobertaConfig,
is_tf_available,
is_torch_available,
)
from transformers.models.albert import AlbertOnnxConfig
Expand All @@ -39,7 +48,7 @@
)
from transformers.onnx.config import DEFAULT_ONNX_OPSET, OnnxConfigWithPast
from transformers.onnx.utils import compute_effective_axis_dimension, compute_serialized_parameters_size
from transformers.testing_utils import require_onnx, require_torch, slow
from transformers.testing_utils import require_onnx, require_tf, require_torch, slow


@require_onnx
Expand Down Expand Up @@ -223,6 +232,42 @@ def test_values_override(self):
}


if is_tf_available():
from transformers import ( # T5Model,
AlbertModel,
BartModel,
BertModel,
DistilBertModel,
GPT2Model,
GPTNeoModel,
LayoutLMModel,
MBartModel,
RobertaModel,
XLMRobertaModel,
)

TENSORFLOW_EXPORT_DEFAULT_MODELS = {
("ALBERT", "hf-internal-testing/tiny-albert", TFAlbertModel, AlbertConfig, AlbertOnnxConfig),
("BART", "facebook/bart-base", TFBartModel, BartConfig, BartOnnxConfig),
("BERT", "bert-base-cased", TFBertModel, BertConfig, BertOnnxConfig),
("DistilBERT", "distilbert-base-cased", TFDistilBertModel, DistilBertConfig, DistilBertOnnxConfig),
("GPT2", "gpt2", TFGPT2Model, GPT2Config, GPT2OnnxConfig),
# ("GPT-Neo", "EleutherAI/gpt-neo-125M", TFGPTNeoModel, GPTNeoConfig, GPTNeoOnnxConfig),
# ("LongFormer", "longformer-base-4096", LongformerModel, LongformerConfig, LongformerOnnxConfig),
("Roberta", "roberta-base", TFRobertaModel, RobertaConfig, RobertaOnnxConfig),
("XLM-Roberta", "roberta-base", TFXLMRobertaModel, XLMRobertaConfig, XLMRobertaOnnxConfig),
# ("LayoutLM", "microsoft/layoutlm-base-uncased", TFLayoutLMModel, LayoutLMConfig, LayoutLMOnnxConfig),
("MBart", "sshleifer/tiny-mbart", TFMBartModel, MBartConfig, MBartOnnxConfig),
# ("T5", "t5-small", T5Model, T5Config, T5OnnxConfig),
}

TENSORFLOW_EXPORT_WITH_PAST_MODELS = {
# ("BART", "facebook/bart-base", BartModel, BartConfig, BartOnnxConfig),
# ("GPT2", "gpt2", GPT2Model, GPT2Config, GPT2OnnxConfig),
# ("T5", "t5-small", T5Model, T5Config, T5OnnxConfig)
}


class OnnxExportTestCaseV2(TestCase):
"""
Integration tests ensuring supported models are correctly exported
Expand Down Expand Up @@ -251,6 +296,29 @@ def test_pytorch_export_default(self):
except ValueError as ve:
self.fail(f"{name} -> {ve}")

@slow
@require_tf
def test_tensorflow_export_default(self):
from transformers.onnx import export

for name, model, model_class, config_class, onnx_config_class in TENSORFLOW_EXPORT_DEFAULT_MODELS:
with self.subTest(name):
self.assertTrue(hasattr(onnx_config_class, "from_model_config"))

tokenizer = AutoTokenizer.from_pretrained(model)
model = model_class(config_class.from_pretrained(model))
onnx_config = onnx_config_class.from_model_config(model.config)

with NamedTemporaryFile("w") as output:
onnx_inputs, onnx_outputs = export(
tokenizer, model, onnx_config, DEFAULT_ONNX_OPSET, Path(output.name)
)

try:
validate_model_outputs(onnx_config, tokenizer, model, Path(output.name), onnx_outputs, 1e-5)
except ValueError as ve:
self.fail(f"{name} -> {ve}")

@slow
@require_torch
def test_pytorch_export_with_past(self):
Expand All @@ -277,3 +345,30 @@ def test_pytorch_export_with_past(self):
validate_model_outputs(onnx_config, tokenizer, model, output, onnx_outputs, 1e-5)
except ValueError as ve:
self.fail(f"{name} -> {ve}")

@slow
@require_tf
def test_tensorflow_export_with_past(self):
from transformers.onnx import export

for name, model, model_class, config_class, onnx_config_class in TENSORFLOW_EXPORT_WITH_PAST_MODELS:
with self.subTest(name):
self.assertTrue(hasattr(onnx_config_class, "with_past"), "OnnxConfigWithPast should have with_past()")

tokenizer = AutoTokenizer.from_pretrained(model)
model = model_class(config_class())
onnx_config = onnx_config_class.with_past(model.config)

self.assertTrue(hasattr(onnx_config, "use_past"), "OnnxConfigWithPast should have use_past attribute.")
self.assertTrue(
onnx_config.use_past, "OnnxConfigWithPast.use_past should be if called with with_past()"
)

with NamedTemporaryFile("w") as output:
output = Path(output.name)
onnx_inputs, onnx_outputs = export(tokenizer, model, onnx_config, DEFAULT_ONNX_OPSET, output)

try:
validate_model_outputs(onnx_config, tokenizer, model, output, onnx_outputs, 1e-5)
except ValueError as ve:
self.fail(f"{name} -> {ve}")

0 comments on commit b3fc528

Please sign in to comment.