Skip to content

Commit

Permalink
Update the example of exporting Bart + BeamSearch to ONNX module to r…
Browse files Browse the repository at this point in the history
…esolve comments. (huggingface#14310)

* Update code to resolve comments left in previous PR.

* Add README.md file for this example.

* Update examples/onnx/pytorch/translation/README.md

Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com>

* Update examples/onnx/pytorch/translation/README.md

Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com>

* Update examples/onnx/pytorch/translation/README.md

Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com>

* Update README.md file to resolve comments.

* Add a section name.

* Update examples/onnx/pytorch/translation/README.md

Co-authored-by: Gary Miguel <garymm@garymm.org>

* Add more comments for _convert_past_list_to_tuple().

* Change the default file name to a consistent one.

* Fix a format issue.

* Update examples/onnx/pytorch/translation/README.md

Co-authored-by: Gary Miguel <garymm@garymm.org>

* Update examples/onnx/pytorch/translation/run_onnx_exporter.py

Co-authored-by: Gary Miguel <garymm@garymm.org>

* Update examples/onnx/pytorch/translation/README.md

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Change the folder to summarization and address some other coments.

* Update the torch version.

Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com>
Co-authored-by: Gary Miguel <garymm@garymm.org>
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
  • Loading branch information
4 people authored and Alberto Bégué committed Jan 27, 2022
1 parent d7b052b commit b06abf1
Show file tree
Hide file tree
Showing 6 changed files with 159 additions and 179 deletions.
43 changes: 43 additions & 0 deletions examples/onnx/pytorch/summarization/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
<!---
Copyright 2021 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->

# Bart + Beam Search to ONNX



This folder contains an example of exporting Bart + Beam Search generation (`BartForConditionalGeneration`) to ONNX.

Beam Search contains a for-loop workflow, so we need to make them TorchScript-compatible for exporting to ONNX. This example shows how to make a Bart model be TorchScript-compatible by wrapping up it into a new model. In addition, some changes were made to the `beam_search()` function to make it TorchScript-compatible.


## How to run the example

To make sure you can successfully run the latest versions of the example scripts, you have to **install the library from source** and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:

```bash
git clone https://github.com/huggingface/transformers
cd transformers
pip install .
```
Then cd in this example folder and run
```bash
pip install -r requirements.txt
```

Now you can run the example command below to get the example ONNX file:

```bash
python run_onnx_exporter.py --model_name_or_path facebook/bart-base
```
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import copy
import itertools
from typing import List, Optional, Tuple

import torch
Expand All @@ -8,23 +9,23 @@
from transformers.generation_utils import GenerationMixin


def flatten_list(past):
values = []
if past is not None:
for i, p in enumerate(past):
for j, q in enumerate(p):
values.append(q)

return values

def _convert_past_list_to_tuple(past_key_values):
"""
In Bart model, the type of past_key_values is tuple(tuple(torch.FloatTensor)) which is not
TorchScript-compatible. To support this, we have to convert it during the export process.
This function will convert past values from a list to tuple(tuple(torch.FloatTensor)) for
the inner decoder.
def list_to_tuple(past):
According to the definition of past_key_values, each inner tuple(torch.FloatTensor) has 4 tensors,
so we convert every 4 elements in the list as a tuple(torch.FloatTensor).
"""
count_of_each_inner_tuple = 4
results = ()
temp_result = ()
count_n = len(past) // 4
count_n = len(past_key_values) // count_of_each_inner_tuple
for idx in range(count_n):
real_idx = idx * 4
temp_result = tuple(past[real_idx : real_idx + 4])
real_idx = idx * count_of_each_inner_tuple
temp_result = tuple(past_key_values[real_idx : real_idx + count_of_each_inner_tuple])
results += ((temp_result),)

return results
Expand All @@ -51,7 +52,7 @@ def __init__(self, decoder):
def forward(self, input_ids, encoder_state, attention_mask, past=None):
all_results = None
if past is not None:
all_results = list_to_tuple(past)
all_results = _convert_past_list_to_tuple(past)
input_ids = input_ids[:, -1:]

last_hidden_state, past_key_values = self.decoder(
Expand All @@ -68,28 +69,33 @@ def forward(self, input_ids, encoder_state, attention_mask, past=None):
return last_hidden_state, past_values


def create_traced_encoder(encoder, input_ids, attention_mask):
def _create_traced_encoder(encoder, input_ids, attention_mask):
encoder_c = copy.deepcopy(encoder)
encoder_for_onnx = EncoderForONNX(encoder_c)

# return torch.jit.trace(encoder, (input_ids, attention_mask))
return torch.jit.trace(encoder_for_onnx, (input_ids, attention_mask))


def create_traced_decoder(decoder, input_ids, encoder_state, attention_mask, past=None):
def _create_traced_decoder(decoder, input_ids, encoder_state, attention_mask, past=None):
decoder_c = copy.deepcopy(decoder)
decoder_for_onnx = DecoderForONNX(decoder_c)
past_values = flatten_list(past)
past_values = list(itertools.chain.from_iterable(past or ()))

# Do this twice so we got 2 different decoders for further work.
if past_values is None or len(past_values) == 0:
return torch.jit.trace(decoder_for_onnx, (input_ids, encoder_state, attention_mask))
else:
if past_values:
return torch.jit.trace(decoder_for_onnx, (input_ids, encoder_state, attention_mask, past_values))
else:
return torch.jit.trace(decoder_for_onnx, (input_ids, encoder_state, attention_mask))


class BartConfigTS(BartConfig, torch.nn.Module):
def init_module(self):
"""
BartConfigTS is a TorchScript-compatible transformers.models.bart.configuration_bart.BartConfig.
TorchScript only supports sub-classes of torch.nn.Module.
"""

def __init__(self, config):
BartConfig.__init__(self, config)
torch.nn.Module.__init__(self)


Expand Down Expand Up @@ -127,7 +133,6 @@ class BARTGenerator(torch.nn.Module, GenerationMixin):
def __init__(self, model):
super().__init__()
self.config = BartConfigTS(model.config)
self.config.init_module()
self.config.force_bos_token_to_be_generated = False
self._trace_modules(model)
self.logits_processor = MinLengthLogitsProcessorTS(self.config.min_length, self.config.eos_token_id)
Expand All @@ -136,7 +141,6 @@ def __init__(self, model):
self.decoder_layers = model.config.decoder_layers

def _trace_modules(self, model):
# Be aware of the last one 2 should be kept.
input_ids = torch.tensor(
[
[
Expand Down Expand Up @@ -200,89 +204,25 @@ def _trace_modules(self, model):
57,
8629,
5,
2,
model.config.eos_token_id,
]
],
device=model.device,
dtype=torch.long,
)
attention_mask = torch.tensor(
[
[
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
]
],
[[True] * input_ids.shape[-1]],
device=model.device,
dtype=torch.bool,
)
self.encoder = create_traced_encoder(model.get_encoder(), input_ids, attention_mask)
self.encoder = _create_traced_encoder(model.get_encoder(), input_ids, attention_mask)
encoder_outputs = model.get_encoder()(input_ids, attention_mask=attention_mask, return_dict=True)
decoder = model.model.decoder
decoder_outputs = decoder(input_ids, attention_mask, encoder_outputs["last_hidden_state"], None, None, None)
self.decoder_no_past = create_traced_decoder(
self.decoder_no_past = _create_traced_decoder(
model.model.decoder, input_ids, encoder_outputs["last_hidden_state"], attention_mask
)
self.decoder_with_past = create_traced_decoder(
self.decoder_with_past = _create_traced_decoder(
model.model.decoder, input_ids, encoder_outputs["last_hidden_state"], attention_mask, decoder_outputs[1]
)

Expand Down Expand Up @@ -414,8 +354,8 @@ def __init__(self):
self._beam_hyps_count = torch.zeros(self.batch_size, dtype=torch.long)
self._beam_hyps_worst_scores = torch.zeros(self.batch_size) + 1e9
self._beam_hyps_max_length: int = self.max_length - 1
self._beam_hyps: List[torch.Tensor] = [torch.zeros(2)] # placeholder for TorchScript compatible
self._beam_scores: List[torch.Tensor] = [torch.zeros(2)] # placeholder for TorchScript compatible
self._beam_hyps: List[torch.Tensor] = [torch.zeros(2)] # placeholder for TorchScript compatibility
self._beam_scores: List[torch.Tensor] = [torch.zeros(2)] # placeholder for TorchScript compatibility

def is_done(self) -> torch.Tensor:
return self._done.all()
Expand Down Expand Up @@ -474,11 +414,11 @@ def hypo_add(self, hyp: torch.Tensor, sum_logprobs: float, hypo_idx: int):
score = sum_logprobs / (hyp.shape[-1] ** self.length_penalty)
hyps_count = self.hypo_len(hypo_idx)
if hyps_count < self.num_beams or score > self._beam_hyps_worst_scores[hypo_idx]:
# NOTE: work around difference of torch.sum(empty_tensor) = 0, while error in onnx.
# NOTE: work around difference of torch.sum(empty_tensor) == 0, while error in onnx.
# Bug: https://msdata.visualstudio.com/Vienna/_workitems/edit/1486599
beam_idx = (
torch.sum(self._beam_hyps_count[:hypo_idx]) if hypo_idx != 0 else torch.tensor(0, dtype=torch.long)
)
# beam_idx = torch.sum(_beam_hyps_count[:hypo_idx])
self._beam_scores.insert(beam_idx, torch.tensor([score]))
self._beam_hyps.insert(beam_idx, hyp)
if hyps_count + 1 > self.num_beams:
Expand Down Expand Up @@ -605,7 +545,7 @@ def finalize(
self.hypo_add(final_tokens, final_score, batch_idx)

# select the best hypotheses
# NOTE: new is not scriptable
# NOTE: torch.Tensor.new_zeros() is not scriptable
sent_lengths = torch.zeros(batch_size * self.num_beam_hyps_to_keep, dtype=torch.long)
best = []
best_scores = torch.zeros(
Expand Down Expand Up @@ -782,7 +722,6 @@ def forward(self, input_ids, attention_mask, num_beams, max_length, decoder_star
bos_token_id=bos_token_id,
)

# from generation_utils.py
batch_size = input_ids.shape[0]

length_penalty = self.config.length_penalty
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
"""
Code to remove duplicate initializers to reduce ONNX model size.
"""

import os

import numpy

import onnx


def is_equal_tensor_proto(a, b):
def _is_equal_tensor_proto(a, b):
name_a = a.name
name_b = b.name

Expand All @@ -20,25 +24,25 @@ def is_equal_tensor_proto(a, b):
return res


def node_replace_input_with(node_proto, name, new_name):
def _node_replace_input_with(node_proto, name, new_name):
for i, input_name in enumerate(node_proto.input):
if input_name == name:
node_proto.input.insert(i, new_name)
node_proto.input.pop(i + 1)

if node_proto.op_type == "If":
graph_replace_input_with(node_proto.attribute[0].g, name, new_name)
graph_replace_input_with(node_proto.attribute[1].g, name, new_name)
_graph_replace_input_with(node_proto.attribute[0].g, name, new_name)
_graph_replace_input_with(node_proto.attribute[1].g, name, new_name)
if node_proto.op_type == "Loop":
graph_replace_input_with(node_proto.attribute[0].g, name, new_name)
_graph_replace_input_with(node_proto.attribute[0].g, name, new_name)


def graph_replace_input_with(graph_proto, name, new_name):
def _graph_replace_input_with(graph_proto, name, new_name):
for n in graph_proto.node:
node_replace_input_with(n, name, new_name)
_node_replace_input_with(n, name, new_name)


def remove_dup_initializers_from_model(model, model_without_ext, ind_to_replace):
def _remove_dup_initializers_from_model(model, model_without_ext, ind_to_replace):
inits_with_data = [i for i in model.graph.initializer]
inits = [i for i in model_without_ext.graph.initializer]
for i, ref_i in ind_to_replace:
Expand All @@ -52,10 +56,15 @@ def remove_dup_initializers_from_model(model, model_without_ext, ind_to_replace)
model_without_ext.graph.initializer.remove(inits[i])

# for n in model.graph.node:
graph_replace_input_with(model_without_ext.graph, name_i, name_ref)
_graph_replace_input_with(model_without_ext.graph, name_i, name_ref)


def remove_dup_initializers(onnx_file_path):
"""
Removes duplicate initializers from the model to reduce its size.
Writes a new file in the same directory as onnx_file_path and returns the path to that file.
"""

model_file_folder = os.path.dirname(onnx_file_path)
model_file_name = os.path.basename(onnx_file_path)

Expand All @@ -76,7 +85,7 @@ def remove_dup_initializers(onnx_file_path):
for j in range(i + 1, len(inits)):
if j in dup_set:
continue
if is_equal_tensor_proto(inits[i], inits[j]):
if _is_equal_tensor_proto(inits[i], inits[j]):
dup_set.add(i)
dup_set.add(j)

Expand All @@ -103,8 +112,8 @@ def remove_dup_initializers(onnx_file_path):

print("total reduced size: ", total_reduced_size / 1024 / 1024 / 1024, "GB")

ind_to_replace = sorted(ind_to_replace, key=lambda x: x[0])
remove_dup_initializers_from_model(model, model, ind_to_replace)
ind_to_replace = sorted(ind_to_replace)
_remove_dup_initializers_from_model(model, model, ind_to_replace)

optimized_model_file_name = "optimized_" + model_file_name
new_model = os.path.join(model_file_folder, optimized_model_file_name)
Expand Down
1 change: 1 addition & 0 deletions examples/onnx/pytorch/summarization/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
torch >= 1.10
Loading

0 comments on commit b06abf1

Please sign in to comment.