diff --git a/docs/source/en/serialization.mdx b/docs/source/en/serialization.mdx index db4f8aa83d..1cbc1237f2 100644 --- a/docs/source/en/serialization.mdx +++ b/docs/source/en/serialization.mdx @@ -100,6 +100,7 @@ Ready-made configurations include the following architectures: - Table Transformer - Vision Encoder decoder - ViT +- Whisper - XLM - XLM-RoBERTa - XLM-RoBERTa-XL diff --git a/src/transformers/models/whisper/__init__.py b/src/transformers/models/whisper/__init__.py index 71e354a936..2528e03a4d 100644 --- a/src/transformers/models/whisper/__init__.py +++ b/src/transformers/models/whisper/__init__.py @@ -21,7 +21,7 @@ _import_structure = { - "configuration_whisper": ["WHISPER_PRETRAINED_CONFIG_ARCHIVE_MAP", "WhisperConfig"], + "configuration_whisper": ["WHISPER_PRETRAINED_CONFIG_ARCHIVE_MAP", "WhisperConfig", "WhisperOnnxConfig"], "feature_extraction_whisper": ["WhisperFeatureExtractor"], "processing_whisper": ["WhisperProcessor"], "tokenization_whisper": ["WhisperTokenizer"], @@ -55,7 +55,7 @@ ] if TYPE_CHECKING: - from .configuration_whisper import WHISPER_PRETRAINED_CONFIG_ARCHIVE_MAP, WhisperConfig + from .configuration_whisper import WHISPER_PRETRAINED_CONFIG_ARCHIVE_MAP, WhisperConfig, WhisperOnnxConfig from .feature_extraction_whisper import WhisperFeatureExtractor from .processing_whisper import WhisperProcessor from .tokenization_whisper import WhisperTokenizer diff --git a/src/transformers/models/whisper/configuration_whisper.py b/src/transformers/models/whisper/configuration_whisper.py index 6ee5ee9057..c25dab667d 100644 --- a/src/transformers/models/whisper/configuration_whisper.py +++ b/src/transformers/models/whisper/configuration_whisper.py @@ -14,10 +14,19 @@ # limitations under the License. """ Whisper model configuration""" +from collections import OrderedDict +from typing import TYPE_CHECKING, Any, Mapping, Optional, Union + from ...configuration_utils import PretrainedConfig +from ...onnx import OnnxConfig, OnnxSeq2SeqConfigWithPast from ...utils import logging +if TYPE_CHECKING: + from ...feature_extraction_utils import FeatureExtractionMixin + from ...tokenization_utils_base import PreTrainedTokenizerBase + from ...utils import TensorType + logger = logging.get_logger(__name__) WHISPER_PRETRAINED_CONFIG_ARCHIVE_MAP = { @@ -214,3 +223,59 @@ def __init__( begin_suppress_tokens=begin_suppress_tokens, **kwargs, ) + + +class WhisperOnnxConfig(OnnxSeq2SeqConfigWithPast): + @property + def inputs(self) -> Mapping[str, Mapping[int, str]]: + common_inputs = OrderedDict( + [ + ("input_features", {0: "batch", 1: "feature_size", 2: "encoder_sequence"}), + ] + ) + if self.use_past: + common_inputs["decoder_input_ids"] = {0: "batch"} + else: + common_inputs["decoder_input_ids"] = {0: "batch", 1: "decoder_sequence"} + + if self.use_past: + self.fill_with_past_key_values_(common_inputs, direction="inputs") + + return common_inputs + + def generate_dummy_inputs( + self, + preprocessor: Union["PreTrainedTokenizerBase", "FeatureExtractionMixin"], + batch_size: int = -1, + seq_length: int = -1, + is_pair: bool = False, + framework: Optional["TensorType"] = None, + sampling_rate: int = 22050, + time_duration: float = 5.0, + frequency: int = 220, + ) -> Mapping[str, Any]: + dummy_inputs = OrderedDict() + encoder_inputs = OnnxConfig.generate_dummy_inputs( + self, + preprocessor=preprocessor.feature_extractor, + batch_size=batch_size, + framework=framework, + sampling_rate=sampling_rate, + time_duration=time_duration, + frequency=frequency, + ) + decoder_inputs = super().generate_dummy_inputs( + preprocessor.tokenizer, batch_size, seq_length, is_pair, framework + ) + + dummy_inputs["input_features"] = encoder_inputs.pop("input_features") + dummy_inputs["decoder_input_ids"] = decoder_inputs.pop("decoder_input_ids") + + if "past_key_values" in decoder_inputs: + dummy_inputs["past_key_values"] = decoder_inputs.pop("past_key_values") + + return dummy_inputs + + @property + def atol_for_validation(self) -> float: + return 1e-3 diff --git a/src/transformers/onnx/config.py b/src/transformers/onnx/config.py index 5a1c3e6eed..1c8d10939a 100644 --- a/src/transformers/onnx/config.py +++ b/src/transformers/onnx/config.py @@ -104,6 +104,7 @@ class OnnxConfig(ABC): "sequence-classification": OrderedDict({"logits": {0: "batch"}}), "token-classification": OrderedDict({"logits": {0: "batch", 1: "sequence"}}), "vision2seq-lm": OrderedDict({"logits": {0: "batch", 1: "sequence"}}), + "speech2seq-lm": OrderedDict({"logits": {0: "batch", 1: "sequence"}}), } def __init__(self, config: "PretrainedConfig", task: str = "default", patching_specs: List[PatchingSpec] = None): @@ -262,6 +263,19 @@ def _generate_dummy_images( images.append(Image.fromarray(data.astype("uint8")).convert("RGB")) return images + def _generate_dummy_audio( + self, batch_size: int = 2, sampling_rate: int = 22050, time_duration: float = 5.0, frequency: int = 220 + ): + audio_data = [] + for _ in range(batch_size): + # time variable + t = np.linspace(0, time_duration, int(time_duration * sampling_rate), endpoint=False) + + # generate pure sine wave at `frequency` Hz + audio_data.append(0.5 * np.sin(2 * np.pi * frequency * t)) + + return audio_data + def generate_dummy_inputs( self, preprocessor: Union["PreTrainedTokenizerBase", "FeatureExtractionMixin"], @@ -273,6 +287,9 @@ def generate_dummy_inputs( num_channels: int = 3, image_width: int = 40, image_height: int = 40, + sampling_rate: int = 22050, + time_duration: float = 5.0, + frequency: int = 220, tokenizer: "PreTrainedTokenizerBase" = None, ) -> Mapping[str, Any]: """ @@ -297,6 +314,12 @@ def generate_dummy_inputs( The width of the generated images. image_height (`int`, *optional*, defaults to 40): The height of the generated images. + sampling_rate (`int`, *optional* defaults to 22050) + The sampling rate for audio data generation. + time_duration (`float`, *optional* defaults to 5.0) + Total seconds of sampling for audio data generation. + frequency (`int`, *optional* defaults to 220) + The desired natural frequency of generated audio. Returns: Mapping[str, Tensor] holding the kwargs to provide to the model's forward function @@ -325,7 +348,12 @@ def generate_dummy_inputs( seq_length, fixed_dimension=OnnxConfig.default_fixed_sequence, num_token_to_add=token_to_add ) # Generate dummy inputs according to compute batch and sequence - dummy_input = [" ".join([preprocessor.unk_token]) * seq_length] * batch_size + input_token = ( + preprocessor.unk_token + if (preprocessor.unk_token is not None and len(preprocessor.unk_token) > 0) + else "0" + ) + dummy_input = [" ".join([input_token]) * seq_length] * batch_size if self.task == "multiple-choice": # If dynamic axis (-1) we forward with a fixed dimension of 4 candidate answers to avoid optimizations # made by ONNX @@ -345,11 +373,32 @@ def generate_dummy_inputs( batch_size = compute_effective_axis_dimension(batch_size, fixed_dimension=OnnxConfig.default_fixed_batch) dummy_input = self._generate_dummy_images(batch_size, num_channels, image_height, image_width) return dict(preprocessor(images=dummy_input, return_tensors=framework)) + elif ( + isinstance(preprocessor, FeatureExtractionMixin) and preprocessor.model_input_names[0] == "input_features" + ): + # If dynamic axis (-1) we forward with a fixed dimension of 2 samples to avoid optimizations made by ONNX + batch_size = compute_effective_axis_dimension(batch_size, fixed_dimension=OnnxConfig.default_fixed_batch) + dummy_input = self._generate_dummy_audio(batch_size, sampling_rate, time_duration, frequency) + return dict(preprocessor(dummy_input, return_tensors=framework)) else: raise ValueError( "Unable to generate dummy inputs for the model. Please provide a tokenizer or a preprocessor." ) + def generate_dummy_inputs_onnxruntime(self, reference_model_inputs: Mapping[str, Any]) -> Mapping[str, Any]: + """ + Generate inputs for ONNX Runtime using the reference model inputs. Override this to run inference with seq2seq + models which have the encoder and decoder exported as separate ONNX files. + + Args: + reference_model_inputs ([`Mapping[str, Tensor]`): + Reference inputs for the model. + + Returns: + `Mapping[str, Tensor]`: The mapping holding the kwargs to provide to the model's forward function + """ + return reference_model_inputs + def patch_ops(self): for spec in self._patching_specs: custom_op = spec.custom_op if spec.op_wrapper is None else spec.op_wrapper(spec.custom_op) diff --git a/src/transformers/onnx/convert.py b/src/transformers/onnx/convert.py index 234724699e..e953207b3a 100644 --- a/src/transformers/onnx/convert.py +++ b/src/transformers/onnx/convert.py @@ -145,7 +145,21 @@ def export_pytorch( device = torch.device(device) if device.type == "cuda" and torch.cuda.is_available(): model.to(device) - model_inputs = dict((k, v.to(device)) for k, v in model_inputs.items()) + model_inputs_device = dict() + for k, v in model_inputs.items(): + if isinstance(v, Tuple): + model_inputs_device[k] = tuple( + x.to(device) if isinstance(x, torch.Tensor) else None for x in v + ) + elif isinstance(v, List): + model_inputs_device[k] = [ + tuple(x.to(device) if isinstance(x, torch.Tensor) else None for x in t) for t in v + ] + else: + model_inputs_device[k] = v.to(device) + + model_inputs = model_inputs_device + inputs_match, matched_inputs = ensure_model_and_config_inputs_match(model, model_inputs.keys()) onnx_outputs = list(config.outputs.keys()) @@ -404,9 +418,12 @@ def validate_model_outputs( else: ref_outputs_dict[name] = value + # Create onnxruntime inputs from the reference model inputs + reference_model_inputs_onnxruntime = config.generate_dummy_inputs_onnxruntime(reference_model_inputs) + # We flatten potential collection of inputs (i.e. past_keys) onnx_inputs = {} - for name, value in reference_model_inputs.items(): + for name, value in reference_model_inputs_onnxruntime.items(): if isinstance(value, (list, tuple)): value = config.flatten_output_collection_property(name, value) onnx_inputs.update({tensor_name: pt_tensor.numpy() for tensor_name, pt_tensor in value.items()}) diff --git a/src/transformers/onnx/features.py b/src/transformers/onnx/features.py index 878fcce651..8e69c5a1a0 100644 --- a/src/transformers/onnx/features.py +++ b/src/transformers/onnx/features.py @@ -29,6 +29,7 @@ AutoModelForSemanticSegmentation, AutoModelForSeq2SeqLM, AutoModelForSequenceClassification, + AutoModelForSpeechSeq2Seq, AutoModelForTokenClassification, AutoModelForVision2Seq, ) @@ -100,6 +101,7 @@ class FeaturesManager: "masked-im": AutoModelForMaskedImageModeling, "semantic-segmentation": AutoModelForSemanticSegmentation, "vision2seq-lm": AutoModelForVision2Seq, + "speech2seq-lm": AutoModelForSpeechSeq2Seq, } if is_tf_available(): _TASKS_TO_TF_AUTOMODELS = { @@ -492,6 +494,13 @@ class FeaturesManager: "vit": supported_features_mapping( "default", "image-classification", "masked-im", onnx_config_cls="models.vit.ViTOnnxConfig" ), + "whisper": supported_features_mapping( + "default", + "default-with-past", + "speech2seq-lm", + "speech2seq-lm-with-past", + onnx_config_cls="models.whisper.WhisperOnnxConfig", + ), "xlm": supported_features_mapping( "default", "masked-lm", diff --git a/tests/onnx/test_onnx_v2.py b/tests/onnx/test_onnx_v2.py index eac6ee0634..ab8610db71 100644 --- a/tests/onnx/test_onnx_v2.py +++ b/tests/onnx/test_onnx_v2.py @@ -218,6 +218,7 @@ def test_values_override(self): ("yolos", "hustvl/yolos-tiny"), ("segformer", "nvidia/segformer-b0-finetuned-ade-512-512"), ("swin", "microsoft/swin-tiny-patch4-window7-224"), + ("whisper", "openai/whisper-tiny.en"), } PYTORCH_EXPORT_ENCODER_DECODER_MODELS = { @@ -398,7 +399,7 @@ def _onnx_export_encoder_decoder_models( preprocessor = AutoTokenizer.from_pretrained(model_name) with NamedTemporaryFile("w") as decoder_output: - onnx_inputs, onnx_outputs = export( + _, onnx_outputs = export( preprocessor, decoder_model, decoder_onnx_config,