From 53102bc7e602a766c1354bc17ca571afa3d77cc4 Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Wed, 8 May 2024 15:05:28 -0700 Subject: [PATCH 01/12] Ininial WARs to implement dynamo option for export Signed-off-by: Boris Fomitchev --- Dockerfile | 27 ++++++----- .../asr/parts/preprocessing/features.py | 21 +++++---- .../megatron/retro_dataset.py | 3 +- nemo/core/classes/common.py | 20 ++++++++- nemo/core/classes/exportable.py | 36 +++++++++------ nemo/utils/export_utils.py | 5 +++ tests/collections/asr/test_asr_exportables.py | 45 ++++--------------- tests/collections/nlp/test_nlp_exportables.py | 3 ++ tests/collections/tts/test_tts_exportables.py | 4 ++ 9 files changed, 92 insertions(+), 72 deletions(-) diff --git a/Dockerfile b/Dockerfile index 396645d37019..c834fcfbbf48 100644 --- a/Dockerfile +++ b/Dockerfile @@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -ARG BASE_IMAGE=nvcr.io/nvidia/pytorch:24.01-py3 +ARG BASE_IMAGE=nvcr.io/nvidia/tritonserver:24.04-py3 # build an image that includes only the nemo dependencies, ensures that dependencies # are included first for optimal caching, and useful for building a development @@ -61,20 +61,17 @@ RUN apt-get update && \ libgts-dev && \ rm -rf /var/lib/apt/lists/* +RUN pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu121 +RUN pip3 install onnxscript==0.1.0.dev20240430 + WORKDIR /workspace/ # Install megatron core, this can be removed once 0.3 pip package is released # We leave it here in case we need to work off of a specific commit in main RUN git clone https://github.com/NVIDIA/Megatron-LM.git && \ cd Megatron-LM && \ - git checkout 36e9b6bf3d8034b10c9bbd9fc357c2df2bd1515c && \ - git cherry-pick -n e69187bc3679ea5841030a165d587bb48b56ee77 && \ pip install . -# Performance optimizations for distributed optimizer: https://github.com/NVIDIA/apex/pull/1771 -RUN git clone https://github.com/NVIDIA/apex.git && \ - cd apex && \ - git checkout f058162b215791b15507bb542f22ccfde49c872d && \ - pip install -v --no-build-isolation --disable-pip-version-check --no-cache-dir --config-settings "--build-option=--cpp_ext --cuda_ext --fast_layer_norm --distributed_adam --deprecated_fused_adam" ./ +RUN pip3 install packaging # Transformer Engine 1.2.0 RUN git clone https://github.com/NVIDIA/TransformerEngine.git && \ @@ -84,6 +81,12 @@ RUN git clone https://github.com/NVIDIA/TransformerEngine.git && \ git submodule init && git submodule update && \ NVTE_FRAMEWORK=pytorch NVTE_WITH_USERBUFFERS=1 MPI_HOME=/usr/local/mpi pip install . +# Performance optimizations for distributed optimizer: https://github.com/NVIDIA/apex/pull/1771 +RUN git clone https://github.com/NVIDIA/apex.git && \ + cd apex && \ + sed -i '178d' setup.py && \ + pip install -v --no-build-isolation --disable-pip-version-check --no-cache-dir --config-settings "--build-option=--cpp_ext --cuda_ext --fast_layer_norm --group_norm --distributed_adam --deprecated_fused_adam" ./ + WORKDIR /tmp/ # uninstall stuff from base container @@ -152,11 +155,13 @@ RUN /usr/bin/test -n "$NEMO_VERSION" && \ # Install NeMo RUN --mount=from=nemo-src,target=/tmp/nemo,rw cd /tmp/nemo && pip install ".[all]" +RUN apt-get install -y python3 +RUN alias python=python3 # Check install -RUN python -c "import nemo.collections.nlp as nemo_nlp" && \ - python -c "import nemo.collections.tts as nemo_tts" && \ - python -c "import nemo_text_processing.text_normalization as text_normalization" +RUN python3 -c "import nemo.collections.nlp as nemo_nlp" && \ + python3 -c "import nemo.collections.tts as nemo_tts" && \ + python3 -c "import nemo_text_processing.text_normalization as text_normalization" # copy scripts/examples/tests into container for end user diff --git a/nemo/collections/asr/parts/preprocessing/features.py b/nemo/collections/asr/parts/preprocessing/features.py index 67813f3e66d2..8479611b3513 100644 --- a/nemo/collections/asr/parts/preprocessing/features.py +++ b/nemo/collections/asr/parts/preprocessing/features.py @@ -292,6 +292,7 @@ def __init__( self.hop_length = n_window_stride self.n_fft = n_fft or 2 ** math.ceil(math.log2(self.win_length)) self.stft_pad_amount = (self.n_fft - self.hop_length) // 2 if exact_pad else None + self.exact_pad = exact_pad if exact_pad: logging.info("STFT using exact pad") @@ -305,15 +306,6 @@ def __init__( window_fn = torch_windows.get(window, None) window_tensor = window_fn(self.win_length, periodic=False) if window_fn else None self.register_buffer("window", window_tensor) - self.stft = lambda x: torch.stft( - x, - n_fft=self.n_fft, - hop_length=self.hop_length, - win_length=self.win_length, - center=False if exact_pad else True, - window=self.window.to(dtype=torch.float), - return_complex=True, - ) self.normalize = normalize self.log = log @@ -372,6 +364,17 @@ def __init__( logging.debug(f"using grads: {use_grads}") logging.debug(f"nb_augmentation_prob: {nb_augmentation_prob}") + def stft(self, x): + return torch.stft( + x, + n_fft=self.n_fft, + hop_length=self.hop_length, + win_length=self.win_length, + center=False if self.exact_pad else True, + window=self.window.to(dtype=torch.float), + return_complex=True, + ) + def log_zero_guard_value_fn(self, x): if isinstance(self.log_zero_guard_value, str): if self.log_zero_guard_value == "tiny": diff --git a/nemo/collections/nlp/data/language_modeling/megatron/retro_dataset.py b/nemo/collections/nlp/data/language_modeling/megatron/retro_dataset.py index 377bff309b7c..3cec32760328 100644 --- a/nemo/collections/nlp/data/language_modeling/megatron/retro_dataset.py +++ b/nemo/collections/nlp/data/language_modeling/megatron/retro_dataset.py @@ -46,8 +46,7 @@ HAVE_MEGATRON_CORE = True -except (ImportError, ModuleNotFoundError): - +except (ImportError, ModuleNotFoundError) as e: HAVE_MEGATRON_CORE = False diff --git a/nemo/core/classes/common.py b/nemo/core/classes/common.py index cf39ed134768..fe7f040287cc 100644 --- a/nemo/core/classes/common.py +++ b/nemo/core/classes/common.py @@ -1004,8 +1004,17 @@ def __init__( self.ignore_collections = ignore_collections + def __call__(self, wrapped): + return self.unwrapped_call(wrapped) if is_typecheck_enabled() else self.unwrapped_call(wrapped) + + def unwrapped_call(self, wrapped): + return wrapped + + def wrapped_call(self, wrapped): + return self.decorated_call(wrapped) + @wrapt.decorator(enabled=is_typecheck_enabled) - def __call__(self, wrapped, instance: Typing, args, kwargs): + def decorated_call(self, wrapped, instance: Typing, args, kwargs): """ Wrapper method that can be used on any function of a class that implements :class:`~nemo.core.Typing`. By default, it will utilize the `input_types` and `output_types` properties of the class inheriting Typing. @@ -1114,3 +1123,12 @@ def disable_semantic_checks(): yield finally: typecheck.set_semantic_check_enabled(enabled=True) + + @staticmethod + def enable_wrapping(enabled: bool = True): + typecheck.set_typecheck_enabled(enabled) + if enabled: + typecheck.__call__.__code__ = nemo.core.classes.common.typecheck.wrapped_call.__code__ + else: + typecheck.__call__.__code__ = nemo.core.classes.common.typecheck.unwrapped_call.__code__ + print(typecheck.__call__) diff --git a/nemo/core/classes/exportable.py b/nemo/core/classes/exportable.py index 5bd1bb813ba3..b2fa3a920d30 100644 --- a/nemo/core/classes/exportable.py +++ b/nemo/core/classes/exportable.py @@ -68,6 +68,7 @@ def export( check_tolerance=0.01, export_modules_as_functions=False, keep_initializers_as_inputs=None, + use_dynamo=True, ): """ Exports the model to the specified format. The format is inferred from the file extension of the output file. @@ -143,6 +144,7 @@ def _export( check_tolerance=0.01, export_modules_as_functions=False, keep_initializers_as_inputs=None, + use_dynamo=True, ): my_args = locals().copy() my_args.pop('self') @@ -218,19 +220,27 @@ def _export( if dynamic_axes is None: dynamic_axes = get_dynamic_axes(self.input_module.input_types_for_export, input_names) dynamic_axes.update(get_dynamic_axes(self.output_module.output_types_for_export, output_names)) - torch.onnx.export( - jitted_model, - input_example, - output, - input_names=input_names, - output_names=output_names, - verbose=verbose, - do_constant_folding=do_constant_folding, - dynamic_axes=dynamic_axes, - opset_version=onnx_opset_version, - keep_initializers_as_inputs=keep_initializers_as_inputs, - export_modules_as_functions=export_modules_as_functions, - ) + if use_dynamo: + options = torch.onnx.ExportOptions(dynamic_shapes=dynamic_axes) + ex_model = torch.export.export(jitted_model, tuple(input_list), kwargs=input_dict) + ex_model = ex_model.run_decompositions() + ex = torch.onnx.dynamo_export(ex_model, *input_list, **input_dict, export_options=options) + ex.save(output) + input_names = None + else: + torch.onnx.export( + jitted_model, + input_example, + output, + input_names=input_names, + output_names=output_names, + verbose=verbose, + do_constant_folding=do_constant_folding, + dynamic_axes=dynamic_axes, + opset_version=onnx_opset_version, + keep_initializers_as_inputs=keep_initializers_as_inputs, + export_modules_as_functions=export_modules_as_functions, + ) if check_trace: verify_runtime(self, output, check_trace_input, input_names, check_tolerance=check_tolerance) diff --git a/nemo/utils/export_utils.py b/nemo/utils/export_utils.py index 4c7a166437cc..58256659bfc5 100644 --- a/nemo/utils/export_utils.py +++ b/nemo/utils/export_utils.py @@ -126,6 +126,11 @@ def parse_input_example(input_example): def to_onnxrt_input(ort_input_names, input_names, input_dict, input_list): odict = {} + if not input_names: + input_list.extend(input_dict.values()) + for k, v in zip(ort_input_names, input_list): + odict[k] = v.cpu().numpy() + return odict for k in reversed(input_names): val = None if k in input_dict: diff --git a/tests/collections/asr/test_asr_exportables.py b/tests/collections/asr/test_asr_exportables.py index 86bcacab86db..6bb669a70a24 100644 --- a/tests/collections/asr/test_asr_exportables.py +++ b/tests/collections/asr/test_asr_exportables.py @@ -30,6 +30,10 @@ from nemo.core.utils import numba_utils from nemo.core.utils.numba_utils import __NUMBA_MINIMUM_VERSION__ +# from nemo.core.classes import typecheck +# typecheck.enable_wrapping(enabled=False) + + NUMBA_RNNT_LOSS_AVAILABLE = numba_utils.numba_cuda_is_supported(__NUMBA_MINIMUM_VERSION__) @@ -52,8 +56,6 @@ def test_EncDecCTCModel_export_to_onnx(self): ) onnx_model = onnx.load(filename) onnx.checker.check_model(onnx_model, full_check=True) # throws when failed - assert onnx_model.graph.input[0].name == 'audio_signal' - assert onnx_model.graph.output[0].name == 'logprobs' @pytest.mark.run_only_on('GPU') @pytest.mark.unit @@ -66,8 +68,6 @@ def test_EncDecClassificationModel_export_to_onnx(self, speech_classification_mo ) onnx_model = onnx.load(filename) onnx.checker.check_model(onnx_model, full_check=True) # throws when failed - assert onnx_model.graph.input[0].name == 'audio_signal' - assert onnx_model.graph.output[0].name == 'logits' @pytest.mark.run_only_on('GPU') @pytest.mark.unit @@ -78,8 +78,6 @@ def test_EncDecSpeakerLabelModel_export_to_onnx(self, speaker_label_model): model.export(output=filename) onnx_model = onnx.load(filename) onnx.checker.check_model(onnx_model, full_check=True) # throws when failed - assert onnx_model.graph.input[0].name == 'audio_signal' - assert onnx_model.graph.output[0].name == 'logits' @pytest.mark.run_only_on('GPU') @pytest.mark.unit @@ -90,9 +88,6 @@ def test_EncDecCitrinetModel_export_to_onnx(self, citrinet_model): model.export(output=filename) onnx_model = onnx.load(filename) onnx.checker.check_model(onnx_model, full_check=True) # throws when failed - assert onnx_model.graph.input[0].name == 'audio_signal' - assert onnx_model.graph.input[1].name == 'length' - assert onnx_model.graph.output[0].name == 'logprobs' @pytest.mark.pleasefixme @pytest.mark.run_only_on('GPU') @@ -132,9 +127,6 @@ def test_EncDecCitrinetModel_limited_SE_export_to_onnx(self, citrinet_model): ) onnx_model = onnx.load(filename) onnx.checker.check_model(onnx_model, full_check=True) # throws when failed - assert onnx_model.graph.input[0].name == 'audio_signal' - assert onnx_model.graph.input[1].name == 'length' - assert onnx_model.graph.output[0].name == 'logprobs' @pytest.mark.run_only_on('GPU') @pytest.mark.unit @@ -153,10 +145,6 @@ def test_EncDecRNNTModel_export_to_onnx(self, citrinet_rnnt_model): onnx.checker.check_model(onnx_model, full_check=True) # throws when failed assert len(onnx_model.graph.input) == 2 assert len(onnx_model.graph.output) == 2 - assert onnx_model.graph.input[0].name == 'audio_signal' - assert onnx_model.graph.input[1].name == 'length' - assert onnx_model.graph.output[0].name == 'outputs' - assert onnx_model.graph.output[1].name == 'encoded_lengths' decoder_joint_filename = os.path.join(tmpdir, 'decoder_joint-' + fn) assert files[1] == decoder_joint_filename @@ -171,21 +159,12 @@ def test_EncDecRNNTModel_export_to_onnx(self, citrinet_rnnt_model): # enc_logits + (all decoder inputs - state tuple) + flattened state list assert len(onnx_model.graph.input) == (1 + (len(input_examples) - 1) + num_states) - assert onnx_model.graph.input[0].name == 'encoder_outputs' - assert onnx_model.graph.input[1].name == 'targets' - assert onnx_model.graph.input[2].name == 'target_length' if num_states > 0: for idx, ip in enumerate(onnx_model.graph.input[3:]): assert ip.name == "input_" + state_name + '_' + str(idx + 1) assert len(onnx_model.graph.output) == (len(input_examples) - 1) + num_states - assert onnx_model.graph.output[0].name == 'outputs' - assert onnx_model.graph.output[1].name == 'prednet_lengths' - - if num_states > 0: - for idx, op in enumerate(onnx_model.graph.output[2:]): - assert op.name == "output_" + state_name + '_' + str(idx + 1) @pytest.mark.run_only_on('GPU') @pytest.mark.unit @@ -206,8 +185,6 @@ def test_EncDecRNNTModel_export_to_ts(self, citrinet_rnnt_model): assert ts_encoder is not None arguments = ts_encoder.forward.schema.arguments[1:] # First value is `self` - assert arguments[0].name == 'audio_signal' - assert arguments[1].name == 'length' decoder_joint_filename = os.path.join(tmpdir, 'decoder_joint-' + fn) assert files[1] == decoder_joint_filename @@ -225,13 +202,6 @@ def test_EncDecRNNTModel_export_to_ts(self, citrinet_rnnt_model): # enc_logits + (all decoder inputs - state tuple) + flattened state list assert len(ts_decoder_joint_args) == (1 + (len(input_examples) - 1) + num_states) - assert ts_decoder_joint_args[0].name == 'encoder_outputs' - assert ts_decoder_joint_args[1].name == 'targets' - assert ts_decoder_joint_args[2].name == 'target_length' - - if num_states > 0: - for idx, ip in enumerate(ts_decoder_joint_args[3:]): - assert ip.name == "input_" + state_name + '_' + str(idx + 1) @pytest.mark.run_only_on('GPU') @pytest.mark.unit @@ -265,8 +235,6 @@ def test_EncDecCTCModel_adapted_export_to_onnx(self): ) onnx_model = onnx.load(filename) onnx.checker.check_model(onnx_model, full_check=True) # throws when failed - assert onnx_model.graph.input[0].name == 'audio_signal' - assert onnx_model.graph.output[0].name == 'logprobs' def setup_method(self): self.preprocessor = { @@ -670,3 +638,8 @@ def squeezeformer_model(): ) conformer_model = EncDecCTCModel(cfg=modelConfig) return conformer_model + + +if __name__ == "__main__": + t = TestExportable() + t.test_EncDecClassificationModel_export_to_onnx(speech_classification_model()) diff --git a/tests/collections/nlp/test_nlp_exportables.py b/tests/collections/nlp/test_nlp_exportables.py index c0b97caea4ed..3181e1ce0c46 100644 --- a/tests/collections/nlp/test_nlp_exportables.py +++ b/tests/collections/nlp/test_nlp_exportables.py @@ -20,6 +20,9 @@ import torch import wget from omegaconf import DictConfig, OmegaConf +from nemo.core.classes import typecheck + +typecheck.enable_wrapping(enabled=False) from nemo.collections import nlp as nemo_nlp from nemo.collections.nlp.models import IntentSlotClassificationModel diff --git a/tests/collections/tts/test_tts_exportables.py b/tests/collections/tts/test_tts_exportables.py index 67f016b0c2af..2569d708e235 100644 --- a/tests/collections/tts/test_tts_exportables.py +++ b/tests/collections/tts/test_tts_exportables.py @@ -18,6 +18,10 @@ import torch from omegaconf import OmegaConf +from nemo.core.classes import typecheck + +typecheck.enable_wrapping(enabled=False) + from nemo.collections.tts.models import FastPitchModel, HifiGanModel, RadTTSModel from nemo.utils.app_state import AppState From d3c41f7dc3fb37fd8486d4c8993092b7d10c09d4 Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Wed, 8 May 2024 16:59:42 -0700 Subject: [PATCH 02/12] including weights in .onnx Signed-off-by: Boris Fomitchev --- nemo/core/classes/common.py | 2 +- nemo/core/classes/exportable.py | 6 ++++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/nemo/core/classes/common.py b/nemo/core/classes/common.py index fe7f040287cc..a7aa9e17b1fd 100644 --- a/nemo/core/classes/common.py +++ b/nemo/core/classes/common.py @@ -1005,7 +1005,7 @@ def __init__( self.ignore_collections = ignore_collections def __call__(self, wrapped): - return self.unwrapped_call(wrapped) if is_typecheck_enabled() else self.unwrapped_call(wrapped) + return self.wrapped_call(wrapped) if is_typecheck_enabled() else self.unwrapped_call(wrapped) def unwrapped_call(self, wrapped): return wrapped diff --git a/nemo/core/classes/exportable.py b/nemo/core/classes/exportable.py index b2fa3a920d30..5e7d5522765c 100644 --- a/nemo/core/classes/exportable.py +++ b/nemo/core/classes/exportable.py @@ -222,10 +222,12 @@ def _export( dynamic_axes.update(get_dynamic_axes(self.output_module.output_types_for_export, output_names)) if use_dynamo: options = torch.onnx.ExportOptions(dynamic_shapes=dynamic_axes) - ex_model = torch.export.export(jitted_model, tuple(input_list), kwargs=input_dict) + ex_model = torch.export.export( + jitted_model, tuple(input_list), kwargs=input_dict, strict=False + ) ex_model = ex_model.run_decompositions() ex = torch.onnx.dynamo_export(ex_model, *input_list, **input_dict, export_options=options) - ex.save(output) + ex.save(output, model_state=jitted_model.state_dict()) input_names = None else: torch.onnx.export( From e9e81b0d73855e401d53f862a1289488aaa5cd88 Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Thu, 16 May 2024 10:50:13 -0700 Subject: [PATCH 03/12] dynamo_export works for many small models Signed-off-by: Boris Fomitchev --- nemo/collections/asr/models/asr_model.py | 6 +- nemo/collections/asr/models/label_models.py | 4 +- nemo/collections/asr/modules/conv_asr.py | 3 +- .../asr/parts/submodules/jasper.py | 2 +- nemo/core/classes/common.py | 2 +- nemo/core/classes/exportable.py | 61 +++++++++++++++---- nemo/utils/__init__.py | 1 + nemo/utils/cast_utils.py | 11 +++- nemo/utils/export_utils.py | 30 ++++++++- tests/collections/asr/test_asr_exportables.py | 47 +++++++++++--- tests/collections/nlp/test_nlp_exportables.py | 3 - tests/collections/tts/test_tts_exportables.py | 4 -- 12 files changed, 134 insertions(+), 40 deletions(-) diff --git a/nemo/collections/asr/models/asr_model.py b/nemo/collections/asr/models/asr_model.py index 4420318dd416..7df0ae9fb689 100644 --- a/nemo/collections/asr/models/asr_model.py +++ b/nemo/collections/asr/models/asr_model.py @@ -198,7 +198,7 @@ def output_names(self): return get_io_names(otypes, self.disabled_deployment_output_names) def forward_for_export( - self, input, length=None, cache_last_channel=None, cache_last_time=None, cache_last_channel_len=None + self, audio_signal, length=None, cache_last_channel=None, cache_last_time=None, cache_last_channel_len=None ): """ This forward is used when we need to export the model to ONNX format. @@ -217,12 +217,12 @@ def forward_for_export( """ enc_fun = getattr(self.input_module, 'forward_for_export', self.input_module.forward) if cache_last_channel is None: - encoder_output = enc_fun(audio_signal=input, length=length) + encoder_output = enc_fun(audio_signal=audio_signal, length=length) if isinstance(encoder_output, tuple): encoder_output = encoder_output[0] else: encoder_output, length, cache_last_channel, cache_last_time, cache_last_channel_len = enc_fun( - audio_signal=input, + audio_signal=audio_signal, length=length, cache_last_channel=cache_last_channel, cache_last_time=cache_last_time, diff --git a/nemo/collections/asr/models/label_models.py b/nemo/collections/asr/models/label_models.py index 23ab5469e60c..ba5489839db4 100644 --- a/nemo/collections/asr/models/label_models.py +++ b/nemo/collections/asr/models/label_models.py @@ -333,8 +333,8 @@ def output_types(self) -> Optional[Dict[str, NeuralType]]: "embs": NeuralType(('B', 'D'), AcousticEncodedRepresentation()), } - def forward_for_export(self, processed_signal, processed_signal_len): - encoded, length = self.encoder(audio_signal=processed_signal, length=processed_signal_len) + def forward_for_export(self, audio_signal, length): + encoded, length = self.encoder(audio_signal=audio_signal, length=length) logits, embs = self.decoder(encoder_output=encoded, length=length) return logits, embs diff --git a/nemo/collections/asr/modules/conv_asr.py b/nemo/collections/asr/modules/conv_asr.py index 03b94ae0b209..25348dae95f3 100644 --- a/nemo/collections/asr/modules/conv_asr.py +++ b/nemo/collections/asr/modules/conv_asr.py @@ -876,7 +876,8 @@ def forward(self, encoder_output, length=None): embs = [] for layer in self.emb_layers: - pool, emb = layer(pool), layer[: self.emb_id](pool) + emb = layer[: self.emb_id](pool) + pool = layer(pool) embs.append(emb) pool = pool.squeeze(-1) diff --git a/nemo/collections/asr/parts/submodules/jasper.py b/nemo/collections/asr/parts/submodules/jasper.py index e53f6299b08a..c2beb3918ead 100644 --- a/nemo/collections/asr/parts/submodules/jasper.py +++ b/nemo/collections/asr/parts/submodules/jasper.py @@ -478,7 +478,7 @@ def forward_for_export(self, x, lengths): mask = self.make_pad_mask(lengths, max_audio_length=max_len, device=x.device) mask = ~mask # 0 represents value, 1 represents pad x = x.float() # For stable AMP, SE must be computed at fp32. - x.masked_fill_(mask, 0.0) # mask padded values explicitly to 0 + x = x.masked_fill(mask, 0.0) # mask padded values explicitly to 0 y = self._se_pool_step(x, mask) # [B, C, 1] y = y.transpose(1, -1) # [B, 1, C] y = self.fc(y) # [B, 1, C] diff --git a/nemo/core/classes/common.py b/nemo/core/classes/common.py index a7aa9e17b1fd..fe7f040287cc 100644 --- a/nemo/core/classes/common.py +++ b/nemo/core/classes/common.py @@ -1005,7 +1005,7 @@ def __init__( self.ignore_collections = ignore_collections def __call__(self, wrapped): - return self.wrapped_call(wrapped) if is_typecheck_enabled() else self.unwrapped_call(wrapped) + return self.unwrapped_call(wrapped) if is_typecheck_enabled() else self.unwrapped_call(wrapped) def unwrapped_call(self, wrapped): return wrapped diff --git a/nemo/core/classes/exportable.py b/nemo/core/classes/exportable.py index 5e7d5522765c..fe6cbce5bcfa 100644 --- a/nemo/core/classes/exportable.py +++ b/nemo/core/classes/exportable.py @@ -14,18 +14,20 @@ from abc import ABC from typing import Dict, List, Optional, Union +import onnx import torch from pytorch_lightning.core.module import _jit_is_scripting from nemo.core.classes import typecheck from nemo.core.neural_types import NeuralType from nemo.core.utils.neural_type_utils import get_dynamic_axes, get_io_names -from nemo.utils import logging +from nemo.utils import logging, monkeypatched from nemo.utils.export_utils import ( ExportFormat, augment_filename, get_export_format, parse_input_example, + rename_onnx_io, replace_for_export, verify_runtime, verify_torchscript, @@ -177,7 +179,7 @@ def _export( with torch.inference_mode(), torch.no_grad(), torch.jit.optimized_execution(True), _jit_is_scripting(): if input_example is None: - input_example = self.input_module.input_example() + input_example = self.input_module.input_example(max_batch=2) # Remove i/o examples from args we propagate to enclosed Exportables my_args.pop('output') @@ -191,7 +193,9 @@ def _export( input_list, input_dict = parse_input_example(input_example) input_names = self.input_names output_names = self.output_names - output_example = tuple(self.forward(*input_list, **input_dict)) + output_example = self.forward(*input_list, **input_dict) + if not isinstance(output_example, tuple): + output_example = (output_example,) if check_trace: if isinstance(check_trace, bool): @@ -219,16 +223,49 @@ def _export( # dynamic axis is a mapping from input/output_name => list of "dynamic" indices if dynamic_axes is None: dynamic_axes = get_dynamic_axes(self.input_module.input_types_for_export, input_names) - dynamic_axes.update(get_dynamic_axes(self.output_module.output_types_for_export, output_names)) + if use_dynamo: + dynamic_shapes = {} + batch = torch.export.Dim("batch", max=128) + for name, dims in dynamic_axes.items(): + ds = {} + for d in dims: + if d == 0: + ds[d] = batch + # this currently fails, https://github.com/pytorch/pytorch/issues/126127 + # else: + # ds[d] = torch.export.Dim(name + '__' + str(d)) + dynamic_shapes[name] = ds + else: + dynamic_shapes = dynamic_axes if use_dynamo: - options = torch.onnx.ExportOptions(dynamic_shapes=dynamic_axes) - ex_model = torch.export.export( - jitted_model, tuple(input_list), kwargs=input_dict, strict=False - ) - ex_model = ex_model.run_decompositions() - ex = torch.onnx.dynamo_export(ex_model, *input_list, **input_dict, export_options=options) - ex.save(output, model_state=jitted_model.state_dict()) - input_names = None + import onnxscript + + # https://github.com/microsoft/onnxscript/issues/1544 + onnxscript.optimizer.constant_folding._DEFAULT_CONSTANT_FOLD_SIZE_LIMIT = 1024 * 1024 * 64 + + # https://github.com/pytorch/pytorch/issues/126339 + with monkeypatched(torch.nn.RNNBase, "flatten_parameters", lambda *args: None): + print("Running export.export, dynamic shapes:\n", dynamic_shapes) + + ex_model = torch.export.export( + jitted_model, + tuple(input_list), + kwargs=input_dict, + dynamic_shapes=dynamic_shapes, + strict=False, + ) + ex_model = ex_model.run_decompositions() + + print("Running torch.onnx.dynamo_export ...") + + options = torch.onnx.ExportOptions(dynamic_shapes=True, op_level_debug=True) + ex_module = ex_model.module() + ex = torch.onnx.dynamo_export(ex_module, *input_list, **input_dict, export_options=options) + ex.save(output) # , model_state=ex_module.state_dict()) + del ex + # Rename I/O after save - don't want to risk modifying ex._model_proto + rename_onnx_io(output, input_names, output_names) + # input_names=None else: torch.onnx.export( jitted_model, diff --git a/nemo/utils/__init__.py b/nemo/utils/__init__.py index ebf892927723..a1e59646ae13 100644 --- a/nemo/utils/__init__.py +++ b/nemo/utils/__init__.py @@ -21,6 +21,7 @@ avoid_float16_autocast_context, cast_all, cast_tensor, + monkeypatched, ) from nemo.utils.dtype import str_to_dtype from nemo.utils.nemo_logging import Logger as _Logger diff --git a/nemo/utils/cast_utils.py b/nemo/utils/cast_utils.py index 21e977ec494d..d59189cc912e 100644 --- a/nemo/utils/cast_utils.py +++ b/nemo/utils/cast_utils.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from contextlib import nullcontext +from contextlib import contextmanager, nullcontext import torch @@ -91,3 +91,12 @@ def forward(self, *args): return cast_all(ret, from_dtype=torch.float32, to_dtype=from_dtype) else: return self.mod.forward(*args) + + +@contextmanager +def monkeypatched(object, name, patch): + """ Temporarily monkeypatches an object. """ + pre_patched_value = getattr(object, name) + setattr(object, name, patch) + yield object + setattr(object, name, pre_patched_value) diff --git a/nemo/utils/export_utils.py b/nemo/utils/export_utils.py index 58256659bfc5..eda7abd9fe49 100644 --- a/nemo/utils/export_utils.py +++ b/nemo/utils/export_utils.py @@ -177,6 +177,8 @@ def verify_runtime(model, output, input_examples, input_names, check_tolerance=0 for input_example in input_examples: input_list, input_dict = parse_input_example(input_example) output_example = model.forward(*input_list, **input_dict) + if not isinstance(output_example, tuple): + output_example = (output_example,) ort_input = to_onnxrt_input(ort_input_names, input_names, input_dict, input_list) all_good = all_good and run_ort_and_compare(sess, ort_input, output_example, check_tolerance) status = "SUCCESS" if all_good else "FAIL" @@ -221,10 +223,12 @@ def run_ort_and_compare(sess, ort_input, output_example, check_tolerance=0.01): try: if not torch.allclose(tout, expected.cpu(), rtol=check_tolerance, atol=100 * check_tolerance): this_good = False - except Exception: # there may ne size mismatch and it may be OK + except Exception: # there may be size mismatch and it may be OK this_good = False if not this_good: - logging.info(f"onnxruntime results mismatch! PyTorch(expected):\n{expected}\nONNXruntime:\n{tout}") + logging.info( + f"onnxruntime results mismatch! PyTorch(expected, {expected.shape}):\n{expected}\nONNXruntime, {tout.shape}:\n{tout}" + ) all_good = False return all_good @@ -479,3 +483,25 @@ def add_casts_around_norms(model: nn.Module): "MaskedInstanceNorm1d": wrap_module(MaskedInstanceNorm1d, CastToFloatAll), } replace_modules(model, default_cast_replacements) + + +def rename_onnx_io(output, input_names, output_names): + onnx_model = onnx.load(output) + rename_map = {} + for inp, name in zip(onnx_model.graph.input, input_names): + rename_map[inp.name] = name + for out, name in zip(onnx_model.graph.output, output_names): + rename_map[out.name] = name + for n in onnx_model.graph.node: + for inp in range(len(n.input)): + if n.input[inp] in rename_map: + n.input[inp] = rename_map[n.input[inp]] + for out in range(len(n.output)): + if n.output[out] in rename_map: + n.output[out] = rename_map[n.output[out]] + + for i in range(len(onnx_model.graph.input)): + onnx_model.graph.input[i].name = input_names[i] + for i in range(len(onnx_model.graph.output)): + onnx_model.graph.output[i].name = output_names[i] + onnx.save(onnx_model, output) diff --git a/tests/collections/asr/test_asr_exportables.py b/tests/collections/asr/test_asr_exportables.py index 6bb669a70a24..9377f49aa1b6 100644 --- a/tests/collections/asr/test_asr_exportables.py +++ b/tests/collections/asr/test_asr_exportables.py @@ -30,10 +30,6 @@ from nemo.core.utils import numba_utils from nemo.core.utils.numba_utils import __NUMBA_MINIMUM_VERSION__ -# from nemo.core.classes import typecheck -# typecheck.enable_wrapping(enabled=False) - - NUMBA_RNNT_LOSS_AVAILABLE = numba_utils.numba_cuda_is_supported(__NUMBA_MINIMUM_VERSION__) @@ -56,6 +52,8 @@ def test_EncDecCTCModel_export_to_onnx(self): ) onnx_model = onnx.load(filename) onnx.checker.check_model(onnx_model, full_check=True) # throws when failed + assert onnx_model.graph.input[0].name == 'audio_signal' + assert onnx_model.graph.output[0].name == 'logprobs' @pytest.mark.run_only_on('GPU') @pytest.mark.unit @@ -68,6 +66,8 @@ def test_EncDecClassificationModel_export_to_onnx(self, speech_classification_mo ) onnx_model = onnx.load(filename) onnx.checker.check_model(onnx_model, full_check=True) # throws when failed + assert onnx_model.graph.input[0].name == 'audio_signal' + assert onnx_model.graph.output[0].name == 'logits' @pytest.mark.run_only_on('GPU') @pytest.mark.unit @@ -78,6 +78,8 @@ def test_EncDecSpeakerLabelModel_export_to_onnx(self, speaker_label_model): model.export(output=filename) onnx_model = onnx.load(filename) onnx.checker.check_model(onnx_model, full_check=True) # throws when failed + assert onnx_model.graph.input[0].name == 'audio_signal' + assert onnx_model.graph.output[0].name == 'logits' @pytest.mark.run_only_on('GPU') @pytest.mark.unit @@ -88,6 +90,9 @@ def test_EncDecCitrinetModel_export_to_onnx(self, citrinet_model): model.export(output=filename) onnx_model = onnx.load(filename) onnx.checker.check_model(onnx_model, full_check=True) # throws when failed + assert onnx_model.graph.input[0].name == 'audio_signal' + assert onnx_model.graph.input[1].name == 'length' + assert onnx_model.graph.output[0].name == 'logprobs' @pytest.mark.pleasefixme @pytest.mark.run_only_on('GPU') @@ -127,6 +132,9 @@ def test_EncDecCitrinetModel_limited_SE_export_to_onnx(self, citrinet_model): ) onnx_model = onnx.load(filename) onnx.checker.check_model(onnx_model, full_check=True) # throws when failed + assert onnx_model.graph.input[0].name == 'audio_signal' + assert onnx_model.graph.input[1].name == 'length' + assert onnx_model.graph.output[0].name == 'logprobs' @pytest.mark.run_only_on('GPU') @pytest.mark.unit @@ -136,7 +144,7 @@ def test_EncDecRNNTModel_export_to_onnx(self, citrinet_rnnt_model): with tempfile.TemporaryDirectory() as tmpdir: fn = 'citri_rnnt.onnx' filename = os.path.join(tmpdir, fn) - files, descr = model.export(output=filename, verbose=False) + files, descr = model.export(output=filename, dynamic_axes={}, verbose=False) encoder_filename = os.path.join(tmpdir, 'encoder-' + fn) assert files[0] == encoder_filename @@ -145,6 +153,10 @@ def test_EncDecRNNTModel_export_to_onnx(self, citrinet_rnnt_model): onnx.checker.check_model(onnx_model, full_check=True) # throws when failed assert len(onnx_model.graph.input) == 2 assert len(onnx_model.graph.output) == 2 + assert onnx_model.graph.input[0].name == 'audio_signal' + assert onnx_model.graph.input[1].name == 'length' + assert onnx_model.graph.output[0].name == 'outputs' + assert onnx_model.graph.output[1].name == 'encoded_lengths' decoder_joint_filename = os.path.join(tmpdir, 'decoder_joint-' + fn) assert files[1] == decoder_joint_filename @@ -159,12 +171,21 @@ def test_EncDecRNNTModel_export_to_onnx(self, citrinet_rnnt_model): # enc_logits + (all decoder inputs - state tuple) + flattened state list assert len(onnx_model.graph.input) == (1 + (len(input_examples) - 1) + num_states) + assert onnx_model.graph.input[0].name == 'encoder_outputs' + assert onnx_model.graph.input[1].name == 'targets' + assert onnx_model.graph.input[2].name == 'target_length' if num_states > 0: for idx, ip in enumerate(onnx_model.graph.input[3:]): assert ip.name == "input_" + state_name + '_' + str(idx + 1) assert len(onnx_model.graph.output) == (len(input_examples) - 1) + num_states + assert onnx_model.graph.output[0].name == 'outputs' + assert onnx_model.graph.output[1].name == 'prednet_lengths' + + if num_states > 0: + for idx, op in enumerate(onnx_model.graph.output[2:]): + assert op.name == "output_" + state_name + '_' + str(idx + 1) @pytest.mark.run_only_on('GPU') @pytest.mark.unit @@ -185,6 +206,8 @@ def test_EncDecRNNTModel_export_to_ts(self, citrinet_rnnt_model): assert ts_encoder is not None arguments = ts_encoder.forward.schema.arguments[1:] # First value is `self` + assert arguments[0].name == 'audio_signal' + assert arguments[1].name == 'length' decoder_joint_filename = os.path.join(tmpdir, 'decoder_joint-' + fn) assert files[1] == decoder_joint_filename @@ -202,6 +225,13 @@ def test_EncDecRNNTModel_export_to_ts(self, citrinet_rnnt_model): # enc_logits + (all decoder inputs - state tuple) + flattened state list assert len(ts_decoder_joint_args) == (1 + (len(input_examples) - 1) + num_states) + assert ts_decoder_joint_args[0].name == 'encoder_outputs' + assert ts_decoder_joint_args[1].name == 'targets' + assert ts_decoder_joint_args[2].name == 'target_length' + + if num_states > 0: + for idx, ip in enumerate(ts_decoder_joint_args[3:]): + assert ip.name == "input_" + state_name + '_' + str(idx + 1) @pytest.mark.run_only_on('GPU') @pytest.mark.unit @@ -235,6 +265,8 @@ def test_EncDecCTCModel_adapted_export_to_onnx(self): ) onnx_model = onnx.load(filename) onnx.checker.check_model(onnx_model, full_check=True) # throws when failed + assert onnx_model.graph.input[0].name == 'audio_signal' + assert onnx_model.graph.output[0].name == 'logprobs' def setup_method(self): self.preprocessor = { @@ -638,8 +670,3 @@ def squeezeformer_model(): ) conformer_model = EncDecCTCModel(cfg=modelConfig) return conformer_model - - -if __name__ == "__main__": - t = TestExportable() - t.test_EncDecClassificationModel_export_to_onnx(speech_classification_model()) diff --git a/tests/collections/nlp/test_nlp_exportables.py b/tests/collections/nlp/test_nlp_exportables.py index 3181e1ce0c46..c0b97caea4ed 100644 --- a/tests/collections/nlp/test_nlp_exportables.py +++ b/tests/collections/nlp/test_nlp_exportables.py @@ -20,9 +20,6 @@ import torch import wget from omegaconf import DictConfig, OmegaConf -from nemo.core.classes import typecheck - -typecheck.enable_wrapping(enabled=False) from nemo.collections import nlp as nemo_nlp from nemo.collections.nlp.models import IntentSlotClassificationModel diff --git a/tests/collections/tts/test_tts_exportables.py b/tests/collections/tts/test_tts_exportables.py index 2569d708e235..67f016b0c2af 100644 --- a/tests/collections/tts/test_tts_exportables.py +++ b/tests/collections/tts/test_tts_exportables.py @@ -18,10 +18,6 @@ import torch from omegaconf import OmegaConf -from nemo.core.classes import typecheck - -typecheck.enable_wrapping(enabled=False) - from nemo.collections.tts.models import FastPitchModel, HifiGanModel, RadTTSModel from nemo.utils.app_state import AppState From 7907435612953524f3534fbbda8143bf32d01d54 Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Thu, 16 May 2024 17:26:24 -0700 Subject: [PATCH 04/12] External weights behaviour fixed Signed-off-by: Boris Fomitchev --- nemo/core/classes/exportable.py | 30 +++++++++++++++++++++--------- 1 file changed, 21 insertions(+), 9 deletions(-) diff --git a/nemo/core/classes/exportable.py b/nemo/core/classes/exportable.py index fe6cbce5bcfa..216742ad05b5 100644 --- a/nemo/core/classes/exportable.py +++ b/nemo/core/classes/exportable.py @@ -202,7 +202,7 @@ def _export( check_trace_input = [input_example] else: check_trace_input = check_trace - jitted_model = self + if format == ExportFormat.TORCHSCRIPT: jitted_model = torch.jit.trace_module( self, @@ -245,10 +245,10 @@ def _export( # https://github.com/pytorch/pytorch/issues/126339 with monkeypatched(torch.nn.RNNBase, "flatten_parameters", lambda *args: None): - print("Running export.export, dynamic shapes:\n", dynamic_shapes) + logging.info("Running export.export, dynamic shapes:{dynamo_export}\n") ex_model = torch.export.export( - jitted_model, + self, tuple(input_list), kwargs=input_dict, dynamic_shapes=dynamic_shapes, @@ -256,19 +256,31 @@ def _export( ) ex_model = ex_model.run_decompositions() - print("Running torch.onnx.dynamo_export ...") - options = torch.onnx.ExportOptions(dynamic_shapes=True, op_level_debug=True) - ex_module = ex_model.module() - ex = torch.onnx.dynamo_export(ex_module, *input_list, **input_dict, export_options=options) - ex.save(output) # , model_state=ex_module.state_dict()) + # We have to use different types of arguments for dynamo_export to achieve + # same external weights behaviour as onnx.export : + # https://github.com/pytorch/pytorch/issues/126479 + # https://github.com/pytorch/pytorch/issues/126269 + mem_params = sum([param.nelement() * param.element_size() for param in self.parameters()]) + mem_bufs = sum([buf.nelement() * buf.element_size() for buf in self.buffers()]) + mem = mem_params + mem_bufs + + if mem > 2 * 1000 * 1000 * 1000: + model_state = ex_model.state_dict + else: + model_state = None + ex_model = ex_model.module() + ex = torch.onnx.dynamo_export(ex_model, *input_list, **input_dict, export_options=options) + ex.save(output, model_state=model_state) + del ex + del ex_model # Rename I/O after save - don't want to risk modifying ex._model_proto rename_onnx_io(output, input_names, output_names) # input_names=None else: torch.onnx.export( - jitted_model, + self, input_example, output, input_names=input_names, From 732c1198f140da3863c88ea69daa63ddf21ba9e2 Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Fri, 7 Jun 2024 11:56:42 -0700 Subject: [PATCH 05/12] Cleanup Signed-off-by: Boris Fomitchev --- Dockerfile | 29 ++++++--------- .../asr/modules/conformer_encoder.py | 3 +- nemo/collections/asr/modules/conv_asr.py | 3 +- .../megatron/retro_dataset.py | 6 ++- nemo/collections/tts/modules/transformer.py | 20 ++++++---- nemo/core/classes/common.py | 7 ++-- nemo/core/classes/exportable.py | 37 +++++++++---------- nemo/utils/export_utils.py | 4 +- tests/collections/asr/test_asr_exportables.py | 2 +- tests/collections/nlp/test_nlp_exportables.py | 2 + tests/collections/tts/test_tts_exportables.py | 6 +-- 11 files changed, 58 insertions(+), 61 deletions(-) diff --git a/Dockerfile b/Dockerfile index c834fcfbbf48..c27048784244 100644 --- a/Dockerfile +++ b/Dockerfile @@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -ARG BASE_IMAGE=nvcr.io/nvidia/tritonserver:24.04-py3 +ARG BASE_IMAGE=nvcr.io/nvidia/pytorch:24.01-py3 # build an image that includes only the nemo dependencies, ensures that dependencies # are included first for optimal caching, and useful for building a development @@ -61,17 +61,20 @@ RUN apt-get update && \ libgts-dev && \ rm -rf /var/lib/apt/lists/* -RUN pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu121 -RUN pip3 install onnxscript==0.1.0.dev20240430 - WORKDIR /workspace/ # Install megatron core, this can be removed once 0.3 pip package is released # We leave it here in case we need to work off of a specific commit in main RUN git clone https://github.com/NVIDIA/Megatron-LM.git && \ cd Megatron-LM && \ + git checkout 36e9b6bf3d8034b10c9bbd9fc357c2df2bd1515c && \ + git cherry-pick -n e69187bc3679ea5841030a165d587bb48b56ee77 && \ pip install . -RUN pip3 install packaging +# Performance optimizations for distributed optimizer: https://github.com/NVIDIA/apex/pull/1771 +RUN git clone https://github.com/NVIDIA/apex.git && \ + cd apex && \ + git checkout f058162b215791b15507bb542f22ccfde49c872d && \ + pip install -v --no-build-isolation --disable-pip-version-check --no-cache-dir --config-settings "--build-option=--cpp_ext --cuda_ext --fast_layer_norm --distributed_adam --deprecated_fused_adam" ./ # Transformer Engine 1.2.0 RUN git clone https://github.com/NVIDIA/TransformerEngine.git && \ @@ -81,12 +84,6 @@ RUN git clone https://github.com/NVIDIA/TransformerEngine.git && \ git submodule init && git submodule update && \ NVTE_FRAMEWORK=pytorch NVTE_WITH_USERBUFFERS=1 MPI_HOME=/usr/local/mpi pip install . -# Performance optimizations for distributed optimizer: https://github.com/NVIDIA/apex/pull/1771 -RUN git clone https://github.com/NVIDIA/apex.git && \ - cd apex && \ - sed -i '178d' setup.py && \ - pip install -v --no-build-isolation --disable-pip-version-check --no-cache-dir --config-settings "--build-option=--cpp_ext --cuda_ext --fast_layer_norm --group_norm --distributed_adam --deprecated_fused_adam" ./ - WORKDIR /tmp/ # uninstall stuff from base container @@ -136,8 +133,6 @@ RUN for f in $(ls requirements*.txt); do pip3 install --disable-pip-version-chec RUN pip install flash-attn # install numba for latest containers RUN pip install numba>=0.57.1 -# install ammo -RUN pip install nvidia-ammo~=0.9.0 --extra-index-url https://pypi.nvidia.com --no-cache-dir # copy nemo source into a scratch image FROM scratch as nemo-src @@ -155,13 +150,11 @@ RUN /usr/bin/test -n "$NEMO_VERSION" && \ # Install NeMo RUN --mount=from=nemo-src,target=/tmp/nemo,rw cd /tmp/nemo && pip install ".[all]" -RUN apt-get install -y python3 -RUN alias python=python3 # Check install -RUN python3 -c "import nemo.collections.nlp as nemo_nlp" && \ - python3 -c "import nemo.collections.tts as nemo_tts" && \ - python3 -c "import nemo_text_processing.text_normalization as text_normalization" +RUN python -c "import nemo.collections.nlp as nemo_nlp" && \ + python -c "import nemo.collections.tts as nemo_tts" && \ + python -c "import nemo_text_processing.text_normalization as text_normalization" # copy scripts/examples/tests into container for end user diff --git a/nemo/collections/asr/modules/conformer_encoder.py b/nemo/collections/asr/modules/conformer_encoder.py index b9642b3ea5dc..b31aecdc880a 100644 --- a/nemo/collections/asr/modules/conformer_encoder.py +++ b/nemo/collections/asr/modules/conformer_encoder.py @@ -495,6 +495,7 @@ def streaming_post_process(self, rets, keep_all_outputs=True): def forward( self, audio_signal, length, cache_last_channel=None, cache_last_time=None, cache_last_channel_len=None ): + self.update_max_seq_length(seq_length=audio_signal.size(2), device=audio_signal.device) return self.forward_internal( audio_signal, length, @@ -506,8 +507,6 @@ def forward( def forward_internal( self, audio_signal, length, cache_last_channel=None, cache_last_time=None, cache_last_channel_len=None ): - self.update_max_seq_length(seq_length=audio_signal.size(2), device=audio_signal.device) - if length is None: length = audio_signal.new_full( (audio_signal.size(0),), audio_signal.size(-1), dtype=torch.int64, device=audio_signal.device diff --git a/nemo/collections/asr/modules/conv_asr.py b/nemo/collections/asr/modules/conv_asr.py index 25348dae95f3..03b94ae0b209 100644 --- a/nemo/collections/asr/modules/conv_asr.py +++ b/nemo/collections/asr/modules/conv_asr.py @@ -876,8 +876,7 @@ def forward(self, encoder_output, length=None): embs = [] for layer in self.emb_layers: - emb = layer[: self.emb_id](pool) - pool = layer(pool) + pool, emb = layer(pool), layer[: self.emb_id](pool) embs.append(emb) pool = pool.squeeze(-1) diff --git a/nemo/collections/nlp/data/language_modeling/megatron/retro_dataset.py b/nemo/collections/nlp/data/language_modeling/megatron/retro_dataset.py index 3cec32760328..0f8d3410398d 100644 --- a/nemo/collections/nlp/data/language_modeling/megatron/retro_dataset.py +++ b/nemo/collections/nlp/data/language_modeling/megatron/retro_dataset.py @@ -40,13 +40,15 @@ MultiSplitGPTDatasetConfig, ) from megatron.core.datasets.retro.query.retro_dataset import get_retro_datasets + from megatron.core.datasets.utils import get_blend_from_list from megatron.core.models.retro import RetroConfig from nemo.collections.nlp.modules.common.megatron.utils import get_ltor_masks_and_position_ids HAVE_MEGATRON_CORE = True -except (ImportError, ModuleNotFoundError) as e: +except (ImportError, ModuleNotFoundError): + HAVE_MEGATRON_CORE = False @@ -189,7 +191,7 @@ def is_dataset_built_on_rank(): data_config = MultiSplitGPTDatasetConfig( random_seed=cfg.seed, sequence_length=cfg.data.seq_length, - blend=cfg.data.data_prefix, + blend=get_blend_from_list(cfg.data.data_prefix), split=cfg.data.splits_string, split_preprocessing=cfg.data.retro_data.retro_split_preprocessing, path_to_cache=None, diff --git a/nemo/collections/tts/modules/transformer.py b/nemo/collections/tts/modules/transformer.py index 728b583919ff..2243d7d1c317 100644 --- a/nemo/collections/tts/modules/transformer.py +++ b/nemo/collections/tts/modules/transformer.py @@ -125,13 +125,17 @@ def _forward(self, inp, attn_mask=None, conditioning=None): head_q, head_k, head_v = torch.chunk(self.qkv_net(inp), 3, dim=2) - head_q = head_q.view(inp.size(0), inp.size(1), n_head, d_head) - head_k = head_k.view(inp.size(0), inp.size(1), n_head, d_head) - head_v = head_v.view(inp.size(0), inp.size(1), n_head, d_head) + s0 = inp.size(0) + s1 = inp.size(1) + s2 = s0 * n_head - q = head_q.permute(2, 0, 1, 3).reshape(-1, inp.size(1), d_head) - k = head_k.permute(2, 0, 1, 3).reshape(-1, inp.size(1), d_head) - v = head_v.permute(2, 0, 1, 3).reshape(-1, inp.size(1), d_head) + head_q = head_q.view(s0, s1, n_head, d_head) + head_k = head_k.view(s0, s1, n_head, d_head) + head_v = head_v.view(s0, s1, n_head, d_head) + + q = head_q.permute(2, 0, 1, 3).reshape(s2, s1, d_head) + k = head_k.permute(2, 0, 1, 3).reshape(s2, s1, d_head) + v = head_v.permute(2, 0, 1, 3).reshape(s2, s1, d_head) attn_score = torch.bmm(q, k.transpose(1, 2)) attn_score.mul_(self.scale) @@ -145,8 +149,8 @@ def _forward(self, inp, attn_mask=None, conditioning=None): attn_prob = self.dropatt(attn_prob) attn_vec = torch.bmm(attn_prob, v) - attn_vec = attn_vec.view(n_head, inp.size(0), inp.size(1), d_head) - attn_vec = attn_vec.permute(1, 2, 0, 3).contiguous().view(inp.size(0), inp.size(1), n_head * d_head) + attn_vec = attn_vec.view(n_head, s0, s1, d_head) + attn_vec = attn_vec.permute(1, 2, 0, 3).contiguous().view(s0, s1, n_head * d_head) # linear projection attn_out = self.o_net(attn_vec) diff --git a/nemo/core/classes/common.py b/nemo/core/classes/common.py index fe7f040287cc..34cb680db37b 100644 --- a/nemo/core/classes/common.py +++ b/nemo/core/classes/common.py @@ -1005,7 +1005,7 @@ def __init__( self.ignore_collections = ignore_collections def __call__(self, wrapped): - return self.unwrapped_call(wrapped) if is_typecheck_enabled() else self.unwrapped_call(wrapped) + return self.wrapped_call(wrapped) if is_typecheck_enabled() else self.unwrapped_call(wrapped) def unwrapped_call(self, wrapped): return wrapped @@ -1128,7 +1128,6 @@ def disable_semantic_checks(): def enable_wrapping(enabled: bool = True): typecheck.set_typecheck_enabled(enabled) if enabled: - typecheck.__call__.__code__ = nemo.core.classes.common.typecheck.wrapped_call.__code__ + typecheck.__call__ = nemo.core.classes.common.typecheck.wrapped_call else: - typecheck.__call__.__code__ = nemo.core.classes.common.typecheck.unwrapped_call.__code__ - print(typecheck.__call__) + typecheck.__call__ = nemo.core.classes.common.typecheck.unwrapped_call diff --git a/nemo/core/classes/exportable.py b/nemo/core/classes/exportable.py index 216742ad05b5..380ee819d5f9 100644 --- a/nemo/core/classes/exportable.py +++ b/nemo/core/classes/exportable.py @@ -14,7 +14,6 @@ from abc import ABC from typing import Dict, List, Optional, Union -import onnx import torch from pytorch_lightning.core.module import _jit_is_scripting @@ -125,6 +124,7 @@ def export( check_tolerance=check_tolerance, export_modules_as_functions=export_modules_as_functions, keep_initializers_as_inputs=keep_initializers_as_inputs, + use_dynamo=use_dynamo, ) # Propagate input example (default scenario, may need to be overriden) if input_example is not None: @@ -166,7 +166,7 @@ def _export( # Pytorch's default opset version is too low, using reasonable latest one if onnx_opset_version is None: - onnx_opset_version = 16 + onnx_opset_version = 17 try: # Disable typechecks @@ -225,15 +225,15 @@ def _export( dynamic_axes = get_dynamic_axes(self.input_module.input_types_for_export, input_names) if use_dynamo: dynamic_shapes = {} - batch = torch.export.Dim("batch", max=128) + batch = torch.export.Dim("batch") for name, dims in dynamic_axes.items(): ds = {} for d in dims: if d == 0: ds[d] = batch - # this currently fails, https://github.com/pytorch/pytorch/issues/126127 - # else: - # ds[d] = torch.export.Dim(name + '__' + str(d)) + # this currently has issues: https://github.com/pytorch/pytorch/issues/126127 + else: + ds[d] = torch.export.Dim(name + '__' + str(d)) dynamic_shapes[name] = ds else: dynamic_shapes = dynamic_axes @@ -245,18 +245,8 @@ def _export( # https://github.com/pytorch/pytorch/issues/126339 with monkeypatched(torch.nn.RNNBase, "flatten_parameters", lambda *args: None): - logging.info("Running export.export, dynamic shapes:{dynamo_export}\n") + logging.info(f"Running export.export, dynamic shapes:{dynamic_shapes}\n") - ex_model = torch.export.export( - self, - tuple(input_list), - kwargs=input_dict, - dynamic_shapes=dynamic_shapes, - strict=False, - ) - ex_model = ex_model.run_decompositions() - - options = torch.onnx.ExportOptions(dynamic_shapes=True, op_level_debug=True) # We have to use different types of arguments for dynamo_export to achieve # same external weights behaviour as onnx.export : # https://github.com/pytorch/pytorch/issues/126479 @@ -266,10 +256,20 @@ def _export( mem = mem_params + mem_bufs if mem > 2 * 1000 * 1000 * 1000: + ex_model = torch.export.export( + self, + tuple(input_list), + kwargs=input_dict, + dynamic_shapes=dynamic_shapes, + strict=False, + ) + ex_model = ex_model.run_decompositions() model_state = ex_model.state_dict else: model_state = None - ex_model = ex_model.module() + ex_model = self + + options = torch.onnx.ExportOptions(dynamic_shapes=True, op_level_debug=True) ex = torch.onnx.dynamo_export(ex_model, *input_list, **input_dict, export_options=options) ex.save(output, model_state=model_state) @@ -277,7 +277,6 @@ def _export( del ex_model # Rename I/O after save - don't want to risk modifying ex._model_proto rename_onnx_io(output, input_names, output_names) - # input_names=None else: torch.onnx.export( self, diff --git a/nemo/utils/export_utils.py b/nemo/utils/export_utils.py index eda7abd9fe49..c2da09101523 100644 --- a/nemo/utils/export_utils.py +++ b/nemo/utils/export_utils.py @@ -500,8 +500,8 @@ def rename_onnx_io(output, input_names, output_names): if n.output[out] in rename_map: n.output[out] = rename_map[n.output[out]] - for i in range(len(onnx_model.graph.input)): + for i in range(len(input_names)): onnx_model.graph.input[i].name = input_names[i] - for i in range(len(onnx_model.graph.output)): + for i in range(len(output_names)): onnx_model.graph.output[i].name = output_names[i] onnx.save(onnx_model, output) diff --git a/tests/collections/asr/test_asr_exportables.py b/tests/collections/asr/test_asr_exportables.py index 9377f49aa1b6..86bcacab86db 100644 --- a/tests/collections/asr/test_asr_exportables.py +++ b/tests/collections/asr/test_asr_exportables.py @@ -144,7 +144,7 @@ def test_EncDecRNNTModel_export_to_onnx(self, citrinet_rnnt_model): with tempfile.TemporaryDirectory() as tmpdir: fn = 'citri_rnnt.onnx' filename = os.path.join(tmpdir, fn) - files, descr = model.export(output=filename, dynamic_axes={}, verbose=False) + files, descr = model.export(output=filename, verbose=False) encoder_filename = os.path.join(tmpdir, 'encoder-' + fn) assert files[0] == encoder_filename diff --git a/tests/collections/nlp/test_nlp_exportables.py b/tests/collections/nlp/test_nlp_exportables.py index c0b97caea4ed..f533c4a36dfd 100644 --- a/tests/collections/nlp/test_nlp_exportables.py +++ b/tests/collections/nlp/test_nlp_exportables.py @@ -20,7 +20,9 @@ import torch import wget from omegaconf import DictConfig, OmegaConf +from nemo.core.classes import typecheck +typecheck.enable_wrapping(enabled=False) from nemo.collections import nlp as nemo_nlp from nemo.collections.nlp.models import IntentSlotClassificationModel from nemo.collections.nlp.modules.common import ( diff --git a/tests/collections/tts/test_tts_exportables.py b/tests/collections/tts/test_tts_exportables.py index 67f016b0c2af..68c9a55e1f8a 100644 --- a/tests/collections/tts/test_tts_exportables.py +++ b/tests/collections/tts/test_tts_exportables.py @@ -26,7 +26,7 @@ def fastpitch_model(): model = FastPitchModel.from_pretrained(model_name="tts_en_fastpitch") model.export_config['enable_volume'] = True - model.export_config['enable_ragged_batches'] = True + # model.export_config['enable_ragged_batches'] = True return model @@ -65,7 +65,7 @@ def test_FastPitchModel_export_to_onnx(self, fastpitch_model): model = fastpitch_model.cuda() with tempfile.TemporaryDirectory() as tmpdir: filename = os.path.join(tmpdir, 'fp.onnx') - model.export(output=filename, verbose=True, onnx_opset_version=14, check_trace=True) + model.export(output=filename, verbose=True, onnx_opset_version=14, check_trace=True, use_dynamo=True) @pytest.mark.with_downloads() @pytest.mark.run_only_on('GPU') @@ -75,7 +75,7 @@ def test_HifiGanModel_export_to_onnx(self, hifigan_model): assert hifigan_model.generator is not None with tempfile.TemporaryDirectory() as tmpdir: filename = os.path.join(tmpdir, 'hfg.onnx') - model.export(output=filename, verbose=True, check_trace=True) + model.export(output=filename, use_dynamo=True, verbose=True, check_trace=True) @pytest.mark.pleasefixme @pytest.mark.run_only_on('GPU') From dd21b74c3bf08d9d62a90454dfeb78f7194503ae Mon Sep 17 00:00:00 2001 From: borisfom Date: Thu, 13 Jun 2024 05:12:41 +0000 Subject: [PATCH 06/12] Apply isort and black reformatting Signed-off-by: borisfom --- nemo/collections/asr/models/asr_model.py | 2 +- nemo/collections/asr/parts/preprocessing/features.py | 8 ++++---- nemo/collections/asr/parts/submodules/jasper.py | 4 ++-- nemo/collections/tts/modules/transformer.py | 2 +- nemo/utils/cast_utils.py | 2 +- nemo/utils/export_utils.py | 4 ++-- 6 files changed, 11 insertions(+), 11 deletions(-) diff --git a/nemo/collections/asr/models/asr_model.py b/nemo/collections/asr/models/asr_model.py index 4f8e82293d48..24e300aff112 100644 --- a/nemo/collections/asr/models/asr_model.py +++ b/nemo/collections/asr/models/asr_model.py @@ -240,7 +240,7 @@ def output_names(self): if getattr(self.input_module, 'export_cache_support', False): in_types = self.input_module.output_types otypes = {n: t for (n, t) in list(otypes.items())[:1]} - for (n, t) in list(in_types.items())[1:]: + for n, t in list(in_types.items())[1:]: otypes[n] = t return get_io_names(otypes, self.disabled_deployment_output_names) diff --git a/nemo/collections/asr/parts/preprocessing/features.py b/nemo/collections/asr/parts/preprocessing/features.py index 51fc6c2418f7..d70737b5135b 100644 --- a/nemo/collections/asr/parts/preprocessing/features.py +++ b/nemo/collections/asr/parts/preprocessing/features.py @@ -131,7 +131,7 @@ def clean_spectrogram_batch(spectrogram: torch.Tensor, spectrogram_len: torch.Te def splice_frames(x, frame_splicing): - """ Stacks frames together across feature dim + """Stacks frames together across feature dim input is batch_size, feature_dim, num_frames output is batch_size, feature_dim*frame_splicing, num_frames @@ -261,7 +261,7 @@ def __init__( highfreq=None, log=True, log_zero_guard_type="add", - log_zero_guard_value=2 ** -24, + log_zero_guard_value=2**-24, dither=CONSTANT, pad_to=16, max_duration=16.7, @@ -511,7 +511,7 @@ def __init__( highfreq: Optional[float] = None, log: bool = True, log_zero_guard_type: str = "add", - log_zero_guard_value: Union[float, str] = 2 ** -24, + log_zero_guard_value: Union[float, str] = 2**-24, dither: float = 1e-5, window: str = "hann", pad_to: int = 0, @@ -582,7 +582,7 @@ def __init__( @property def filter_banks(self): - """ Matches the analogous class """ + """Matches the analogous class""" return self._mel_spec_extractor.mel_scale.fb def _resolve_log_zero_guard_value(self, dtype: torch.dtype) -> float: diff --git a/nemo/collections/asr/parts/submodules/jasper.py b/nemo/collections/asr/parts/submodules/jasper.py index c2beb3918ead..78f81ee555bc 100644 --- a/nemo/collections/asr/parts/submodules/jasper.py +++ b/nemo/collections/asr/parts/submodules/jasper.py @@ -510,8 +510,8 @@ def _se_pool_step(self, x, mask): return y def set_max_len(self, max_len, seq_range=None): - """ Sets maximum input length. - Pre-calculates internal seq_range mask. + """Sets maximum input length. + Pre-calculates internal seq_range mask. """ self.max_len = max_len if seq_range is None: diff --git a/nemo/collections/tts/modules/transformer.py b/nemo/collections/tts/modules/transformer.py index 2243d7d1c317..25c177d221cc 100644 --- a/nemo/collections/tts/modules/transformer.py +++ b/nemo/collections/tts/modules/transformer.py @@ -102,7 +102,7 @@ def __init__(self, n_head, d_model, d_head, dropout, dropatt=0.1, pre_lnorm=Fals self.n_head = n_head self.d_model = d_model self.d_head = d_head - self.scale = 1 / (d_head ** 0.5) + self.scale = 1 / (d_head**0.5) self.pre_lnorm = pre_lnorm self.qkv_net = nn.Linear(d_model, 3 * n_head * d_head) diff --git a/nemo/utils/cast_utils.py b/nemo/utils/cast_utils.py index d59189cc912e..a7960be4cc4d 100644 --- a/nemo/utils/cast_utils.py +++ b/nemo/utils/cast_utils.py @@ -95,7 +95,7 @@ def forward(self, *args): @contextmanager def monkeypatched(object, name, patch): - """ Temporarily monkeypatches an object. """ + """Temporarily monkeypatches an object.""" pre_patched_value = getattr(object, name) setattr(object, name, patch) yield object diff --git a/nemo/utils/export_utils.py b/nemo/utils/export_utils.py index c2da09101523..c44530944051 100644 --- a/nemo/utils/export_utils.py +++ b/nemo/utils/export_utils.py @@ -383,7 +383,7 @@ def replace_MatchedScaleMaskSoftmax(n: nn.Module) -> Optional[nn.Linear]: def wrap_module(BaseT: Type[nn.Module], DestT: Type[nn.Module]) -> Callable[[nn.Module], Optional[nn.Module]]: """ - Generic function generator to replace BaseT module with DestT wrapper. + Generic function generator to replace BaseT module with DestT wrapper. Args: BaseT : module type to replace DestT : destination module type @@ -450,7 +450,7 @@ def script_module(m: nn.Module): def replace_for_export(model: nn.Module) -> nn.Module: """ - Top-level function to replace 'default set' of modules in model, called from _prepare_for_export. + Top-level function to replace 'default set' of modules in model, called from _prepare_for_export. NOTE: This occurs in place, if you want to preserve model then make sure to copy it first. Args: model : top level module From 6c299d6e078d0d2927aae868d1d6e2c22fbb5d8a Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Wed, 12 Jun 2024 22:14:58 -0700 Subject: [PATCH 07/12] print cleaned up Signed-off-by: Boris Fomitchev --- .../data/language_modeling/megatron/retro_dataset.py | 3 +-- .../multimodal/Multimodal Data Preparation.ipynb | 12 ++++++++---- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/nemo/collections/nlp/data/language_modeling/megatron/retro_dataset.py b/nemo/collections/nlp/data/language_modeling/megatron/retro_dataset.py index 5d227ff23342..c63c6faaddb8 100644 --- a/nemo/collections/nlp/data/language_modeling/megatron/retro_dataset.py +++ b/nemo/collections/nlp/data/language_modeling/megatron/retro_dataset.py @@ -47,8 +47,7 @@ HAVE_MEGATRON_CORE = True -except (ImportError, ModuleNotFoundError) as e: - print(e) +except (ImportError, ModuleNotFoundError): HAVE_MEGATRON_CORE = False diff --git a/tutorials/multimodal/Multimodal Data Preparation.ipynb b/tutorials/multimodal/Multimodal Data Preparation.ipynb index b3a38b8b5ec2..fb7bdee1402f 100644 --- a/tutorials/multimodal/Multimodal Data Preparation.ipynb +++ b/tutorials/multimodal/Multimodal Data Preparation.ipynb @@ -14,7 +14,8 @@ ], "metadata": { "collapsed": false - } + }, + "id": "88adf24c9f52084f" }, { "cell_type": "code", @@ -56,7 +57,8 @@ ], "metadata": { "collapsed": false - } + }, + "id": "bb0c8d61cdb92704" }, { "attachments": {}, @@ -207,7 +209,8 @@ }, "source": [ "Note: In this dummy dataset, you will likely see a success rate of 1.000 (no failures). However, for read datasets, the success rate will always be much less than 1.000" - ] + ], + "id": "eaffa123548d6a5e" }, { "attachments": {}, @@ -649,7 +652,8 @@ "\n", "After this, you can proceed with Stage 3 of the tutorial.\n", "Note: if you can use a script to create folders with exactly `tar_chunk_size` (1000 in the tutorial) image-text pairs, and create multiple tarfiles each with `tar_chunk_size` pairs of data, then you can skip Stage 3 and proceed with Stage 4 of the tutorial." - ] + ], + "id": "217dacb92b870798" } ], "metadata": { From f425d8a3ef0c8546fa5285930afb73fd65d7953d Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Thu, 13 Jun 2024 15:26:47 -0700 Subject: [PATCH 08/12] Added overloadable dynamic_shapes_for_export Signed-off-by: Boris Fomitchev --- nemo/core/classes/exportable.py | 21 ++++----------- nemo/core/utils/neural_type_utils.py | 39 ++++++++++++++++++---------- 2 files changed, 31 insertions(+), 29 deletions(-) diff --git a/nemo/core/classes/exportable.py b/nemo/core/classes/exportable.py index a4803d368e80..7fa950f2c2c0 100644 --- a/nemo/core/classes/exportable.py +++ b/nemo/core/classes/exportable.py @@ -222,21 +222,7 @@ def _export( elif format == ExportFormat.ONNX: # dynamic axis is a mapping from input/output_name => list of "dynamic" indices if dynamic_axes is None: - dynamic_axes = get_dynamic_axes(self.input_module.input_types_for_export, input_names) - if use_dynamo: - dynamic_shapes = {} - batch = torch.export.Dim("batch") - for name, dims in dynamic_axes.items(): - ds = {} - for d in dims: - if d == 0: - ds[d] = batch - # this currently has issues: https://github.com/pytorch/pytorch/issues/126127 - else: - ds[d] = torch.export.Dim(name + '__' + str(d)) - dynamic_shapes[name] = ds - else: - dynamic_shapes = dynamic_axes + dynamic_axes = self.dynamic_shapes_for_export(use_dynamo) if use_dynamo: import onnxscript @@ -260,7 +246,7 @@ def _export( self, tuple(input_list), kwargs=input_dict, - dynamic_shapes=dynamic_shapes, + dynamic_shapes=dynamic_axes, strict=False, ) ex_model = ex_model.run_decompositions() @@ -348,6 +334,9 @@ def input_types_for_export(self) -> Optional[Dict[str, NeuralType]]: def output_types_for_export(self): return self.output_types + def dynamic_shapes_for_export(self, use_dynamo=False): + return get_dynamic_axes(self.input_module.input_types_for_export, self.input_names, use_dynamo) + def get_export_subnet(self, subnet=None): """ Returns Exportable subnet model/module to export diff --git a/nemo/core/utils/neural_type_utils.py b/nemo/core/utils/neural_type_utils.py index 98ae442b9aa7..c540a8b1912b 100644 --- a/nemo/core/utils/neural_type_utils.py +++ b/nemo/core/utils/neural_type_utils.py @@ -30,19 +30,19 @@ def get_io_names(types: Optional[Dict[str, NeuralType]], disabled_names: List[st def extract_dynamic_axes(name: str, ntype: NeuralType): """ - This method will extract BATCH and TIME dimension ids from each provided input/output name argument. - - For example, if module/model accepts argument named "input_signal" with type corresponding to [Batch, Time, Dim] - shape, then the returned result should contain "input_signal" -> [0, 1] because Batch and Time are dynamic axes - as they can change from call to call during inference. - - Args: - name: Name of input or output parameter - ntype: Corresponding Neural Type - - Returns: + This method will extract BATCH and TIME dimension ids from each provided input/output name argument. - """ + For example, if module/model accepts argument named "input_signal" with type corresponding to [Batch, Time, Dim] + shape, then the returned result should contain "input_signal" -> [0, 1] because Batch and Time are dynamic axes + as they can change from call to call during inference. + + Args: + name: Name of input or output parameter + ntype: Corresponding Neural Type + + Returns: + + """ def unpack_nested_neural_type(neural_type): if type(neural_type) in (list, tuple): @@ -60,10 +60,23 @@ def unpack_nested_neural_type(neural_type): return dynamic_axes -def get_dynamic_axes(types, names): +def get_dynamic_axes(types, names, use_dynamo=False): dynamic_axes = defaultdict(list) if names is not None: for name in names: if name in types: dynamic_axes.update(extract_dynamic_axes(name, types[name])) + if use_dynamo: + dynamic_shapes = {} + batch = torch.export.Dim("batch") + for name, dims in dynamic_axes.items(): + ds = {} + for d in dims: + if d == 0: + ds[d] = batch + # this currently has issues: https://github.com/pytorch/pytorch/issues/126127 + else: + ds[d] = torch.export.Dim(name + '__' + str(d)) + dynamic_shapes[name] = ds + dynamic_axes = dynamic_shapes return dynamic_axes From c7a5e84b45af213a4af7f914f1008cd18e667c78 Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Fri, 14 Jun 2024 22:38:20 -0700 Subject: [PATCH 09/12] Addressing code review Signed-off-by: Boris Fomitchev --- .../language_modeling/megatron/retro_dataset.py | 1 + nemo/core/classes/exportable.py | 5 +++-- nemo/core/utils/neural_type_utils.py | 2 +- tests/collections/nlp/test_nlp_exportables.py | 13 +++++++++---- 4 files changed, 14 insertions(+), 7 deletions(-) diff --git a/nemo/collections/nlp/data/language_modeling/megatron/retro_dataset.py b/nemo/collections/nlp/data/language_modeling/megatron/retro_dataset.py index c63c6faaddb8..7d604c0b51bc 100644 --- a/nemo/collections/nlp/data/language_modeling/megatron/retro_dataset.py +++ b/nemo/collections/nlp/data/language_modeling/megatron/retro_dataset.py @@ -48,6 +48,7 @@ HAVE_MEGATRON_CORE = True except (ImportError, ModuleNotFoundError): + HAVE_MEGATRON_CORE = False diff --git a/nemo/core/classes/exportable.py b/nemo/core/classes/exportable.py index 7fa950f2c2c0..e6a7344aa293 100644 --- a/nemo/core/classes/exportable.py +++ b/nemo/core/classes/exportable.py @@ -101,6 +101,7 @@ def export( ONNX specific. keep_initializers_as_inputs (bool): If True, will keep the model's initializers as inputs in the onnx graph. This is ONNX specific. + use_dynamo (bool): If True, use onnx.dynamo_export() instead of onnx.export(). This is ONNX specific. Returns: A tuple of two outputs. @@ -179,7 +180,7 @@ def _export( with torch.inference_mode(), torch.no_grad(), torch.jit.optimized_execution(True), _jit_is_scripting(): if input_example is None: - input_example = self.input_module.input_example(max_batch=2) + input_example = self.input_module.input_example() # Remove i/o examples from args we propagate to enclosed Exportables my_args.pop('output') @@ -231,7 +232,7 @@ def _export( # https://github.com/pytorch/pytorch/issues/126339 with monkeypatched(torch.nn.RNNBase, "flatten_parameters", lambda *args: None): - logging.info(f"Running export.export, dynamic shapes:{dynamic_shapes}\n") + logging.info(f"Running export.export, dynamic shapes:{dynamic_axes}\n") # We have to use different types of arguments for dynamo_export to achieve # same external weights behaviour as onnx.export : diff --git a/nemo/core/utils/neural_type_utils.py b/nemo/core/utils/neural_type_utils.py index c540a8b1912b..5a634dad3d57 100644 --- a/nemo/core/utils/neural_type_utils.py +++ b/nemo/core/utils/neural_type_utils.py @@ -14,7 +14,7 @@ from collections import defaultdict from typing import Dict, List, Optional - +import torch from nemo.core.neural_types import AxisKind, NeuralType diff --git a/tests/collections/nlp/test_nlp_exportables.py b/tests/collections/nlp/test_nlp_exportables.py index 3e44fd4dc2a8..dbd5b3ac4427 100644 --- a/tests/collections/nlp/test_nlp_exportables.py +++ b/tests/collections/nlp/test_nlp_exportables.py @@ -41,7 +41,7 @@ def classifier_export(obj): with tempfile.TemporaryDirectory() as tmpdir: filename = os.path.join(tmpdir, obj.__class__.__name__ + '.onnx') obj = obj.cuda() - obj.export(output=filename, use_dynamo=True) + obj.export(output=filename, use_dynamo=True, check_trace=True) class TestExportableClassifiers: @@ -181,7 +181,8 @@ def test_IntentSlotClassificationModel_export_to_onnx(self, dummy_data): trainer = pl.Trainer(**config.trainer) model = IntentSlotClassificationModel(config.model, trainer=trainer) filename = os.path.join(tmpdir, 'isc.onnx') - model.export(output=filename, check_trace=True) + model.export(output=filename, check_trace=True, use_dynamo=False) + model.export(output=filename, check_trace=True, use_dynamo=True) onnx_model = onnx.load(filename) onnx.checker.check_model(onnx_model, full_check=True) # throws when failed assert onnx_model.graph.input[0].name == 'input_ids' @@ -197,7 +198,8 @@ def test_TokenClassificationModel_export_to_onnx(self): model = nemo_nlp.models.TokenClassificationModel.from_pretrained(model_name="ner_en_bert") with tempfile.TemporaryDirectory() as tmpdir: filename = os.path.join(tmpdir, 'ner.onnx') - model.export(output=filename, check_trace=True) + model.export(output=filename, check_trace=True, use_dynamo=False) + model.export(output=filename, check_trace=True, use_dynamo=True) onnx_model = onnx.load(filename) onnx.checker.check_model(onnx_model, full_check=True) # throws when failed assert onnx_model.graph.input[0].name == 'input_ids' @@ -212,7 +214,9 @@ def test_PunctuationCapitalizationModel_export_to_onnx(self): model = nemo_nlp.models.PunctuationCapitalizationModel.from_pretrained(model_name="punctuation_en_distilbert") with tempfile.TemporaryDirectory() as tmpdir: filename = os.path.join(tmpdir, 'puncap.onnx') - model.export(output=filename, check_trace=True) + model.export(output=filename, check_trace=True, use_dynamo=False) + # Unsupported FX nodes: {'call_function': ['aten.detach_.default']}. + # model.export(output=filename, check_trace=True, use_dynamo=True) onnx_model = onnx.load(filename) onnx.checker.check_model(onnx_model, full_check=True) # throws when failed assert onnx_model.graph.input[0].name == 'input_ids' @@ -227,6 +231,7 @@ def test_QAModel_export_to_onnx(self): model = nemo_nlp.models.QAModel.from_pretrained(model_name="qa_squadv2.0_bertbase") with tempfile.TemporaryDirectory() as tmpdir: filename = os.path.join(tmpdir, 'qa.onnx') + model.export(output=filename, check_trace=True, use_dynamo=False) model.export(output=filename, check_trace=True, use_dynamo=True) onnx_model = onnx.load(filename) assert onnx_model.graph.input[0].name == 'input_ids' From 21a5882444875a995299f4db1179e74a73bc3fb8 Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Tue, 18 Jun 2024 17:13:23 -0700 Subject: [PATCH 10/12] Fixing CI issues Signed-off-by: Boris Fomitchev --- Dockerfile.ci | 1 + nemo/collections/asr/models/msdd_models.py | 70 ++++++++++++---------- nemo/core/classes/common.py | 7 +-- nemo/core/classes/exportable.py | 5 -- 4 files changed, 40 insertions(+), 43 deletions(-) diff --git a/Dockerfile.ci b/Dockerfile.ci index 18188f7be45f..a1bc61dece62 100644 --- a/Dockerfile.ci +++ b/Dockerfile.ci @@ -48,6 +48,7 @@ pip install --no-cache-dir --no-build-isolation --extra-index-url https://pypi.n "nvidia-modelopt[torch]~=${MODELOPT_VERSION}" \ "apex @ git+https://github.com/NVIDIA/apex.git@${APEX_TAG}" \ "llama-index==0.10.43" \ +"onnxscript @ git+https://github.com/microsoft/onnxscript" \ -r tools/ctc_segmentation/requirements.txt \ ".[all]" diff --git a/nemo/collections/asr/models/msdd_models.py b/nemo/collections/asr/models/msdd_models.py index 01926eb4ae79..60aae8d1a4b1 100644 --- a/nemo/collections/asr/models/msdd_models.py +++ b/nemo/collections/asr/models/msdd_models.py @@ -163,8 +163,7 @@ def add_speaker_model_config(self, cfg): del cfg.speaker_model_cfg.validation_ds def _init_segmentation_info(self): - """Initialize segmentation settings: window, shift and multiscale weights. - """ + """Initialize segmentation settings: window, shift and multiscale weights.""" self._diarizer_params = self.cfg_msdd_model.diarizer self.multiscale_args_dict = parse_scale_configs( self._diarizer_params.speaker_embeddings.parameters.window_length_in_sec, @@ -275,10 +274,14 @@ def __setup_dataloader_from_config_infer( ) def setup_training_data(self, train_data_config: Optional[Union[DictConfig, Dict]]): - self._train_dl = self.__setup_dataloader_from_config(config=train_data_config,) + self._train_dl = self.__setup_dataloader_from_config( + config=train_data_config, + ) def setup_validation_data(self, val_data_layer_config: Optional[Union[DictConfig, Dict]]): - self._validation_dl = self.__setup_dataloader_from_config(config=val_data_layer_config,) + self._validation_dl = self.__setup_dataloader_from_config( + config=val_data_layer_config, + ) def setup_test_data(self, test_data_config: Optional[Union[DictConfig, Dict]]): if self.pairwise_infer: @@ -338,32 +341,32 @@ def get_ms_emb_seq( Merged embeddings without zero-padding in the batch. See `ms_seg_counts` for details. Shape: (Total number of segments in the batch, emb_dim) scale_mapping (Tensor): - The element at the m-th row and the n-th column of the scale mapping matrix indicates the (m+1)-th scale - segment index which has the closest center distance with (n+1)-th segment in the base scale. - Example: - scale_mapping_argmat[2][101] = 85 - In the above example, it means that 86-th segment in the 3rd scale (python index is 2) is mapped with - 102-th segment in the base scale. Thus, the longer segments bound to have more repeating numbers since - multiple base scale segments (since the base scale has the shortest length) fall into the range of the - longer segments. At the same time, each row contains N numbers of indices where N is number of - segments in the base-scale (i.e., the finest scale). + The element at the m-th row and the n-th column of the scale mapping matrix indicates the (m+1)-th scale + segment index which has the closest center distance with (n+1)-th segment in the base scale. + Example: + scale_mapping_argmat[2][101] = 85 + In the above example, it means that 86-th segment in the 3rd scale (python index is 2) is mapped with + 102-th segment in the base scale. Thus, the longer segments bound to have more repeating numbers since + multiple base scale segments (since the base scale has the shortest length) fall into the range of the + longer segments. At the same time, each row contains N numbers of indices where N is number of + segments in the base-scale (i.e., the finest scale). Shape: (batch_size, scale_n, self.diar_window_length) ms_seg_counts (Tensor): Cumulative sum of the number of segments in each scale. This information is needed to reconstruct the multi-scale input matrix during forward propagating. - Example: `batch_size=3, scale_n=6, emb_dim=192` - ms_seg_counts = - [[8, 9, 12, 16, 25, 51], - [11, 13, 14, 17, 25, 51], - [ 9, 9, 11, 16, 23, 50]] + Example: `batch_size=3, scale_n=6, emb_dim=192` + ms_seg_counts = + [[8, 9, 12, 16, 25, 51], + [11, 13, 14, 17, 25, 51], + [ 9, 9, 11, 16, 23, 50]] - In this function, `ms_seg_counts` is used to get the actual length of each embedding sequence without - zero-padding. + In this function, `ms_seg_counts` is used to get the actual length of each embedding sequence without + zero-padding. Returns: ms_emb_seq (Tensor): - Multi-scale embedding sequence that is mapped, matched and repeated. The longer scales are less repeated, + Multi-scale embedding sequence that is mapped, matched and repeated. The longer scales are less repeated, while shorter scales are more frequently repeated following the scale mapping tensor. """ scale_n, batch_size = scale_mapping[0].shape[0], scale_mapping.shape[0] @@ -409,9 +412,9 @@ def get_cluster_avg_embs_model( [ 9, 9, 11, 16, 23, 50] ] - Counts of merged segments: (121, 131, 118) - embs has shape of (370, 192) - clus_label_index has shape of (3, 131) + Counts of merged segments: (121, 131, 118) + embs has shape of (370, 192) + clus_label_index has shape of (3, 131) Shape: (batch_size, scale_n) @@ -553,7 +556,7 @@ def forward( with torch.no_grad(): self.msdd._speaker_model.eval() logits, embs_d = self.msdd._speaker_model.forward_for_export( - processed_signal=audio_signal[detach_ids[1]], processed_signal_len=audio_signal_len[detach_ids[1]] + audio_signal=audio_signal[detach_ids[1]], length=audio_signal_len[detach_ids[1]] ) embs = torch.zeros(audio_signal.shape[0], embs_d.shape[1]).to(embs_d.device) embs[detach_ids[1], :] = embs_d.detach() @@ -854,9 +857,9 @@ def run_clustering_diarizer(self, manifest_filepath: str, emb_dir: str): os.makedirs(self.out_rttm_dir, exist_ok=True) self.clus_diar_model._cluster_params = self.cfg_diar_infer.diarizer.clustering.parameters - self.clus_diar_model.multiscale_args_dict[ - "multiscale_weights" - ] = self.cfg_diar_infer.diarizer.speaker_embeddings.parameters.multiscale_weights + self.clus_diar_model.multiscale_args_dict["multiscale_weights"] = ( + self.cfg_diar_infer.diarizer.speaker_embeddings.parameters.multiscale_weights + ) self.clus_diar_model._diarizer_params.speaker_embeddings.parameters = ( self.cfg_diar_infer.diarizer.speaker_embeddings.parameters ) @@ -1076,7 +1079,6 @@ def extract_standalone_speaker_model(self, prefix: str = 'msdd._speaker_model.') return _speaker_model def _init_msdd_model(self, cfg: Union[DictConfig, NeuralDiarizerInferenceConfig]): - """ Initialized MSDD model with the provided config. Load either from `.nemo` file or `.ckpt` checkpoint files. """ @@ -1128,7 +1130,7 @@ def get_pred_mat(self, data_list: List[Union[Tuple[int], List[torch.Tensor]]]) - digit_map = dict(zip(sorted(set(all_tups)), range(n_est_spks))) total_len = max([sess[1].shape[1] for sess in data_list]) sum_pred = torch.zeros(total_len, n_est_spks) - for (_dim_tup, pred_mat) in data_list: + for _dim_tup, pred_mat in data_list: dim_tup = [digit_map[x] for x in _dim_tup] if len(pred_mat.shape) == 3: pred_mat = pred_mat.squeeze(0) @@ -1167,8 +1169,7 @@ def get_integrated_preds_list( return output_list def get_emb_clus_infer(self, cluster_embeddings): - """Assign dictionaries containing the clustering results from the class instance `cluster_embeddings`. - """ + """Assign dictionaries containing the clustering results from the class instance `cluster_embeddings`.""" self.msdd_model.emb_sess_test_dict = cluster_embeddings.emb_sess_test_dict self.msdd_model.clus_test_label_dict = cluster_embeddings.clus_test_label_dict self.msdd_model.emb_seq_test = cluster_embeddings.emb_seq_test @@ -1456,7 +1457,10 @@ def from_pretrained( """ logging.setLevel(logging.INFO if verbose else logging.WARNING) cfg = NeuralDiarizerInferenceConfig.init_config( - diar_model_path=model_name, vad_model_path=vad_model_name, map_location=map_location, verbose=verbose, + diar_model_path=model_name, + vad_model_path=vad_model_name, + map_location=map_location, + verbose=verbose, ) return cls(cfg) diff --git a/nemo/core/classes/common.py b/nemo/core/classes/common.py index 73253bbd0a56..60f842dbfb68 100644 --- a/nemo/core/classes/common.py +++ b/nemo/core/classes/common.py @@ -1016,16 +1016,13 @@ def __init__( self.ignore_collections = ignore_collections def __call__(self, wrapped): - return self.wrapped_call(wrapped) if is_typecheck_enabled() else self.unwrapped_call(wrapped) + return self.wrapped_call(wrapped) def unwrapped_call(self, wrapped): return wrapped - def wrapped_call(self, wrapped): - return self.decorated_call(wrapped) - @wrapt.decorator(enabled=is_typecheck_enabled) - def decorated_call(self, wrapped, instance: Typing, args, kwargs): + def wrapped_call(self, wrapped, instance: Typing, args, kwargs): """ Wrapper method that can be used on any function of a class that implements :class:`~nemo.core.Typing`. By default, it will utilize the `input_types` and `output_types` properties of the class inheriting Typing. diff --git a/nemo/core/classes/exportable.py b/nemo/core/classes/exportable.py index e6a7344aa293..d1773cedbaa3 100644 --- a/nemo/core/classes/exportable.py +++ b/nemo/core/classes/exportable.py @@ -225,11 +225,6 @@ def _export( if dynamic_axes is None: dynamic_axes = self.dynamic_shapes_for_export(use_dynamo) if use_dynamo: - import onnxscript - - # https://github.com/microsoft/onnxscript/issues/1544 - onnxscript.optimizer.constant_folding._DEFAULT_CONSTANT_FOLD_SIZE_LIMIT = 1024 * 1024 * 64 - # https://github.com/pytorch/pytorch/issues/126339 with monkeypatched(torch.nn.RNNBase, "flatten_parameters", lambda *args: None): logging.info(f"Running export.export, dynamic shapes:{dynamic_axes}\n") From 52891309ab03f2b6470c4896bbd7fc7b671fd340 Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Wed, 26 Jun 2024 15:55:35 -0700 Subject: [PATCH 11/12] Fixing CI test failure Signed-off-by: Boris Fomitchev --- nemo/core/classes/common.py | 14 ++++---------- tests/collections/nlp/test_nlp_exportables.py | 2 +- 2 files changed, 5 insertions(+), 11 deletions(-) diff --git a/nemo/core/classes/common.py b/nemo/core/classes/common.py index 60f842dbfb68..81f8e6b3e14c 100644 --- a/nemo/core/classes/common.py +++ b/nemo/core/classes/common.py @@ -1015,14 +1015,11 @@ def __init__( self.ignore_collections = ignore_collections - def __call__(self, wrapped): - return self.wrapped_call(wrapped) - def unwrapped_call(self, wrapped): return wrapped @wrapt.decorator(enabled=is_typecheck_enabled) - def wrapped_call(self, wrapped, instance: Typing, args, kwargs): + def __call__(self, wrapped, instance: Typing, args, kwargs): """ Wrapper method that can be used on any function of a class that implements :class:`~nemo.core.Typing`. By default, it will utilize the `input_types` and `output_types` properties of the class inheriting Typing. @@ -1133,9 +1130,6 @@ def disable_semantic_checks(): typecheck.set_semantic_check_enabled(enabled=True) @staticmethod - def enable_wrapping(enabled: bool = True): - typecheck.set_typecheck_enabled(enabled) - if enabled: - typecheck.__call__ = nemo.core.classes.common.typecheck.wrapped_call - else: - typecheck.__call__ = nemo.core.classes.common.typecheck.unwrapped_call + def disable_wrapping(): + typecheck.set_typecheck_enabled(False) + typecheck.__call__ = nemo.core.classes.common.typecheck.unwrapped_call diff --git a/tests/collections/nlp/test_nlp_exportables.py b/tests/collections/nlp/test_nlp_exportables.py index dbd5b3ac4427..119093b703de 100644 --- a/tests/collections/nlp/test_nlp_exportables.py +++ b/tests/collections/nlp/test_nlp_exportables.py @@ -25,7 +25,7 @@ # Has to be applied before first import of NeMo from nemo.core.classes import typecheck -typecheck.enable_wrapping(enabled=False) +typecheck.disable_wrapping() from nemo.collections import nlp as nemo_nlp from nemo.collections.nlp.models import IntentSlotClassificationModel From f804d267a3c44f78642ab2b5024d19bfda4f4c2a Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Wed, 26 Jun 2024 17:57:46 -0700 Subject: [PATCH 12/12] Eliminated test cross-contamination Signed-off-by: Boris Fomitchev --- nemo/core/classes/common.py | 14 ++++++++++---- nemo/core/classes/exportable.py | 2 ++ tests/collections/nlp/test_nlp_exportables.py | 2 +- 3 files changed, 13 insertions(+), 5 deletions(-) diff --git a/nemo/core/classes/common.py b/nemo/core/classes/common.py index 81f8e6b3e14c..60f842dbfb68 100644 --- a/nemo/core/classes/common.py +++ b/nemo/core/classes/common.py @@ -1015,11 +1015,14 @@ def __init__( self.ignore_collections = ignore_collections + def __call__(self, wrapped): + return self.wrapped_call(wrapped) + def unwrapped_call(self, wrapped): return wrapped @wrapt.decorator(enabled=is_typecheck_enabled) - def __call__(self, wrapped, instance: Typing, args, kwargs): + def wrapped_call(self, wrapped, instance: Typing, args, kwargs): """ Wrapper method that can be used on any function of a class that implements :class:`~nemo.core.Typing`. By default, it will utilize the `input_types` and `output_types` properties of the class inheriting Typing. @@ -1130,6 +1133,9 @@ def disable_semantic_checks(): typecheck.set_semantic_check_enabled(enabled=True) @staticmethod - def disable_wrapping(): - typecheck.set_typecheck_enabled(False) - typecheck.__call__ = nemo.core.classes.common.typecheck.unwrapped_call + def enable_wrapping(enabled: bool = True): + typecheck.set_typecheck_enabled(enabled) + if enabled: + typecheck.__call__ = nemo.core.classes.common.typecheck.wrapped_call + else: + typecheck.__call__ = nemo.core.classes.common.typecheck.unwrapped_call diff --git a/nemo/core/classes/exportable.py b/nemo/core/classes/exportable.py index d1773cedbaa3..aab09d42d907 100644 --- a/nemo/core/classes/exportable.py +++ b/nemo/core/classes/exportable.py @@ -225,6 +225,7 @@ def _export( if dynamic_axes is None: dynamic_axes = self.dynamic_shapes_for_export(use_dynamo) if use_dynamo: + typecheck.enable_wrapping(enabled=False) # https://github.com/pytorch/pytorch/issues/126339 with monkeypatched(torch.nn.RNNBase, "flatten_parameters", lambda *args: None): logging.info(f"Running export.export, dynamic shapes:{dynamic_axes}\n") @@ -279,6 +280,7 @@ def _export( else: raise ValueError(f'Encountered unknown export format {format}.') finally: + typecheck.enable_wrapping(enabled=True) typecheck.set_typecheck_enabled(enabled=True) if forward_method: type(self).forward = old_forward_method diff --git a/tests/collections/nlp/test_nlp_exportables.py b/tests/collections/nlp/test_nlp_exportables.py index 119093b703de..dbd5b3ac4427 100644 --- a/tests/collections/nlp/test_nlp_exportables.py +++ b/tests/collections/nlp/test_nlp_exportables.py @@ -25,7 +25,7 @@ # Has to be applied before first import of NeMo from nemo.core.classes import typecheck -typecheck.disable_wrapping() +typecheck.enable_wrapping(enabled=False) from nemo.collections import nlp as nemo_nlp from nemo.collections.nlp.models import IntentSlotClassificationModel