diff --git a/Dockerfile.ci b/Dockerfile.ci index 04ba9df13c7a..6d59d300b26f 100644 --- a/Dockerfile.ci +++ b/Dockerfile.ci @@ -48,6 +48,7 @@ pip install --no-cache-dir --no-build-isolation --extra-index-url https://pypi.n "nvidia-modelopt[torch]~=${MODELOPT_VERSION}" \ "apex @ git+https://github.com/NVIDIA/apex.git@${APEX_TAG}" \ "llama-index==0.10.43" \ +"onnxscript @ git+https://github.com/microsoft/onnxscript" \ -r tools/ctc_segmentation/requirements.txt \ ".[all]" diff --git a/nemo/collections/asr/models/asr_model.py b/nemo/collections/asr/models/asr_model.py index 0539f961a1ca..24e300aff112 100644 --- a/nemo/collections/asr/models/asr_model.py +++ b/nemo/collections/asr/models/asr_model.py @@ -240,12 +240,12 @@ def output_names(self): if getattr(self.input_module, 'export_cache_support', False): in_types = self.input_module.output_types otypes = {n: t for (n, t) in list(otypes.items())[:1]} - for (n, t) in list(in_types.items())[1:]: + for n, t in list(in_types.items())[1:]: otypes[n] = t return get_io_names(otypes, self.disabled_deployment_output_names) def forward_for_export( - self, input, length=None, cache_last_channel=None, cache_last_time=None, cache_last_channel_len=None + self, audio_signal, length=None, cache_last_channel=None, cache_last_time=None, cache_last_channel_len=None ): """ This forward is used when we need to export the model to ONNX format. @@ -264,12 +264,12 @@ def forward_for_export( """ enc_fun = getattr(self.input_module, 'forward_for_export', self.input_module.forward) if cache_last_channel is None: - encoder_output = enc_fun(audio_signal=input, length=length) + encoder_output = enc_fun(audio_signal=audio_signal, length=length) if isinstance(encoder_output, tuple): encoder_output = encoder_output[0] else: encoder_output, length, cache_last_channel, cache_last_time, cache_last_channel_len = enc_fun( - audio_signal=input, + audio_signal=audio_signal, length=length, cache_last_channel=cache_last_channel, cache_last_time=cache_last_time, diff --git a/nemo/collections/asr/models/label_models.py b/nemo/collections/asr/models/label_models.py index 071c53417ae2..9de47645d4f3 100644 --- a/nemo/collections/asr/models/label_models.py +++ b/nemo/collections/asr/models/label_models.py @@ -333,8 +333,8 @@ def output_types(self) -> Optional[Dict[str, NeuralType]]: "embs": NeuralType(('B', 'D'), AcousticEncodedRepresentation()), } - def forward_for_export(self, processed_signal, processed_signal_len): - encoded, length = self.encoder(audio_signal=processed_signal, length=processed_signal_len) + def forward_for_export(self, audio_signal, length): + encoded, length = self.encoder(audio_signal=audio_signal, length=length) logits, embs = self.decoder(encoder_output=encoded, length=length) return logits, embs diff --git a/nemo/collections/asr/models/msdd_models.py b/nemo/collections/asr/models/msdd_models.py index 01926eb4ae79..60aae8d1a4b1 100644 --- a/nemo/collections/asr/models/msdd_models.py +++ b/nemo/collections/asr/models/msdd_models.py @@ -163,8 +163,7 @@ def add_speaker_model_config(self, cfg): del cfg.speaker_model_cfg.validation_ds def _init_segmentation_info(self): - """Initialize segmentation settings: window, shift and multiscale weights. - """ + """Initialize segmentation settings: window, shift and multiscale weights.""" self._diarizer_params = self.cfg_msdd_model.diarizer self.multiscale_args_dict = parse_scale_configs( self._diarizer_params.speaker_embeddings.parameters.window_length_in_sec, @@ -275,10 +274,14 @@ def __setup_dataloader_from_config_infer( ) def setup_training_data(self, train_data_config: Optional[Union[DictConfig, Dict]]): - self._train_dl = self.__setup_dataloader_from_config(config=train_data_config,) + self._train_dl = self.__setup_dataloader_from_config( + config=train_data_config, + ) def setup_validation_data(self, val_data_layer_config: Optional[Union[DictConfig, Dict]]): - self._validation_dl = self.__setup_dataloader_from_config(config=val_data_layer_config,) + self._validation_dl = self.__setup_dataloader_from_config( + config=val_data_layer_config, + ) def setup_test_data(self, test_data_config: Optional[Union[DictConfig, Dict]]): if self.pairwise_infer: @@ -338,32 +341,32 @@ def get_ms_emb_seq( Merged embeddings without zero-padding in the batch. See `ms_seg_counts` for details. Shape: (Total number of segments in the batch, emb_dim) scale_mapping (Tensor): - The element at the m-th row and the n-th column of the scale mapping matrix indicates the (m+1)-th scale - segment index which has the closest center distance with (n+1)-th segment in the base scale. - Example: - scale_mapping_argmat[2][101] = 85 - In the above example, it means that 86-th segment in the 3rd scale (python index is 2) is mapped with - 102-th segment in the base scale. Thus, the longer segments bound to have more repeating numbers since - multiple base scale segments (since the base scale has the shortest length) fall into the range of the - longer segments. At the same time, each row contains N numbers of indices where N is number of - segments in the base-scale (i.e., the finest scale). + The element at the m-th row and the n-th column of the scale mapping matrix indicates the (m+1)-th scale + segment index which has the closest center distance with (n+1)-th segment in the base scale. + Example: + scale_mapping_argmat[2][101] = 85 + In the above example, it means that 86-th segment in the 3rd scale (python index is 2) is mapped with + 102-th segment in the base scale. Thus, the longer segments bound to have more repeating numbers since + multiple base scale segments (since the base scale has the shortest length) fall into the range of the + longer segments. At the same time, each row contains N numbers of indices where N is number of + segments in the base-scale (i.e., the finest scale). Shape: (batch_size, scale_n, self.diar_window_length) ms_seg_counts (Tensor): Cumulative sum of the number of segments in each scale. This information is needed to reconstruct the multi-scale input matrix during forward propagating. - Example: `batch_size=3, scale_n=6, emb_dim=192` - ms_seg_counts = - [[8, 9, 12, 16, 25, 51], - [11, 13, 14, 17, 25, 51], - [ 9, 9, 11, 16, 23, 50]] + Example: `batch_size=3, scale_n=6, emb_dim=192` + ms_seg_counts = + [[8, 9, 12, 16, 25, 51], + [11, 13, 14, 17, 25, 51], + [ 9, 9, 11, 16, 23, 50]] - In this function, `ms_seg_counts` is used to get the actual length of each embedding sequence without - zero-padding. + In this function, `ms_seg_counts` is used to get the actual length of each embedding sequence without + zero-padding. Returns: ms_emb_seq (Tensor): - Multi-scale embedding sequence that is mapped, matched and repeated. The longer scales are less repeated, + Multi-scale embedding sequence that is mapped, matched and repeated. The longer scales are less repeated, while shorter scales are more frequently repeated following the scale mapping tensor. """ scale_n, batch_size = scale_mapping[0].shape[0], scale_mapping.shape[0] @@ -409,9 +412,9 @@ def get_cluster_avg_embs_model( [ 9, 9, 11, 16, 23, 50] ] - Counts of merged segments: (121, 131, 118) - embs has shape of (370, 192) - clus_label_index has shape of (3, 131) + Counts of merged segments: (121, 131, 118) + embs has shape of (370, 192) + clus_label_index has shape of (3, 131) Shape: (batch_size, scale_n) @@ -553,7 +556,7 @@ def forward( with torch.no_grad(): self.msdd._speaker_model.eval() logits, embs_d = self.msdd._speaker_model.forward_for_export( - processed_signal=audio_signal[detach_ids[1]], processed_signal_len=audio_signal_len[detach_ids[1]] + audio_signal=audio_signal[detach_ids[1]], length=audio_signal_len[detach_ids[1]] ) embs = torch.zeros(audio_signal.shape[0], embs_d.shape[1]).to(embs_d.device) embs[detach_ids[1], :] = embs_d.detach() @@ -854,9 +857,9 @@ def run_clustering_diarizer(self, manifest_filepath: str, emb_dir: str): os.makedirs(self.out_rttm_dir, exist_ok=True) self.clus_diar_model._cluster_params = self.cfg_diar_infer.diarizer.clustering.parameters - self.clus_diar_model.multiscale_args_dict[ - "multiscale_weights" - ] = self.cfg_diar_infer.diarizer.speaker_embeddings.parameters.multiscale_weights + self.clus_diar_model.multiscale_args_dict["multiscale_weights"] = ( + self.cfg_diar_infer.diarizer.speaker_embeddings.parameters.multiscale_weights + ) self.clus_diar_model._diarizer_params.speaker_embeddings.parameters = ( self.cfg_diar_infer.diarizer.speaker_embeddings.parameters ) @@ -1076,7 +1079,6 @@ def extract_standalone_speaker_model(self, prefix: str = 'msdd._speaker_model.') return _speaker_model def _init_msdd_model(self, cfg: Union[DictConfig, NeuralDiarizerInferenceConfig]): - """ Initialized MSDD model with the provided config. Load either from `.nemo` file or `.ckpt` checkpoint files. """ @@ -1128,7 +1130,7 @@ def get_pred_mat(self, data_list: List[Union[Tuple[int], List[torch.Tensor]]]) - digit_map = dict(zip(sorted(set(all_tups)), range(n_est_spks))) total_len = max([sess[1].shape[1] for sess in data_list]) sum_pred = torch.zeros(total_len, n_est_spks) - for (_dim_tup, pred_mat) in data_list: + for _dim_tup, pred_mat in data_list: dim_tup = [digit_map[x] for x in _dim_tup] if len(pred_mat.shape) == 3: pred_mat = pred_mat.squeeze(0) @@ -1167,8 +1169,7 @@ def get_integrated_preds_list( return output_list def get_emb_clus_infer(self, cluster_embeddings): - """Assign dictionaries containing the clustering results from the class instance `cluster_embeddings`. - """ + """Assign dictionaries containing the clustering results from the class instance `cluster_embeddings`.""" self.msdd_model.emb_sess_test_dict = cluster_embeddings.emb_sess_test_dict self.msdd_model.clus_test_label_dict = cluster_embeddings.clus_test_label_dict self.msdd_model.emb_seq_test = cluster_embeddings.emb_seq_test @@ -1456,7 +1457,10 @@ def from_pretrained( """ logging.setLevel(logging.INFO if verbose else logging.WARNING) cfg = NeuralDiarizerInferenceConfig.init_config( - diar_model_path=model_name, vad_model_path=vad_model_name, map_location=map_location, verbose=verbose, + diar_model_path=model_name, + vad_model_path=vad_model_name, + map_location=map_location, + verbose=verbose, ) return cls(cfg) diff --git a/nemo/collections/asr/modules/conformer_encoder.py b/nemo/collections/asr/modules/conformer_encoder.py index d723ce85d2ce..245404a7601c 100644 --- a/nemo/collections/asr/modules/conformer_encoder.py +++ b/nemo/collections/asr/modules/conformer_encoder.py @@ -501,6 +501,7 @@ def streaming_post_process(self, rets, keep_all_outputs=True): def forward( self, audio_signal, length, cache_last_channel=None, cache_last_time=None, cache_last_channel_len=None ): + self.update_max_seq_length(seq_length=audio_signal.size(2), device=audio_signal.device) return self.forward_internal( audio_signal, length, @@ -512,8 +513,6 @@ def forward( def forward_internal( self, audio_signal, length, cache_last_channel=None, cache_last_time=None, cache_last_channel_len=None ): - self.update_max_seq_length(seq_length=audio_signal.size(2), device=audio_signal.device) - if length is None: length = audio_signal.new_full( (audio_signal.size(0),), audio_signal.size(-1), dtype=torch.int64, device=audio_signal.device diff --git a/nemo/collections/asr/parts/preprocessing/features.py b/nemo/collections/asr/parts/preprocessing/features.py index dccc81b1816c..d70737b5135b 100644 --- a/nemo/collections/asr/parts/preprocessing/features.py +++ b/nemo/collections/asr/parts/preprocessing/features.py @@ -131,7 +131,7 @@ def clean_spectrogram_batch(spectrogram: torch.Tensor, spectrogram_len: torch.Te def splice_frames(x, frame_splicing): - """ Stacks frames together across feature dim + """Stacks frames together across feature dim input is batch_size, feature_dim, num_frames output is batch_size, feature_dim*frame_splicing, num_frames @@ -261,7 +261,7 @@ def __init__( highfreq=None, log=True, log_zero_guard_type="add", - log_zero_guard_value=2 ** -24, + log_zero_guard_value=2**-24, dither=CONSTANT, pad_to=16, max_duration=16.7, @@ -308,6 +308,7 @@ def __init__( self.hop_length = n_window_stride self.n_fft = n_fft or 2 ** math.ceil(math.log2(self.win_length)) self.stft_pad_amount = (self.n_fft - self.hop_length) // 2 if exact_pad else None + self.exact_pad = exact_pad if exact_pad: logging.info("STFT using exact pad") @@ -321,15 +322,6 @@ def __init__( window_fn = torch_windows.get(window, None) window_tensor = window_fn(self.win_length, periodic=False) if window_fn else None self.register_buffer("window", window_tensor) - self.stft = lambda x: torch.stft( - x, - n_fft=self.n_fft, - hop_length=self.hop_length, - win_length=self.win_length, - center=False if exact_pad else True, - window=self.window.to(dtype=torch.float), - return_complex=True, - ) self.normalize = normalize self.log = log @@ -388,6 +380,17 @@ def __init__( logging.debug(f"using grads: {use_grads}") logging.debug(f"nb_augmentation_prob: {nb_augmentation_prob}") + def stft(self, x): + return torch.stft( + x, + n_fft=self.n_fft, + hop_length=self.hop_length, + win_length=self.win_length, + center=False if self.exact_pad else True, + window=self.window.to(dtype=torch.float), + return_complex=True, + ) + def log_zero_guard_value_fn(self, x): if isinstance(self.log_zero_guard_value, str): if self.log_zero_guard_value == "tiny": @@ -508,7 +511,7 @@ def __init__( highfreq: Optional[float] = None, log: bool = True, log_zero_guard_type: str = "add", - log_zero_guard_value: Union[float, str] = 2 ** -24, + log_zero_guard_value: Union[float, str] = 2**-24, dither: float = 1e-5, window: str = "hann", pad_to: int = 0, @@ -579,7 +582,7 @@ def __init__( @property def filter_banks(self): - """ Matches the analogous class """ + """Matches the analogous class""" return self._mel_spec_extractor.mel_scale.fb def _resolve_log_zero_guard_value(self, dtype: torch.dtype) -> float: diff --git a/nemo/collections/asr/parts/submodules/jasper.py b/nemo/collections/asr/parts/submodules/jasper.py index e53f6299b08a..78f81ee555bc 100644 --- a/nemo/collections/asr/parts/submodules/jasper.py +++ b/nemo/collections/asr/parts/submodules/jasper.py @@ -478,7 +478,7 @@ def forward_for_export(self, x, lengths): mask = self.make_pad_mask(lengths, max_audio_length=max_len, device=x.device) mask = ~mask # 0 represents value, 1 represents pad x = x.float() # For stable AMP, SE must be computed at fp32. - x.masked_fill_(mask, 0.0) # mask padded values explicitly to 0 + x = x.masked_fill(mask, 0.0) # mask padded values explicitly to 0 y = self._se_pool_step(x, mask) # [B, C, 1] y = y.transpose(1, -1) # [B, 1, C] y = self.fc(y) # [B, 1, C] @@ -510,8 +510,8 @@ def _se_pool_step(self, x, mask): return y def set_max_len(self, max_len, seq_range=None): - """ Sets maximum input length. - Pre-calculates internal seq_range mask. + """Sets maximum input length. + Pre-calculates internal seq_range mask. """ self.max_len = max_len if seq_range is None: diff --git a/nemo/collections/nlp/data/language_modeling/megatron/retro_dataset.py b/nemo/collections/nlp/data/language_modeling/megatron/retro_dataset.py index 0f8d3410398d..7d604c0b51bc 100644 --- a/nemo/collections/nlp/data/language_modeling/megatron/retro_dataset.py +++ b/nemo/collections/nlp/data/language_modeling/megatron/retro_dataset.py @@ -122,7 +122,11 @@ def __getitem__(self, idx): def build_train_valid_test_datasets( - cfg, retro_config: RetroConfig, train_valid_test_num_samples, seq_length, tokenizer, + cfg, + retro_config: RetroConfig, + train_valid_test_num_samples, + seq_length, + tokenizer, ): # gpt dataset @@ -135,7 +139,10 @@ def build_train_valid_test_datasets( } retro_train_ds, retro_valid_ds, retro_test_ds = get_retro_datasets( - config=retro_config, gpt_datasets=gpt_datasets, sample_length=seq_length, eod_token_id=tokenizer.eos_id, + config=retro_config, + gpt_datasets=gpt_datasets, + sample_length=seq_length, + eod_token_id=tokenizer.eos_id, ) train_ds = ( diff --git a/nemo/collections/nlp/models/language_modeling/megatron/gpt_layer_modelopt_spec.py b/nemo/collections/nlp/models/language_modeling/megatron/gpt_layer_modelopt_spec.py index d4ea6bfcf094..f001e8f58d25 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron/gpt_layer_modelopt_spec.py +++ b/nemo/collections/nlp/models/language_modeling/megatron/gpt_layer_modelopt_spec.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from nemo.collections.nlp.modules.common.megatron.utils import ApexGuardDefaults + try: from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear diff --git a/nemo/collections/tts/modules/transformer.py b/nemo/collections/tts/modules/transformer.py index 728b583919ff..25c177d221cc 100644 --- a/nemo/collections/tts/modules/transformer.py +++ b/nemo/collections/tts/modules/transformer.py @@ -102,7 +102,7 @@ def __init__(self, n_head, d_model, d_head, dropout, dropatt=0.1, pre_lnorm=Fals self.n_head = n_head self.d_model = d_model self.d_head = d_head - self.scale = 1 / (d_head ** 0.5) + self.scale = 1 / (d_head**0.5) self.pre_lnorm = pre_lnorm self.qkv_net = nn.Linear(d_model, 3 * n_head * d_head) @@ -125,13 +125,17 @@ def _forward(self, inp, attn_mask=None, conditioning=None): head_q, head_k, head_v = torch.chunk(self.qkv_net(inp), 3, dim=2) - head_q = head_q.view(inp.size(0), inp.size(1), n_head, d_head) - head_k = head_k.view(inp.size(0), inp.size(1), n_head, d_head) - head_v = head_v.view(inp.size(0), inp.size(1), n_head, d_head) + s0 = inp.size(0) + s1 = inp.size(1) + s2 = s0 * n_head - q = head_q.permute(2, 0, 1, 3).reshape(-1, inp.size(1), d_head) - k = head_k.permute(2, 0, 1, 3).reshape(-1, inp.size(1), d_head) - v = head_v.permute(2, 0, 1, 3).reshape(-1, inp.size(1), d_head) + head_q = head_q.view(s0, s1, n_head, d_head) + head_k = head_k.view(s0, s1, n_head, d_head) + head_v = head_v.view(s0, s1, n_head, d_head) + + q = head_q.permute(2, 0, 1, 3).reshape(s2, s1, d_head) + k = head_k.permute(2, 0, 1, 3).reshape(s2, s1, d_head) + v = head_v.permute(2, 0, 1, 3).reshape(s2, s1, d_head) attn_score = torch.bmm(q, k.transpose(1, 2)) attn_score.mul_(self.scale) @@ -145,8 +149,8 @@ def _forward(self, inp, attn_mask=None, conditioning=None): attn_prob = self.dropatt(attn_prob) attn_vec = torch.bmm(attn_prob, v) - attn_vec = attn_vec.view(n_head, inp.size(0), inp.size(1), d_head) - attn_vec = attn_vec.permute(1, 2, 0, 3).contiguous().view(inp.size(0), inp.size(1), n_head * d_head) + attn_vec = attn_vec.view(n_head, s0, s1, d_head) + attn_vec = attn_vec.permute(1, 2, 0, 3).contiguous().view(s0, s1, n_head * d_head) # linear projection attn_out = self.o_net(attn_vec) diff --git a/nemo/core/classes/common.py b/nemo/core/classes/common.py index 97757b2e3826..60f842dbfb68 100644 --- a/nemo/core/classes/common.py +++ b/nemo/core/classes/common.py @@ -1015,8 +1015,14 @@ def __init__( self.ignore_collections = ignore_collections + def __call__(self, wrapped): + return self.wrapped_call(wrapped) + + def unwrapped_call(self, wrapped): + return wrapped + @wrapt.decorator(enabled=is_typecheck_enabled) - def __call__(self, wrapped, instance: Typing, args, kwargs): + def wrapped_call(self, wrapped, instance: Typing, args, kwargs): """ Wrapper method that can be used on any function of a class that implements :class:`~nemo.core.Typing`. By default, it will utilize the `input_types` and `output_types` properties of the class inheriting Typing. @@ -1125,3 +1131,11 @@ def disable_semantic_checks(): yield finally: typecheck.set_semantic_check_enabled(enabled=True) + + @staticmethod + def enable_wrapping(enabled: bool = True): + typecheck.set_typecheck_enabled(enabled) + if enabled: + typecheck.__call__ = nemo.core.classes.common.typecheck.wrapped_call + else: + typecheck.__call__ = nemo.core.classes.common.typecheck.unwrapped_call diff --git a/nemo/core/classes/exportable.py b/nemo/core/classes/exportable.py index 5bd1bb813ba3..aab09d42d907 100644 --- a/nemo/core/classes/exportable.py +++ b/nemo/core/classes/exportable.py @@ -20,12 +20,13 @@ from nemo.core.classes import typecheck from nemo.core.neural_types import NeuralType from nemo.core.utils.neural_type_utils import get_dynamic_axes, get_io_names -from nemo.utils import logging +from nemo.utils import logging, monkeypatched from nemo.utils.export_utils import ( ExportFormat, augment_filename, get_export_format, parse_input_example, + rename_onnx_io, replace_for_export, verify_runtime, verify_torchscript, @@ -68,6 +69,7 @@ def export( check_tolerance=0.01, export_modules_as_functions=False, keep_initializers_as_inputs=None, + use_dynamo=False, ): """ Exports the model to the specified format. The format is inferred from the file extension of the output file. @@ -99,6 +101,7 @@ def export( ONNX specific. keep_initializers_as_inputs (bool): If True, will keep the model's initializers as inputs in the onnx graph. This is ONNX specific. + use_dynamo (bool): If True, use onnx.dynamo_export() instead of onnx.export(). This is ONNX specific. Returns: A tuple of two outputs. @@ -122,6 +125,7 @@ def export( check_tolerance=check_tolerance, export_modules_as_functions=export_modules_as_functions, keep_initializers_as_inputs=keep_initializers_as_inputs, + use_dynamo=use_dynamo, ) # Propagate input example (default scenario, may need to be overriden) if input_example is not None: @@ -143,6 +147,7 @@ def _export( check_tolerance=0.01, export_modules_as_functions=False, keep_initializers_as_inputs=None, + use_dynamo=False, ): my_args = locals().copy() my_args.pop('self') @@ -162,7 +167,7 @@ def _export( # Pytorch's default opset version is too low, using reasonable latest one if onnx_opset_version is None: - onnx_opset_version = 16 + onnx_opset_version = 17 try: # Disable typechecks @@ -189,14 +194,16 @@ def _export( input_list, input_dict = parse_input_example(input_example) input_names = self.input_names output_names = self.output_names - output_example = tuple(self.forward(*input_list, **input_dict)) + output_example = self.forward(*input_list, **input_dict) + if not isinstance(output_example, tuple): + output_example = (output_example,) if check_trace: if isinstance(check_trace, bool): check_trace_input = [input_example] else: check_trace_input = check_trace - jitted_model = self + if format == ExportFormat.TORCHSCRIPT: jitted_model = torch.jit.trace_module( self, @@ -216,27 +223,64 @@ def _export( elif format == ExportFormat.ONNX: # dynamic axis is a mapping from input/output_name => list of "dynamic" indices if dynamic_axes is None: - dynamic_axes = get_dynamic_axes(self.input_module.input_types_for_export, input_names) - dynamic_axes.update(get_dynamic_axes(self.output_module.output_types_for_export, output_names)) - torch.onnx.export( - jitted_model, - input_example, - output, - input_names=input_names, - output_names=output_names, - verbose=verbose, - do_constant_folding=do_constant_folding, - dynamic_axes=dynamic_axes, - opset_version=onnx_opset_version, - keep_initializers_as_inputs=keep_initializers_as_inputs, - export_modules_as_functions=export_modules_as_functions, - ) + dynamic_axes = self.dynamic_shapes_for_export(use_dynamo) + if use_dynamo: + typecheck.enable_wrapping(enabled=False) + # https://github.com/pytorch/pytorch/issues/126339 + with monkeypatched(torch.nn.RNNBase, "flatten_parameters", lambda *args: None): + logging.info(f"Running export.export, dynamic shapes:{dynamic_axes}\n") + + # We have to use different types of arguments for dynamo_export to achieve + # same external weights behaviour as onnx.export : + # https://github.com/pytorch/pytorch/issues/126479 + # https://github.com/pytorch/pytorch/issues/126269 + mem_params = sum([param.nelement() * param.element_size() for param in self.parameters()]) + mem_bufs = sum([buf.nelement() * buf.element_size() for buf in self.buffers()]) + mem = mem_params + mem_bufs + + if mem > 2 * 1000 * 1000 * 1000: + ex_model = torch.export.export( + self, + tuple(input_list), + kwargs=input_dict, + dynamic_shapes=dynamic_axes, + strict=False, + ) + ex_model = ex_model.run_decompositions() + model_state = ex_model.state_dict + else: + model_state = None + ex_model = self + + options = torch.onnx.ExportOptions(dynamic_shapes=True, op_level_debug=True) + ex = torch.onnx.dynamo_export(ex_model, *input_list, **input_dict, export_options=options) + ex.save(output, model_state=model_state) + + del ex + del ex_model + # Rename I/O after save - don't want to risk modifying ex._model_proto + rename_onnx_io(output, input_names, output_names) + else: + torch.onnx.export( + self, + input_example, + output, + input_names=input_names, + output_names=output_names, + verbose=verbose, + do_constant_folding=do_constant_folding, + dynamic_axes=dynamic_axes, + opset_version=onnx_opset_version, + keep_initializers_as_inputs=keep_initializers_as_inputs, + export_modules_as_functions=export_modules_as_functions, + ) if check_trace: verify_runtime(self, output, check_trace_input, input_names, check_tolerance=check_tolerance) else: raise ValueError(f'Encountered unknown export format {format}.') finally: + typecheck.enable_wrapping(enabled=True) typecheck.set_typecheck_enabled(enabled=True) if forward_method: type(self).forward = old_forward_method @@ -288,9 +332,12 @@ def input_types_for_export(self) -> Optional[Dict[str, NeuralType]]: def output_types_for_export(self): return self.output_types + def dynamic_shapes_for_export(self, use_dynamo=False): + return get_dynamic_axes(self.input_module.input_types_for_export, self.input_names, use_dynamo) + def get_export_subnet(self, subnet=None): """ - Returns Exportable subnet model/module to export + Returns Exportable subnet model/module to export """ if subnet is None or subnet == 'self': return self diff --git a/nemo/core/utils/neural_type_utils.py b/nemo/core/utils/neural_type_utils.py index 98ae442b9aa7..5a634dad3d57 100644 --- a/nemo/core/utils/neural_type_utils.py +++ b/nemo/core/utils/neural_type_utils.py @@ -14,7 +14,7 @@ from collections import defaultdict from typing import Dict, List, Optional - +import torch from nemo.core.neural_types import AxisKind, NeuralType @@ -30,19 +30,19 @@ def get_io_names(types: Optional[Dict[str, NeuralType]], disabled_names: List[st def extract_dynamic_axes(name: str, ntype: NeuralType): """ - This method will extract BATCH and TIME dimension ids from each provided input/output name argument. - - For example, if module/model accepts argument named "input_signal" with type corresponding to [Batch, Time, Dim] - shape, then the returned result should contain "input_signal" -> [0, 1] because Batch and Time are dynamic axes - as they can change from call to call during inference. - - Args: - name: Name of input or output parameter - ntype: Corresponding Neural Type - - Returns: + This method will extract BATCH and TIME dimension ids from each provided input/output name argument. + + For example, if module/model accepts argument named "input_signal" with type corresponding to [Batch, Time, Dim] + shape, then the returned result should contain "input_signal" -> [0, 1] because Batch and Time are dynamic axes + as they can change from call to call during inference. + + Args: + name: Name of input or output parameter + ntype: Corresponding Neural Type - """ + Returns: + + """ def unpack_nested_neural_type(neural_type): if type(neural_type) in (list, tuple): @@ -60,10 +60,23 @@ def unpack_nested_neural_type(neural_type): return dynamic_axes -def get_dynamic_axes(types, names): +def get_dynamic_axes(types, names, use_dynamo=False): dynamic_axes = defaultdict(list) if names is not None: for name in names: if name in types: dynamic_axes.update(extract_dynamic_axes(name, types[name])) + if use_dynamo: + dynamic_shapes = {} + batch = torch.export.Dim("batch") + for name, dims in dynamic_axes.items(): + ds = {} + for d in dims: + if d == 0: + ds[d] = batch + # this currently has issues: https://github.com/pytorch/pytorch/issues/126127 + else: + ds[d] = torch.export.Dim(name + '__' + str(d)) + dynamic_shapes[name] = ds + dynamic_axes = dynamic_shapes return dynamic_axes diff --git a/nemo/utils/__init__.py b/nemo/utils/__init__.py index ebf892927723..a1e59646ae13 100644 --- a/nemo/utils/__init__.py +++ b/nemo/utils/__init__.py @@ -21,6 +21,7 @@ avoid_float16_autocast_context, cast_all, cast_tensor, + monkeypatched, ) from nemo.utils.dtype import str_to_dtype from nemo.utils.nemo_logging import Logger as _Logger diff --git a/nemo/utils/cast_utils.py b/nemo/utils/cast_utils.py index 21e977ec494d..a7960be4cc4d 100644 --- a/nemo/utils/cast_utils.py +++ b/nemo/utils/cast_utils.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from contextlib import nullcontext +from contextlib import contextmanager, nullcontext import torch @@ -91,3 +91,12 @@ def forward(self, *args): return cast_all(ret, from_dtype=torch.float32, to_dtype=from_dtype) else: return self.mod.forward(*args) + + +@contextmanager +def monkeypatched(object, name, patch): + """Temporarily monkeypatches an object.""" + pre_patched_value = getattr(object, name) + setattr(object, name, patch) + yield object + setattr(object, name, pre_patched_value) diff --git a/nemo/utils/export_utils.py b/nemo/utils/export_utils.py index 4c7a166437cc..c44530944051 100644 --- a/nemo/utils/export_utils.py +++ b/nemo/utils/export_utils.py @@ -126,6 +126,11 @@ def parse_input_example(input_example): def to_onnxrt_input(ort_input_names, input_names, input_dict, input_list): odict = {} + if not input_names: + input_list.extend(input_dict.values()) + for k, v in zip(ort_input_names, input_list): + odict[k] = v.cpu().numpy() + return odict for k in reversed(input_names): val = None if k in input_dict: @@ -172,6 +177,8 @@ def verify_runtime(model, output, input_examples, input_names, check_tolerance=0 for input_example in input_examples: input_list, input_dict = parse_input_example(input_example) output_example = model.forward(*input_list, **input_dict) + if not isinstance(output_example, tuple): + output_example = (output_example,) ort_input = to_onnxrt_input(ort_input_names, input_names, input_dict, input_list) all_good = all_good and run_ort_and_compare(sess, ort_input, output_example, check_tolerance) status = "SUCCESS" if all_good else "FAIL" @@ -216,10 +223,12 @@ def run_ort_and_compare(sess, ort_input, output_example, check_tolerance=0.01): try: if not torch.allclose(tout, expected.cpu(), rtol=check_tolerance, atol=100 * check_tolerance): this_good = False - except Exception: # there may ne size mismatch and it may be OK + except Exception: # there may be size mismatch and it may be OK this_good = False if not this_good: - logging.info(f"onnxruntime results mismatch! PyTorch(expected):\n{expected}\nONNXruntime:\n{tout}") + logging.info( + f"onnxruntime results mismatch! PyTorch(expected, {expected.shape}):\n{expected}\nONNXruntime, {tout.shape}:\n{tout}" + ) all_good = False return all_good @@ -374,7 +383,7 @@ def replace_MatchedScaleMaskSoftmax(n: nn.Module) -> Optional[nn.Linear]: def wrap_module(BaseT: Type[nn.Module], DestT: Type[nn.Module]) -> Callable[[nn.Module], Optional[nn.Module]]: """ - Generic function generator to replace BaseT module with DestT wrapper. + Generic function generator to replace BaseT module with DestT wrapper. Args: BaseT : module type to replace DestT : destination module type @@ -441,7 +450,7 @@ def script_module(m: nn.Module): def replace_for_export(model: nn.Module) -> nn.Module: """ - Top-level function to replace 'default set' of modules in model, called from _prepare_for_export. + Top-level function to replace 'default set' of modules in model, called from _prepare_for_export. NOTE: This occurs in place, if you want to preserve model then make sure to copy it first. Args: model : top level module @@ -474,3 +483,25 @@ def add_casts_around_norms(model: nn.Module): "MaskedInstanceNorm1d": wrap_module(MaskedInstanceNorm1d, CastToFloatAll), } replace_modules(model, default_cast_replacements) + + +def rename_onnx_io(output, input_names, output_names): + onnx_model = onnx.load(output) + rename_map = {} + for inp, name in zip(onnx_model.graph.input, input_names): + rename_map[inp.name] = name + for out, name in zip(onnx_model.graph.output, output_names): + rename_map[out.name] = name + for n in onnx_model.graph.node: + for inp in range(len(n.input)): + if n.input[inp] in rename_map: + n.input[inp] = rename_map[n.input[inp]] + for out in range(len(n.output)): + if n.output[out] in rename_map: + n.output[out] = rename_map[n.output[out]] + + for i in range(len(input_names)): + onnx_model.graph.input[i].name = input_names[i] + for i in range(len(output_names)): + onnx_model.graph.output[i].name = output_names[i] + onnx.save(onnx_model, output) diff --git a/tests/collections/nlp/test_nlp_exportables.py b/tests/collections/nlp/test_nlp_exportables.py index c0b97caea4ed..dbd5b3ac4427 100644 --- a/tests/collections/nlp/test_nlp_exportables.py +++ b/tests/collections/nlp/test_nlp_exportables.py @@ -21,6 +21,12 @@ import wget from omegaconf import DictConfig, OmegaConf +# WAR for https://github.com/pytorch/pytorch/issues/125462 +# Has to be applied before first import of NeMo +from nemo.core.classes import typecheck + +typecheck.enable_wrapping(enabled=False) + from nemo.collections import nlp as nemo_nlp from nemo.collections.nlp.models import IntentSlotClassificationModel from nemo.collections.nlp.modules.common import ( @@ -35,7 +41,7 @@ def classifier_export(obj): with tempfile.TemporaryDirectory() as tmpdir: filename = os.path.join(tmpdir, obj.__class__.__name__ + '.onnx') obj = obj.cuda() - obj.export(output=filename) + obj.export(output=filename, use_dynamo=True, check_trace=True) class TestExportableClassifiers: @@ -175,7 +181,8 @@ def test_IntentSlotClassificationModel_export_to_onnx(self, dummy_data): trainer = pl.Trainer(**config.trainer) model = IntentSlotClassificationModel(config.model, trainer=trainer) filename = os.path.join(tmpdir, 'isc.onnx') - model.export(output=filename, check_trace=True) + model.export(output=filename, check_trace=True, use_dynamo=False) + model.export(output=filename, check_trace=True, use_dynamo=True) onnx_model = onnx.load(filename) onnx.checker.check_model(onnx_model, full_check=True) # throws when failed assert onnx_model.graph.input[0].name == 'input_ids' @@ -191,7 +198,8 @@ def test_TokenClassificationModel_export_to_onnx(self): model = nemo_nlp.models.TokenClassificationModel.from_pretrained(model_name="ner_en_bert") with tempfile.TemporaryDirectory() as tmpdir: filename = os.path.join(tmpdir, 'ner.onnx') - model.export(output=filename, check_trace=True) + model.export(output=filename, check_trace=True, use_dynamo=False) + model.export(output=filename, check_trace=True, use_dynamo=True) onnx_model = onnx.load(filename) onnx.checker.check_model(onnx_model, full_check=True) # throws when failed assert onnx_model.graph.input[0].name == 'input_ids' @@ -206,7 +214,9 @@ def test_PunctuationCapitalizationModel_export_to_onnx(self): model = nemo_nlp.models.PunctuationCapitalizationModel.from_pretrained(model_name="punctuation_en_distilbert") with tempfile.TemporaryDirectory() as tmpdir: filename = os.path.join(tmpdir, 'puncap.onnx') - model.export(output=filename, check_trace=True) + model.export(output=filename, check_trace=True, use_dynamo=False) + # Unsupported FX nodes: {'call_function': ['aten.detach_.default']}. + # model.export(output=filename, check_trace=True, use_dynamo=True) onnx_model = onnx.load(filename) onnx.checker.check_model(onnx_model, full_check=True) # throws when failed assert onnx_model.graph.input[0].name == 'input_ids' @@ -221,7 +231,8 @@ def test_QAModel_export_to_onnx(self): model = nemo_nlp.models.QAModel.from_pretrained(model_name="qa_squadv2.0_bertbase") with tempfile.TemporaryDirectory() as tmpdir: filename = os.path.join(tmpdir, 'qa.onnx') - model.export(output=filename, check_trace=True) + model.export(output=filename, check_trace=True, use_dynamo=False) + model.export(output=filename, check_trace=True, use_dynamo=True) onnx_model = onnx.load(filename) assert onnx_model.graph.input[0].name == 'input_ids' assert onnx_model.graph.input[1].name == 'attention_mask' diff --git a/tests/collections/tts/test_tts_exportables.py b/tests/collections/tts/test_tts_exportables.py index 67f016b0c2af..68c9a55e1f8a 100644 --- a/tests/collections/tts/test_tts_exportables.py +++ b/tests/collections/tts/test_tts_exportables.py @@ -26,7 +26,7 @@ def fastpitch_model(): model = FastPitchModel.from_pretrained(model_name="tts_en_fastpitch") model.export_config['enable_volume'] = True - model.export_config['enable_ragged_batches'] = True + # model.export_config['enable_ragged_batches'] = True return model @@ -65,7 +65,7 @@ def test_FastPitchModel_export_to_onnx(self, fastpitch_model): model = fastpitch_model.cuda() with tempfile.TemporaryDirectory() as tmpdir: filename = os.path.join(tmpdir, 'fp.onnx') - model.export(output=filename, verbose=True, onnx_opset_version=14, check_trace=True) + model.export(output=filename, verbose=True, onnx_opset_version=14, check_trace=True, use_dynamo=True) @pytest.mark.with_downloads() @pytest.mark.run_only_on('GPU') @@ -75,7 +75,7 @@ def test_HifiGanModel_export_to_onnx(self, hifigan_model): assert hifigan_model.generator is not None with tempfile.TemporaryDirectory() as tmpdir: filename = os.path.join(tmpdir, 'hfg.onnx') - model.export(output=filename, verbose=True, check_trace=True) + model.export(output=filename, use_dynamo=True, verbose=True, check_trace=True) @pytest.mark.pleasefixme @pytest.mark.run_only_on('GPU') diff --git a/tutorials/multimodal/Multimodal Data Preparation.ipynb b/tutorials/multimodal/Multimodal Data Preparation.ipynb index b3a38b8b5ec2..fb7bdee1402f 100644 --- a/tutorials/multimodal/Multimodal Data Preparation.ipynb +++ b/tutorials/multimodal/Multimodal Data Preparation.ipynb @@ -14,7 +14,8 @@ ], "metadata": { "collapsed": false - } + }, + "id": "88adf24c9f52084f" }, { "cell_type": "code", @@ -56,7 +57,8 @@ ], "metadata": { "collapsed": false - } + }, + "id": "bb0c8d61cdb92704" }, { "attachments": {}, @@ -207,7 +209,8 @@ }, "source": [ "Note: In this dummy dataset, you will likely see a success rate of 1.000 (no failures). However, for read datasets, the success rate will always be much less than 1.000" - ] + ], + "id": "eaffa123548d6a5e" }, { "attachments": {}, @@ -649,7 +652,8 @@ "\n", "After this, you can proceed with Stage 3 of the tutorial.\n", "Note: if you can use a script to create folders with exactly `tar_chunk_size` (1000 in the tutorial) image-text pairs, and create multiple tarfiles each with `tar_chunk_size` pairs of data, then you can skip Stage 3 and proceed with Stage 4 of the tutorial." - ] + ], + "id": "217dacb92b870798" } ], "metadata": {