Skip to content

Commit

Permalink
Added onnx config whisper (#19525)
Browse files Browse the repository at this point in the history
* Added onnx config whisper

* added whisper support onnx

* add audio input data

* added whisper support onnx

* fixed the seqlength value

* Updated the whisper onnx ocnfig

* restore files to old version

* removed attention mask from inputs

* Updated get_dummy_input_onnxruntime docstring

* Updated relative imports and token generation

* update docstring
  • Loading branch information
mht-sharma authored and sgugger committed Nov 1, 2022
1 parent 1ebb3f7 commit 0e654e0
Show file tree
Hide file tree
Showing 7 changed files with 148 additions and 6 deletions.
1 change: 1 addition & 0 deletions docs/source/en/serialization.mdx
Expand Up @@ -100,6 +100,7 @@ Ready-made configurations include the following architectures:
- Table Transformer
- Vision Encoder decoder
- ViT
- Whisper
- XLM
- XLM-RoBERTa
- XLM-RoBERTa-XL
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/whisper/__init__.py
Expand Up @@ -21,7 +21,7 @@


_import_structure = {
"configuration_whisper": ["WHISPER_PRETRAINED_CONFIG_ARCHIVE_MAP", "WhisperConfig"],
"configuration_whisper": ["WHISPER_PRETRAINED_CONFIG_ARCHIVE_MAP", "WhisperConfig", "WhisperOnnxConfig"],
"feature_extraction_whisper": ["WhisperFeatureExtractor"],
"processing_whisper": ["WhisperProcessor"],
"tokenization_whisper": ["WhisperTokenizer"],
Expand Down Expand Up @@ -55,7 +55,7 @@
]

if TYPE_CHECKING:
from .configuration_whisper import WHISPER_PRETRAINED_CONFIG_ARCHIVE_MAP, WhisperConfig
from .configuration_whisper import WHISPER_PRETRAINED_CONFIG_ARCHIVE_MAP, WhisperConfig, WhisperOnnxConfig
from .feature_extraction_whisper import WhisperFeatureExtractor
from .processing_whisper import WhisperProcessor
from .tokenization_whisper import WhisperTokenizer
Expand Down
65 changes: 65 additions & 0 deletions src/transformers/models/whisper/configuration_whisper.py
Expand Up @@ -14,10 +14,19 @@
# limitations under the License.
""" Whisper model configuration"""

from collections import OrderedDict
from typing import TYPE_CHECKING, Any, Mapping, Optional, Union

from ...configuration_utils import PretrainedConfig
from ...onnx import OnnxConfig, OnnxSeq2SeqConfigWithPast
from ...utils import logging


if TYPE_CHECKING:
from ...feature_extraction_utils import FeatureExtractionMixin
from ...tokenization_utils_base import PreTrainedTokenizerBase
from ...utils import TensorType

logger = logging.get_logger(__name__)

WHISPER_PRETRAINED_CONFIG_ARCHIVE_MAP = {
Expand Down Expand Up @@ -214,3 +223,59 @@ def __init__(
begin_suppress_tokens=begin_suppress_tokens,
**kwargs,
)


class WhisperOnnxConfig(OnnxSeq2SeqConfigWithPast):
@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
common_inputs = OrderedDict(
[
("input_features", {0: "batch", 1: "feature_size", 2: "encoder_sequence"}),
]
)
if self.use_past:
common_inputs["decoder_input_ids"] = {0: "batch"}
else:
common_inputs["decoder_input_ids"] = {0: "batch", 1: "decoder_sequence"}

if self.use_past:
self.fill_with_past_key_values_(common_inputs, direction="inputs")

return common_inputs

def generate_dummy_inputs(
self,
preprocessor: Union["PreTrainedTokenizerBase", "FeatureExtractionMixin"],
batch_size: int = -1,
seq_length: int = -1,
is_pair: bool = False,
framework: Optional["TensorType"] = None,
sampling_rate: int = 22050,
time_duration: float = 5.0,
frequency: int = 220,
) -> Mapping[str, Any]:
dummy_inputs = OrderedDict()
encoder_inputs = OnnxConfig.generate_dummy_inputs(
self,
preprocessor=preprocessor.feature_extractor,
batch_size=batch_size,
framework=framework,
sampling_rate=sampling_rate,
time_duration=time_duration,
frequency=frequency,
)
decoder_inputs = super().generate_dummy_inputs(
preprocessor.tokenizer, batch_size, seq_length, is_pair, framework
)

dummy_inputs["input_features"] = encoder_inputs.pop("input_features")
dummy_inputs["decoder_input_ids"] = decoder_inputs.pop("decoder_input_ids")

if "past_key_values" in decoder_inputs:
dummy_inputs["past_key_values"] = decoder_inputs.pop("past_key_values")

return dummy_inputs

@property
def atol_for_validation(self) -> float:
return 1e-3
51 changes: 50 additions & 1 deletion src/transformers/onnx/config.py
Expand Up @@ -104,6 +104,7 @@ class OnnxConfig(ABC):
"sequence-classification": OrderedDict({"logits": {0: "batch"}}),
"token-classification": OrderedDict({"logits": {0: "batch", 1: "sequence"}}),
"vision2seq-lm": OrderedDict({"logits": {0: "batch", 1: "sequence"}}),
"speech2seq-lm": OrderedDict({"logits": {0: "batch", 1: "sequence"}}),
}

def __init__(self, config: "PretrainedConfig", task: str = "default", patching_specs: List[PatchingSpec] = None):
Expand Down Expand Up @@ -262,6 +263,19 @@ def _generate_dummy_images(
images.append(Image.fromarray(data.astype("uint8")).convert("RGB"))
return images

def _generate_dummy_audio(
self, batch_size: int = 2, sampling_rate: int = 22050, time_duration: float = 5.0, frequency: int = 220
):
audio_data = []
for _ in range(batch_size):
# time variable
t = np.linspace(0, time_duration, int(time_duration * sampling_rate), endpoint=False)

# generate pure sine wave at `frequency` Hz
audio_data.append(0.5 * np.sin(2 * np.pi * frequency * t))

return audio_data

def generate_dummy_inputs(
self,
preprocessor: Union["PreTrainedTokenizerBase", "FeatureExtractionMixin"],
Expand All @@ -273,6 +287,9 @@ def generate_dummy_inputs(
num_channels: int = 3,
image_width: int = 40,
image_height: int = 40,
sampling_rate: int = 22050,
time_duration: float = 5.0,
frequency: int = 220,
tokenizer: "PreTrainedTokenizerBase" = None,
) -> Mapping[str, Any]:
"""
Expand All @@ -297,6 +314,12 @@ def generate_dummy_inputs(
The width of the generated images.
image_height (`int`, *optional*, defaults to 40):
The height of the generated images.
sampling_rate (`int`, *optional* defaults to 22050)
The sampling rate for audio data generation.
time_duration (`float`, *optional* defaults to 5.0)
Total seconds of sampling for audio data generation.
frequency (`int`, *optional* defaults to 220)
The desired natural frequency of generated audio.
Returns:
Mapping[str, Tensor] holding the kwargs to provide to the model's forward function
Expand Down Expand Up @@ -325,7 +348,12 @@ def generate_dummy_inputs(
seq_length, fixed_dimension=OnnxConfig.default_fixed_sequence, num_token_to_add=token_to_add
)
# Generate dummy inputs according to compute batch and sequence
dummy_input = [" ".join([preprocessor.unk_token]) * seq_length] * batch_size
input_token = (
preprocessor.unk_token
if (preprocessor.unk_token is not None and len(preprocessor.unk_token) > 0)
else "0"
)
dummy_input = [" ".join([input_token]) * seq_length] * batch_size
if self.task == "multiple-choice":
# If dynamic axis (-1) we forward with a fixed dimension of 4 candidate answers to avoid optimizations
# made by ONNX
Expand All @@ -345,11 +373,32 @@ def generate_dummy_inputs(
batch_size = compute_effective_axis_dimension(batch_size, fixed_dimension=OnnxConfig.default_fixed_batch)
dummy_input = self._generate_dummy_images(batch_size, num_channels, image_height, image_width)
return dict(preprocessor(images=dummy_input, return_tensors=framework))
elif (
isinstance(preprocessor, FeatureExtractionMixin) and preprocessor.model_input_names[0] == "input_features"
):
# If dynamic axis (-1) we forward with a fixed dimension of 2 samples to avoid optimizations made by ONNX
batch_size = compute_effective_axis_dimension(batch_size, fixed_dimension=OnnxConfig.default_fixed_batch)
dummy_input = self._generate_dummy_audio(batch_size, sampling_rate, time_duration, frequency)
return dict(preprocessor(dummy_input, return_tensors=framework))
else:
raise ValueError(
"Unable to generate dummy inputs for the model. Please provide a tokenizer or a preprocessor."
)

def generate_dummy_inputs_onnxruntime(self, reference_model_inputs: Mapping[str, Any]) -> Mapping[str, Any]:
"""
Generate inputs for ONNX Runtime using the reference model inputs. Override this to run inference with seq2seq
models which have the encoder and decoder exported as separate ONNX files.
Args:
reference_model_inputs ([`Mapping[str, Tensor]`):
Reference inputs for the model.
Returns:
`Mapping[str, Tensor]`: The mapping holding the kwargs to provide to the model's forward function
"""
return reference_model_inputs

def patch_ops(self):
for spec in self._patching_specs:
custom_op = spec.custom_op if spec.op_wrapper is None else spec.op_wrapper(spec.custom_op)
Expand Down
21 changes: 19 additions & 2 deletions src/transformers/onnx/convert.py
Expand Up @@ -145,7 +145,21 @@ def export_pytorch(
device = torch.device(device)
if device.type == "cuda" and torch.cuda.is_available():
model.to(device)
model_inputs = dict((k, v.to(device)) for k, v in model_inputs.items())
model_inputs_device = dict()
for k, v in model_inputs.items():
if isinstance(v, Tuple):
model_inputs_device[k] = tuple(
x.to(device) if isinstance(x, torch.Tensor) else None for x in v
)
elif isinstance(v, List):
model_inputs_device[k] = [
tuple(x.to(device) if isinstance(x, torch.Tensor) else None for x in t) for t in v
]
else:
model_inputs_device[k] = v.to(device)

model_inputs = model_inputs_device

inputs_match, matched_inputs = ensure_model_and_config_inputs_match(model, model_inputs.keys())
onnx_outputs = list(config.outputs.keys())

Expand Down Expand Up @@ -404,9 +418,12 @@ def validate_model_outputs(
else:
ref_outputs_dict[name] = value

# Create onnxruntime inputs from the reference model inputs
reference_model_inputs_onnxruntime = config.generate_dummy_inputs_onnxruntime(reference_model_inputs)

# We flatten potential collection of inputs (i.e. past_keys)
onnx_inputs = {}
for name, value in reference_model_inputs.items():
for name, value in reference_model_inputs_onnxruntime.items():
if isinstance(value, (list, tuple)):
value = config.flatten_output_collection_property(name, value)
onnx_inputs.update({tensor_name: pt_tensor.numpy() for tensor_name, pt_tensor in value.items()})
Expand Down
9 changes: 9 additions & 0 deletions src/transformers/onnx/features.py
Expand Up @@ -29,6 +29,7 @@
AutoModelForSemanticSegmentation,
AutoModelForSeq2SeqLM,
AutoModelForSequenceClassification,
AutoModelForSpeechSeq2Seq,
AutoModelForTokenClassification,
AutoModelForVision2Seq,
)
Expand Down Expand Up @@ -100,6 +101,7 @@ class FeaturesManager:
"masked-im": AutoModelForMaskedImageModeling,
"semantic-segmentation": AutoModelForSemanticSegmentation,
"vision2seq-lm": AutoModelForVision2Seq,
"speech2seq-lm": AutoModelForSpeechSeq2Seq,
}
if is_tf_available():
_TASKS_TO_TF_AUTOMODELS = {
Expand Down Expand Up @@ -492,6 +494,13 @@ class FeaturesManager:
"vit": supported_features_mapping(
"default", "image-classification", "masked-im", onnx_config_cls="models.vit.ViTOnnxConfig"
),
"whisper": supported_features_mapping(
"default",
"default-with-past",
"speech2seq-lm",
"speech2seq-lm-with-past",
onnx_config_cls="models.whisper.WhisperOnnxConfig",
),
"xlm": supported_features_mapping(
"default",
"masked-lm",
Expand Down
3 changes: 2 additions & 1 deletion tests/onnx/test_onnx_v2.py
Expand Up @@ -218,6 +218,7 @@ def test_values_override(self):
("yolos", "hustvl/yolos-tiny"),
("segformer", "nvidia/segformer-b0-finetuned-ade-512-512"),
("swin", "microsoft/swin-tiny-patch4-window7-224"),
("whisper", "openai/whisper-tiny.en"),
}

PYTORCH_EXPORT_ENCODER_DECODER_MODELS = {
Expand Down Expand Up @@ -398,7 +399,7 @@ def _onnx_export_encoder_decoder_models(
preprocessor = AutoTokenizer.from_pretrained(model_name)

with NamedTemporaryFile("w") as decoder_output:
onnx_inputs, onnx_outputs = export(
_, onnx_outputs = export(
preprocessor,
decoder_model,
decoder_onnx_config,
Expand Down

0 comments on commit 0e654e0

Please sign in to comment.