From b06abf104f28060f7ec5a5a6265120b3dd069293 Mon Sep 17 00:00:00 2001 From: Jay Zhang <36183870+fatcat-z@users.noreply.github.com> Date: Mon, 6 Dec 2021 21:01:51 +0800 Subject: [PATCH] Update the example of exporting Bart + BeamSearch to ONNX module to resolve comments. (#14310) * Update code to resolve comments left in previous PR. * Add README.md file for this example. * Update examples/onnx/pytorch/translation/README.md Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Update examples/onnx/pytorch/translation/README.md Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Update examples/onnx/pytorch/translation/README.md Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Update README.md file to resolve comments. * Add a section name. * Update examples/onnx/pytorch/translation/README.md Co-authored-by: Gary Miguel * Add more comments for _convert_past_list_to_tuple(). * Change the default file name to a consistent one. * Fix a format issue. * Update examples/onnx/pytorch/translation/README.md Co-authored-by: Gary Miguel * Update examples/onnx/pytorch/translation/run_onnx_exporter.py Co-authored-by: Gary Miguel * Update examples/onnx/pytorch/translation/README.md Co-authored-by: lewtun * Change the folder to summarization and address some other coments. * Update the torch version. Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> Co-authored-by: Gary Miguel Co-authored-by: lewtun --- examples/onnx/pytorch/summarization/README.md | 43 ++++++ .../bart_onnx/generation_onnx.py | 137 +++++------------- .../bart_onnx/reduce_onnx_size.py | 33 +++-- .../pytorch/summarization/requirements.txt | 1 + .../run_onnx_exporter.py | 123 +++++++--------- .../onnx/pytorch/translation/requirements.txt | 1 - 6 files changed, 159 insertions(+), 179 deletions(-) create mode 100644 examples/onnx/pytorch/summarization/README.md rename examples/onnx/pytorch/{translation => summarization}/bart_onnx/generation_onnx.py (90%) rename examples/onnx/pytorch/{translation => summarization}/bart_onnx/reduce_onnx_size.py (71%) create mode 100644 examples/onnx/pytorch/summarization/requirements.txt rename examples/onnx/pytorch/{translation => summarization}/run_onnx_exporter.py (60%) delete mode 100644 examples/onnx/pytorch/translation/requirements.txt diff --git a/examples/onnx/pytorch/summarization/README.md b/examples/onnx/pytorch/summarization/README.md new file mode 100644 index 00000000000000..6fd1ffe70aeb7d --- /dev/null +++ b/examples/onnx/pytorch/summarization/README.md @@ -0,0 +1,43 @@ + + +# Bart + Beam Search to ONNX + + + +This folder contains an example of exporting Bart + Beam Search generation (`BartForConditionalGeneration`) to ONNX. + +Beam Search contains a for-loop workflow, so we need to make them TorchScript-compatible for exporting to ONNX. This example shows how to make a Bart model be TorchScript-compatible by wrapping up it into a new model. In addition, some changes were made to the `beam_search()` function to make it TorchScript-compatible. + + +## How to run the example + +To make sure you can successfully run the latest versions of the example scripts, you have to **install the library from source** and install some example-specific requirements. To do this, execute the following steps in a new virtual environment: + +```bash +git clone https://github.com/huggingface/transformers +cd transformers +pip install . +``` +Then cd in this example folder and run +```bash +pip install -r requirements.txt +``` + +Now you can run the example command below to get the example ONNX file: + +```bash +python run_onnx_exporter.py --model_name_or_path facebook/bart-base +``` diff --git a/examples/onnx/pytorch/translation/bart_onnx/generation_onnx.py b/examples/onnx/pytorch/summarization/bart_onnx/generation_onnx.py similarity index 90% rename from examples/onnx/pytorch/translation/bart_onnx/generation_onnx.py rename to examples/onnx/pytorch/summarization/bart_onnx/generation_onnx.py index ac2f6a9434332a..59d9c6c092b47a 100644 --- a/examples/onnx/pytorch/translation/bart_onnx/generation_onnx.py +++ b/examples/onnx/pytorch/summarization/bart_onnx/generation_onnx.py @@ -1,4 +1,5 @@ import copy +import itertools from typing import List, Optional, Tuple import torch @@ -8,23 +9,23 @@ from transformers.generation_utils import GenerationMixin -def flatten_list(past): - values = [] - if past is not None: - for i, p in enumerate(past): - for j, q in enumerate(p): - values.append(q) - - return values - +def _convert_past_list_to_tuple(past_key_values): + """ + In Bart model, the type of past_key_values is tuple(tuple(torch.FloatTensor)) which is not + TorchScript-compatible. To support this, we have to convert it during the export process. + This function will convert past values from a list to tuple(tuple(torch.FloatTensor)) for + the inner decoder. -def list_to_tuple(past): + According to the definition of past_key_values, each inner tuple(torch.FloatTensor) has 4 tensors, + so we convert every 4 elements in the list as a tuple(torch.FloatTensor). + """ + count_of_each_inner_tuple = 4 results = () temp_result = () - count_n = len(past) // 4 + count_n = len(past_key_values) // count_of_each_inner_tuple for idx in range(count_n): - real_idx = idx * 4 - temp_result = tuple(past[real_idx : real_idx + 4]) + real_idx = idx * count_of_each_inner_tuple + temp_result = tuple(past_key_values[real_idx : real_idx + count_of_each_inner_tuple]) results += ((temp_result),) return results @@ -51,7 +52,7 @@ def __init__(self, decoder): def forward(self, input_ids, encoder_state, attention_mask, past=None): all_results = None if past is not None: - all_results = list_to_tuple(past) + all_results = _convert_past_list_to_tuple(past) input_ids = input_ids[:, -1:] last_hidden_state, past_key_values = self.decoder( @@ -68,28 +69,33 @@ def forward(self, input_ids, encoder_state, attention_mask, past=None): return last_hidden_state, past_values -def create_traced_encoder(encoder, input_ids, attention_mask): +def _create_traced_encoder(encoder, input_ids, attention_mask): encoder_c = copy.deepcopy(encoder) encoder_for_onnx = EncoderForONNX(encoder_c) - # return torch.jit.trace(encoder, (input_ids, attention_mask)) return torch.jit.trace(encoder_for_onnx, (input_ids, attention_mask)) -def create_traced_decoder(decoder, input_ids, encoder_state, attention_mask, past=None): +def _create_traced_decoder(decoder, input_ids, encoder_state, attention_mask, past=None): decoder_c = copy.deepcopy(decoder) decoder_for_onnx = DecoderForONNX(decoder_c) - past_values = flatten_list(past) + past_values = list(itertools.chain.from_iterable(past or ())) # Do this twice so we got 2 different decoders for further work. - if past_values is None or len(past_values) == 0: - return torch.jit.trace(decoder_for_onnx, (input_ids, encoder_state, attention_mask)) - else: + if past_values: return torch.jit.trace(decoder_for_onnx, (input_ids, encoder_state, attention_mask, past_values)) + else: + return torch.jit.trace(decoder_for_onnx, (input_ids, encoder_state, attention_mask)) class BartConfigTS(BartConfig, torch.nn.Module): - def init_module(self): + """ + BartConfigTS is a TorchScript-compatible transformers.models.bart.configuration_bart.BartConfig. + TorchScript only supports sub-classes of torch.nn.Module. + """ + + def __init__(self, config): + BartConfig.__init__(self, config) torch.nn.Module.__init__(self) @@ -127,7 +133,6 @@ class BARTGenerator(torch.nn.Module, GenerationMixin): def __init__(self, model): super().__init__() self.config = BartConfigTS(model.config) - self.config.init_module() self.config.force_bos_token_to_be_generated = False self._trace_modules(model) self.logits_processor = MinLengthLogitsProcessorTS(self.config.min_length, self.config.eos_token_id) @@ -136,7 +141,6 @@ def __init__(self, model): self.decoder_layers = model.config.decoder_layers def _trace_modules(self, model): - # Be aware of the last one 2 should be kept. input_ids = torch.tensor( [ [ @@ -200,89 +204,25 @@ def _trace_modules(self, model): 57, 8629, 5, - 2, + model.config.eos_token_id, ] ], device=model.device, dtype=torch.long, ) attention_mask = torch.tensor( - [ - [ - True, - True, - True, - True, - True, - True, - True, - True, - True, - True, - True, - True, - True, - True, - True, - True, - True, - True, - True, - True, - True, - True, - True, - True, - True, - True, - True, - True, - True, - True, - True, - True, - True, - True, - True, - True, - True, - True, - True, - True, - True, - True, - True, - True, - True, - True, - True, - True, - True, - True, - True, - True, - True, - True, - True, - True, - True, - True, - True, - True, - True, - ] - ], + [[True] * input_ids.shape[-1]], device=model.device, dtype=torch.bool, ) - self.encoder = create_traced_encoder(model.get_encoder(), input_ids, attention_mask) + self.encoder = _create_traced_encoder(model.get_encoder(), input_ids, attention_mask) encoder_outputs = model.get_encoder()(input_ids, attention_mask=attention_mask, return_dict=True) decoder = model.model.decoder decoder_outputs = decoder(input_ids, attention_mask, encoder_outputs["last_hidden_state"], None, None, None) - self.decoder_no_past = create_traced_decoder( + self.decoder_no_past = _create_traced_decoder( model.model.decoder, input_ids, encoder_outputs["last_hidden_state"], attention_mask ) - self.decoder_with_past = create_traced_decoder( + self.decoder_with_past = _create_traced_decoder( model.model.decoder, input_ids, encoder_outputs["last_hidden_state"], attention_mask, decoder_outputs[1] ) @@ -414,8 +354,8 @@ def __init__(self): self._beam_hyps_count = torch.zeros(self.batch_size, dtype=torch.long) self._beam_hyps_worst_scores = torch.zeros(self.batch_size) + 1e9 self._beam_hyps_max_length: int = self.max_length - 1 - self._beam_hyps: List[torch.Tensor] = [torch.zeros(2)] # placeholder for TorchScript compatible - self._beam_scores: List[torch.Tensor] = [torch.zeros(2)] # placeholder for TorchScript compatible + self._beam_hyps: List[torch.Tensor] = [torch.zeros(2)] # placeholder for TorchScript compatibility + self._beam_scores: List[torch.Tensor] = [torch.zeros(2)] # placeholder for TorchScript compatibility def is_done(self) -> torch.Tensor: return self._done.all() @@ -474,11 +414,11 @@ def hypo_add(self, hyp: torch.Tensor, sum_logprobs: float, hypo_idx: int): score = sum_logprobs / (hyp.shape[-1] ** self.length_penalty) hyps_count = self.hypo_len(hypo_idx) if hyps_count < self.num_beams or score > self._beam_hyps_worst_scores[hypo_idx]: - # NOTE: work around difference of torch.sum(empty_tensor) = 0, while error in onnx. + # NOTE: work around difference of torch.sum(empty_tensor) == 0, while error in onnx. + # Bug: https://msdata.visualstudio.com/Vienna/_workitems/edit/1486599 beam_idx = ( torch.sum(self._beam_hyps_count[:hypo_idx]) if hypo_idx != 0 else torch.tensor(0, dtype=torch.long) ) - # beam_idx = torch.sum(_beam_hyps_count[:hypo_idx]) self._beam_scores.insert(beam_idx, torch.tensor([score])) self._beam_hyps.insert(beam_idx, hyp) if hyps_count + 1 > self.num_beams: @@ -605,7 +545,7 @@ def finalize( self.hypo_add(final_tokens, final_score, batch_idx) # select the best hypotheses - # NOTE: new is not scriptable + # NOTE: torch.Tensor.new_zeros() is not scriptable sent_lengths = torch.zeros(batch_size * self.num_beam_hyps_to_keep, dtype=torch.long) best = [] best_scores = torch.zeros( @@ -782,7 +722,6 @@ def forward(self, input_ids, attention_mask, num_beams, max_length, decoder_star bos_token_id=bos_token_id, ) - # from generation_utils.py batch_size = input_ids.shape[0] length_penalty = self.config.length_penalty diff --git a/examples/onnx/pytorch/translation/bart_onnx/reduce_onnx_size.py b/examples/onnx/pytorch/summarization/bart_onnx/reduce_onnx_size.py similarity index 71% rename from examples/onnx/pytorch/translation/bart_onnx/reduce_onnx_size.py rename to examples/onnx/pytorch/summarization/bart_onnx/reduce_onnx_size.py index d16c1e4c41c5f3..63fae44ffac6bc 100644 --- a/examples/onnx/pytorch/translation/bart_onnx/reduce_onnx_size.py +++ b/examples/onnx/pytorch/summarization/bart_onnx/reduce_onnx_size.py @@ -1,3 +1,7 @@ +""" +Code to remove duplicate initializers to reduce ONNX model size. +""" + import os import numpy @@ -5,7 +9,7 @@ import onnx -def is_equal_tensor_proto(a, b): +def _is_equal_tensor_proto(a, b): name_a = a.name name_b = b.name @@ -20,25 +24,25 @@ def is_equal_tensor_proto(a, b): return res -def node_replace_input_with(node_proto, name, new_name): +def _node_replace_input_with(node_proto, name, new_name): for i, input_name in enumerate(node_proto.input): if input_name == name: node_proto.input.insert(i, new_name) node_proto.input.pop(i + 1) if node_proto.op_type == "If": - graph_replace_input_with(node_proto.attribute[0].g, name, new_name) - graph_replace_input_with(node_proto.attribute[1].g, name, new_name) + _graph_replace_input_with(node_proto.attribute[0].g, name, new_name) + _graph_replace_input_with(node_proto.attribute[1].g, name, new_name) if node_proto.op_type == "Loop": - graph_replace_input_with(node_proto.attribute[0].g, name, new_name) + _graph_replace_input_with(node_proto.attribute[0].g, name, new_name) -def graph_replace_input_with(graph_proto, name, new_name): +def _graph_replace_input_with(graph_proto, name, new_name): for n in graph_proto.node: - node_replace_input_with(n, name, new_name) + _node_replace_input_with(n, name, new_name) -def remove_dup_initializers_from_model(model, model_without_ext, ind_to_replace): +def _remove_dup_initializers_from_model(model, model_without_ext, ind_to_replace): inits_with_data = [i for i in model.graph.initializer] inits = [i for i in model_without_ext.graph.initializer] for i, ref_i in ind_to_replace: @@ -52,10 +56,15 @@ def remove_dup_initializers_from_model(model, model_without_ext, ind_to_replace) model_without_ext.graph.initializer.remove(inits[i]) # for n in model.graph.node: - graph_replace_input_with(model_without_ext.graph, name_i, name_ref) + _graph_replace_input_with(model_without_ext.graph, name_i, name_ref) def remove_dup_initializers(onnx_file_path): + """ + Removes duplicate initializers from the model to reduce its size. + Writes a new file in the same directory as onnx_file_path and returns the path to that file. + """ + model_file_folder = os.path.dirname(onnx_file_path) model_file_name = os.path.basename(onnx_file_path) @@ -76,7 +85,7 @@ def remove_dup_initializers(onnx_file_path): for j in range(i + 1, len(inits)): if j in dup_set: continue - if is_equal_tensor_proto(inits[i], inits[j]): + if _is_equal_tensor_proto(inits[i], inits[j]): dup_set.add(i) dup_set.add(j) @@ -103,8 +112,8 @@ def remove_dup_initializers(onnx_file_path): print("total reduced size: ", total_reduced_size / 1024 / 1024 / 1024, "GB") - ind_to_replace = sorted(ind_to_replace, key=lambda x: x[0]) - remove_dup_initializers_from_model(model, model, ind_to_replace) + ind_to_replace = sorted(ind_to_replace) + _remove_dup_initializers_from_model(model, model, ind_to_replace) optimized_model_file_name = "optimized_" + model_file_name new_model = os.path.join(model_file_folder, optimized_model_file_name) diff --git a/examples/onnx/pytorch/summarization/requirements.txt b/examples/onnx/pytorch/summarization/requirements.txt new file mode 100644 index 00000000000000..215356506121ca --- /dev/null +++ b/examples/onnx/pytorch/summarization/requirements.txt @@ -0,0 +1 @@ +torch >= 1.10 \ No newline at end of file diff --git a/examples/onnx/pytorch/translation/run_onnx_exporter.py b/examples/onnx/pytorch/summarization/run_onnx_exporter.py similarity index 60% rename from examples/onnx/pytorch/translation/run_onnx_exporter.py rename to examples/onnx/pytorch/summarization/run_onnx_exporter.py index 1355903df65efe..2a62ca9f704dbb 100644 --- a/examples/onnx/pytorch/translation/run_onnx_exporter.py +++ b/examples/onnx/pytorch/summarization/run_onnx_exporter.py @@ -20,7 +20,6 @@ import logging import os import sys -from datetime import datetime import numpy as np import torch @@ -46,7 +45,7 @@ def parse_args(): - parser = argparse.ArgumentParser(description="Finetune a transformers model on a text classification task") + parser = argparse.ArgumentParser(description="Export Bart model + Beam Search to ONNX graph.") parser.add_argument( "--validation_file", type=str, default=None, help="A csv or a json file containing the validation data." ) @@ -104,13 +103,12 @@ def export_and_validate_model(model, tokenizer, onnx_file_path, num_beams, max_l model.eval() ort_sess = None - onnx_bart = torch.jit.script(BARTBeamSearchGenerator(model)) + bart_script_model = torch.jit.script(BARTBeamSearchGenerator(model)) with torch.no_grad(): ARTICLE_TO_SUMMARIZE = "My friends are cool but they eat too many carbs." inputs = tokenizer([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors="pt").to(model.device) - # Test export here. summary_ids = model.generate( inputs["input_ids"], attention_mask=inputs["attention_mask"], @@ -120,53 +118,54 @@ def export_and_validate_model(model, tokenizer, onnx_file_path, num_beams, max_l decoder_start_token_id=model.config.decoder_start_token_id, ) - if not ort_sess: - torch.onnx.export( - onnx_bart, - ( - inputs["input_ids"], - inputs["attention_mask"], - num_beams, - max_length, - model.config.decoder_start_token_id, - ), - onnx_file_path, - opset_version=14, - input_names=["input_ids", "attention_mask", "num_beams", "max_length", "decoder_start_token_id"], - output_names=["output_ids"], - dynamic_axes={ - "input_ids": {0: "batch", 1: "seq"}, - "output_ids": {0: "batch", 1: "seq_out"}, - }, - verbose=False, - strip_doc_string=False, - example_outputs=summary_ids, - ) - - new_onnx_file_path = remove_dup_initializers(os.path.abspath(onnx_file_path)) - - ort_sess = onnxruntime.InferenceSession(new_onnx_file_path) - ort_out = ort_sess.run( - None, - { - "input_ids": inputs["input_ids"].cpu().numpy(), - "attention_mask": inputs["attention_mask"].cpu().numpy(), - "num_beams": np.array(num_beams), - "max_length": np.array(max_length), - "decoder_start_token_id": np.array(model.config.decoder_start_token_id), - }, - ) - - np.testing.assert_allclose(summary_ids.cpu().numpy(), ort_out[0], rtol=1e-3, atol=1e-3) - - print("========= Pass - Results are matched! =========") + torch.onnx.export( + bart_script_model, + ( + inputs["input_ids"], + inputs["attention_mask"], + num_beams, + max_length, + model.config.decoder_start_token_id, + ), + onnx_file_path, + opset_version=14, + input_names=["input_ids", "attention_mask", "num_beams", "max_length", "decoder_start_token_id"], + output_names=["output_ids"], + dynamic_axes={ + "input_ids": {0: "batch", 1: "seq"}, + "output_ids": {0: "batch", 1: "seq_out"}, + }, + example_outputs=summary_ids, + ) + + logger.info("Model exported to {}".format(onnx_file_path)) + + new_onnx_file_path = remove_dup_initializers(os.path.abspath(onnx_file_path)) + + logger.info("Deduplicated and optimized model written to {}".format(new_onnx_file_path)) + + ort_sess = onnxruntime.InferenceSession(new_onnx_file_path) + ort_out = ort_sess.run( + None, + { + "input_ids": inputs["input_ids"].cpu().numpy(), + "attention_mask": inputs["attention_mask"].cpu().numpy(), + "num_beams": np.array(num_beams), + "max_length": np.array(max_length), + "decoder_start_token_id": np.array(model.config.decoder_start_token_id), + }, + ) + + np.testing.assert_allclose(summary_ids.cpu().numpy(), ort_out[0], rtol=1e-3, atol=1e-3) + + logger.info("Model outputs from torch and ONNX Runtime are similar.") + logger.info("Success.") def main(): args = parse_args() - local_device = None - local_max_length = 5 - local_num_beams = 4 + max_length = 5 + num_beams = 4 # Make one log on every process with the configuration for debugging. logging.basicConfig( @@ -175,41 +174,31 @@ def main(): level=logging.INFO, ) - logger.setLevel(logging.ERROR) + logger.setLevel(logging.INFO) transformers.utils.logging.set_verbosity_error() - if args.model_name_or_path: - model, tokenizer = load_model_tokenizer(args.model_name_or_path, local_device) - else: - raise ValueError("Make sure that model name has been passed") + device = torch.device(args.device) + + model, tokenizer = load_model_tokenizer(args.model_name_or_path, device) if model.config.decoder_start_token_id is None: raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined") - if args.device: - if args.device == "cuda" and not torch.cuda.is_available(): - raise ValueError("CUDA is not available in this server.") - - local_device = torch.device(args.device) - else: - local_device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") - - model.to(local_device) + model.to(device) if args.max_length: - local_max_length = args.max_length + max_length = args.max_length if args.num_beams: - local_num_beams = args.num_beams + num_beams = args.num_beams if args.output_file_path: output_name = args.output_file_path else: - output_name = "onnx_model_{}.onnx".format(datetime.now().utcnow().microsecond) - - export_and_validate_model(model, tokenizer, output_name, local_num_beams, local_max_length) + output_name = "BART.onnx" - logger.info("***** Running export *****") + logger.info("Exporting model to ONNX") + export_and_validate_model(model, tokenizer, output_name, num_beams, max_length) if __name__ == "__main__": diff --git a/examples/onnx/pytorch/translation/requirements.txt b/examples/onnx/pytorch/translation/requirements.txt deleted file mode 100644 index 5714ddd64862c9..00000000000000 --- a/examples/onnx/pytorch/translation/requirements.txt +++ /dev/null @@ -1 +0,0 @@ -torch >= 1.8 \ No newline at end of file