From e208fc4b03b7af06fae3017c5131d3fdcf339bf1 Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Thu, 17 Nov 2022 11:56:26 -0800 Subject: [PATCH 1/5] [TTS] Fixing RADTTS training - removing view buffer and fixing accuracy issue (#5358) * Fixing RADTTS training - removing view buffer and fixing accuracy issue Signed-off-by: Boris Fomitchev * Addressing code review Signed-off-by: Boris Fomitchev * Addressing code review 2 Signed-off-by: Boris Fomitchev * Fixed assignment Signed-off-by: Boris Fomitchev * Working script Signed-off-by: Boris Fomitchev * restored flatten_parameters Signed-off-by: Boris Fomitchev * Working bias alias for export Signed-off-by: Boris Fomitchev * Removing unused import Signed-off-by: Boris Fomitchev * Reverting PartialConv Signed-off-by: Boris Fomitchev * Removing flatten_parameters Signed-off-by: Boris Fomitchev * Moving mask updater to GPU Signed-off-by: Boris Fomitchev * Restored norms Signed-off-by: Boris Fomitchev * Restored flatten Signed-off-by: Boris Fomitchev * Moved to sort/unsort Signed-off-by: Boris Fomitchev * Moved to masked norm Signed-off-by: Boris Fomitchev * Turned off cache Signed-off-by: Boris Fomitchev * cleanup Signed-off-by: Boris Fomitchev * Verifying cache not used Signed-off-by: Boris Fomitchev * Removing cache Signed-off-by: Boris Fomitchev * Working autocast export Signed-off-by: Boris Fomitchev * restored e-6 Signed-off-by: Boris Fomitchev * Removed some casts around masks, etc Signed-off-by: Boris Fomitchev * Fixing some casts Signed-off-by: Boris Fomitchev * Fixing in-place ops Signed-off-by: Boris Fomitchev * fixing grad Signed-off-by: Boris Fomitchev * Small export fixes Signed-off-by: Boris Fomitchev * LGTM cleanup Signed-off-by: Boris Fomitchev * Fixed lstm_tensor Signed-off-by: Boris Fomitchev * restored TS check routine Signed-off-by: Boris Fomitchev * Fixed config error Signed-off-by: Boris Fomitchev * reverting some bad optimizations Signed-off-by: Boris Fomitchev * [TTS] add CI test for RADTTS training recipe. Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> * Addressing code review Signed-off-by: Boris Fomitchev * Removing unused var Signed-off-by: Boris Fomitchev * Adding debug Signed-off-by: Boris Fomitchev * Logging fixes Signed-off-by: Boris Fomitchev * Fixing training warnings Signed-off-by: Boris Fomitchev * Fixing more warnings Signed-off-by: Boris Fomitchev * Fixing more warnings 2 Signed-off-by: Boris Fomitchev * Code review fixes Signed-off-by: Boris Fomitchev * Improving TS check Signed-off-by: Boris Fomitchev * Addressing code review comments, optimizing script Signed-off-by: Boris Fomitchev * Forced no-autocast Signed-off-by: Boris Fomitchev Signed-off-by: Boris Fomitchev Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> Co-authored-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> Co-authored-by: Oleksii Kuchaiev Signed-off-by: Boris Fomitchev --- Jenkinsfile | 37 ++++- .../collections/common/callbacks/callbacks.py | 2 +- nemo/collections/tts/models/radtts.py | 19 +-- nemo/collections/tts/modules/common.py | 95 ++++------- nemo/collections/tts/modules/radtts.py | 90 ++++------ nemo/collections/tts/modules/submodules.py | 155 +++++++++--------- nemo/core/classes/exportable.py | 25 +-- nemo/core/optim/radam.py | 4 +- nemo/utils/cast_utils.py | 2 +- nemo/utils/export_utils.py | 20 ++- scripts/export.py | 3 +- tests/collections/tts/test_tts_exportables.py | 36 +++- 12 files changed, 238 insertions(+), 250 deletions(-) diff --git a/Jenkinsfile b/Jenkinsfile index 20d052127ae1..b8966e8e5bbd 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -4111,7 +4111,9 @@ assert_frame_equal(training_curve, gt_curve, rtol=1e-3, atol=1e-3)"''' validation_datasets=/home/TestData/an4_dataset/an4_val.json \ sup_data_path=/home/TestData/an4_dataset/beta_priors \ trainer.devices="[0]" \ - +trainer.limit_train_batches=1 +trainer.limit_val_batches=1 trainer.max_epochs=1 \ + +trainer.limit_train_batches=1 \ + +trainer.limit_val_batches=1 \ + trainer.max_epochs=1 \ trainer.strategy=null \ model.train_ds.dataloader_params.batch_size=4 \ model.train_ds.dataloader_params.num_workers=0 \ @@ -4127,6 +4129,31 @@ assert_frame_equal(training_curve, gt_curve, rtol=1e-3, atol=1e-3)"''' ~model.text_normalizer_call_kwargs' } } + stage('RADTTS') { + steps { + sh 'python examples/tts/radtts.py \ + train_dataset=/home/TestData/an4_dataset/an4_train.json \ + validation_datasets=/home/TestData/an4_dataset/an4_val.json \ + sup_data_path=/home/TestData/an4_dataset/radtts_beta_priors \ + trainer.devices="[0]" \ + +trainer.limit_train_batches=1 \ + +trainer.limit_val_batches=1 \ + trainer.max_epochs=1 \ + trainer.strategy=null \ + model.pitch_mean=212.35873413085938 \ + model.pitch_std=68.52806091308594 \ + model.train_ds.dataloader_params.batch_size=4 \ + model.train_ds.dataloader_params.num_workers=0 \ + model.validation_ds.dataloader_params.batch_size=4 \ + model.validation_ds.dataloader_params.num_workers=0 \ + export_dir=/home/TestData/radtts_test \ + model.optim.lr=0.0001 \ + model.modelConfig.decoder_use_partial_padding=True \ + ~trainer.check_val_every_n_epoch \ + ~model.text_normalizer \ + ~model.text_normalizer_call_kwargs' + } + } stage('Mixer-TTS') { steps { sh 'python examples/tts/mixer_tts.py \ @@ -4134,7 +4161,9 @@ assert_frame_equal(training_curve, gt_curve, rtol=1e-3, atol=1e-3)"''' validation_datasets=/home/TestData/an4_dataset/an4_val.json \ sup_data_path=/home/TestData/an4_dataset/sup_data \ trainer.devices="[0]" \ - +trainer.limit_train_batches=1 +trainer.limit_val_batches=1 trainer.max_epochs=1 \ + +trainer.limit_train_batches=1 \ + +trainer.limit_val_batches=1 \ + trainer.max_epochs=1 \ trainer.strategy=null \ model.train_ds.dataloader_params.batch_size=4 \ model.train_ds.dataloader_params.num_workers=0 \ @@ -4151,7 +4180,9 @@ assert_frame_equal(training_curve, gt_curve, rtol=1e-3, atol=1e-3)"''' train_dataset=/home/TestData/an4_dataset/an4_train.json \ validation_datasets=/home/TestData/an4_dataset/an4_val.json \ trainer.devices="[0]" \ - +trainer.limit_train_batches=1 +trainer.limit_val_batches=1 +trainer.max_epochs=1 \ + +trainer.limit_train_batches=1 \ + +trainer.limit_val_batches=1 \ + +trainer.max_epochs=1 \ trainer.strategy=null \ model.train_ds.dataloader_params.batch_size=4 \ model.train_ds.dataloader_params.num_workers=0 \ diff --git a/nemo/collections/common/callbacks/callbacks.py b/nemo/collections/common/callbacks/callbacks.py index 489c862b3780..1a6c011c38df 100644 --- a/nemo/collections/common/callbacks/callbacks.py +++ b/nemo/collections/common/callbacks/callbacks.py @@ -13,7 +13,7 @@ # limitations under the License. import time -from pytorch_lightning.callbacks.base import Callback +from pytorch_lightning.callbacks import Callback from pytorch_lightning.utilities import rank_zero_only # from sacrebleu import corpus_bleu diff --git a/nemo/collections/tts/models/radtts.py b/nemo/collections/tts/models/radtts.py index 47251b4a3f61..30b6189c484f 100644 --- a/nemo/collections/tts/models/radtts.py +++ b/nemo/collections/tts/models/radtts.py @@ -27,7 +27,6 @@ from nemo.collections.tts.helpers.helpers import plot_alignment_to_numpy from nemo.collections.tts.losses.radttsloss import AttentionBinarizationLoss, RADTTSLoss from nemo.collections.tts.models.base import SpectrogramGenerator -from nemo.collections.tts.modules.submodules import PartialConv1d from nemo.core.classes import Exportable from nemo.core.classes.common import typecheck from nemo.core.neural_types.elements import Index, MelSpectrogramType, TokenIndex @@ -159,7 +158,7 @@ def training_step(self, batch, batch_idx): loss_outputs['binarization_loss'] = (binarization_loss, 1.0) for k, (v, w) in loss_outputs.items(): - self.log("train/" + k, loss_outputs[k][0]) + self.log("train/" + k, loss_outputs[k][0], on_step=True) return {'loss': loss} @@ -229,7 +228,7 @@ def validation_epoch_end(self, outputs): for k, v in loss_outputs.items(): if k != "binarization_loss": - self.log("val/" + k, loss_outputs[k][0]) + self.log("val/" + k, loss_outputs[k][0], sync_dist=True, on_epoch=True) attn = outputs[0]["attn"] attn_soft = outputs[0]["attn_soft"] @@ -407,17 +406,3 @@ def output_module(self): def forward_for_export(self, text, lens, speaker_id, speaker_id_text, speaker_id_attributes): return self.model.forward_for_export(text, lens, speaker_id, speaker_id_text, speaker_id_attributes) - - def get_export_subnet(self, subnet=None): - return self.model.get_export_subnet(subnet) - - def _prepare_for_export(self, **kwargs): - """ - Override this method to prepare module for export. This is in-place operation. - Base version does common necessary module replacements (Apex etc) - """ - PartialConv1d.forward = PartialConv1d.forward_no_cache - super()._prepare_for_export(**kwargs) - - def _export_teardown(self): - PartialConv1d.forward = PartialConv1d.forward_with_cache diff --git a/nemo/collections/tts/modules/common.py b/nemo/collections/tts/modules/common.py index 7eff0c4c3baf..5d93340b4c2d 100644 --- a/nemo/collections/tts/modules/common.py +++ b/nemo/collections/tts/modules/common.py @@ -30,7 +30,7 @@ piecewise_linear_transform, unbounded_piecewise_quadratic_transform, ) -from nemo.collections.tts.modules.submodules import ConvNorm, LinearNorm +from nemo.collections.tts.modules.submodules import ConvNorm, LinearNorm, MaskedInstanceNorm1d @torch.jit.script @@ -45,7 +45,7 @@ def get_mask_from_lengths_and_val(lengths, val): max_len = val.shape[-1] ids = torch.arange(0, max_len, device=lengths.device) mask = ids < lengths.unsqueeze(1) - return mask.float() + return mask @torch.jit.script @@ -124,30 +124,31 @@ def lstm_tensor(self, context: Tensor, lens: Tensor, enforce_sorted: bool = Fals seq = nn.utils.rnn.pack_padded_sequence( context, lens.long().cpu(), batch_first=True, enforce_sorted=enforce_sorted ) - if not torch.jit.is_scripting(): + if not (torch.jit.is_scripting() or torch.jit.is_tracing()): self.bilstm.flatten_parameters() - ret, _ = self.bilstm(seq) + if hasattr(self.bilstm, 'forward'): + ret, _ = self.bilstm.forward(seq) + else: + ret, _ = self.bilstm.forward_1(seq) return nn.utils.rnn.pad_packed_sequence(ret, batch_first=True) @torch.jit.export def lstm_sequence(self, seq: PackedSequence) -> Tuple[Tensor, Tensor]: - if not torch.jit.is_scripting(): + if not (torch.jit.is_scripting() or torch.jit.is_tracing()): self.bilstm.flatten_parameters() - ret, _ = self.bilstm(seq) + if hasattr(self.bilstm, 'forward'): + ret, _ = self.bilstm.forward(seq) + elif hasattr(self.bilstm, 'forward_1'): + ret, _ = self.bilstm.forward_1(seq) return nn.utils.rnn.pad_packed_sequence(ret, batch_first=True) @torch.jit.export def sort_and_lstm_tensor(self, context: Tensor, lens: Tensor) -> Tensor: - lens_sorted, ids_sorted = torch.sort(lens, descending=True) - unsort_ids = torch.zeros_like(ids_sorted) - for i in range(ids_sorted.shape[0]): - unsort_ids[ids_sorted[i]] = i - context = context[ids_sorted] + context, lens_sorted, unsort_ids = sort_tensor(context, lens) seq = nn.utils.rnn.pack_padded_sequence( context, lens_sorted.long().cpu(), batch_first=True, enforce_sorted=True ) - ret, _ = self.bilstm(seq) - return nn.utils.rnn.pad_packed_sequence(ret, batch_first=True)[0][unsort_ids] + return self.lstm_sequence(seq)[0][unsort_ids] class ConvLSTMLinear(BiLSTM): @@ -161,14 +162,14 @@ def __init__( p_dropout=0.1, use_partial_padding=False, norm_fn=None, - lstm_norm_fn="spectral", ): super(ConvLSTMLinear, self).__init__(n_channels, int(n_channels // 2), 1) - self.out_dim = out_dim + self.convolutions = nn.ModuleList() if n_layers > 0: self.dropout = nn.Dropout(p=p_dropout) - self.convolutions = nn.ModuleList() + + use_weight_norm = norm_fn is None for i in range(n_layers): conv_layer = ConvNorm( @@ -179,14 +180,13 @@ def __init__( padding=int((kernel_size - 1) / 2), dilation=1, w_init_gain='relu', - use_weight_norm=False, + use_weight_norm=use_weight_norm, use_partial_padding=use_partial_padding, norm_fn=norm_fn, ) if norm_fn is not None: print("Applying {} norm to {}".format(norm_fn, conv_layer)) else: - conv_layer = torch.nn.utils.weight_norm(conv_layer.conv) print("Applying weight norm to {}".format(conv_layer)) self.convolutions.append(conv_layer) @@ -194,57 +194,23 @@ def __init__( if out_dim is not None: self.dense = nn.Linear(n_channels, out_dim) - @torch.jit.export - def conv_to_sequence(self, context: Tensor, lens: Tensor, enforce_sorted: bool = False) -> PackedSequence: - context_embedded = [] - bs: int = context.shape[0] - b_ind: int = 0 - for b_ind in range(bs): # TODO: speed up - curr_context = context[b_ind : b_ind + 1, :, : lens[b_ind]].clone() - for conv in self.convolutions: - curr_context = self.dropout(F.relu(conv(curr_context))) - context_embedded.append(curr_context[0].transpose(0, 1)) - seq = torch.nn.utils.rnn.pack_sequence(context_embedded, enforce_sorted=enforce_sorted) - return seq - - @torch.jit.export - def conv_to_padded_tensor(self, context: Tensor, lens: Tensor) -> Tensor: - context_embedded = [] - bs: int = context.shape[0] - b_ind: int = 0 - for b_ind in range(bs): # TODO: speed up - curr_context = context[b_ind : b_ind + 1, :, : lens[b_ind]].clone() - for conv in self.convolutions: - curr_context = self.dropout(F.relu(conv(curr_context))) - context_embedded.append(curr_context[0].transpose(0, 1)) - ret = torch.nn.utils.rnn.pad_sequence(context_embedded, batch_first=True) - return ret - - @torch.jit.export def masked_conv_to_sequence(self, context: Tensor, lens: Tensor, enforce_sorted: bool = False) -> PackedSequence: mask = get_mask_from_lengths_and_val(lens, context) - mask = mask.unsqueeze(1) + mask = mask.to(dtype=context.dtype).unsqueeze(1) for conv in self.convolutions: context = self.dropout(F.relu(conv(context, mask))) - context = torch.mul(context, mask) + context = context.transpose(1, 2) seq = torch.nn.utils.rnn.pack_padded_sequence( context, lens.long().cpu(), batch_first=True, enforce_sorted=enforce_sorted ) return seq - def forward(self, context: Tensor, lens: Optional[Tensor] = None) -> Tensor: - if lens is None: - for conv in self.convolutions: - context = self.dropout(F.relu(conv(context))) - context = context.transpose(1, 2) - context, _ = self.bilstm(context) - else: - # borisf : does not match ADLR (values, lengths) - # seq = self.masked_conv_to_sequence(context, lens, enforce_sorted=False) - # borisf : does match ADLR - seq = self.conv_to_sequence(context, lens, enforce_sorted=False) - context, _ = self.lstm_sequence(seq) + def forward(self, context: Tensor, lens: Tensor) -> Tensor: + context, lens, unsort_ids = sort_tensor(context, lens) + seq = self.masked_conv_to_sequence(context, lens, enforce_sorted=True) + context, _ = self.lstm_sequence(seq) + context = context[unsort_ids] if self.dense is not None: context = self.dense(context).permute(0, 2, 1) @@ -252,12 +218,8 @@ def forward(self, context: Tensor, lens: Optional[Tensor] = None) -> Tensor: return context -def getRadTTSEncoder( - encoder_n_convolutions=3, - encoder_embedding_dim=512, - encoder_kernel_size=5, - norm_fn=nn.BatchNorm1d, - lstm_norm_fn=None, +def get_radtts_encoder( + encoder_n_convolutions=3, encoder_embedding_dim=512, encoder_kernel_size=5, norm_fn=MaskedInstanceNorm1d, ): return ConvLSTMLinear( in_dim=encoder_embedding_dim, @@ -267,7 +229,6 @@ def getRadTTSEncoder( p_dropout=0.5, use_partial_padding=True, norm_fn=norm_fn, - lstm_norm_fn=lstm_norm_fn, ) @@ -275,7 +236,7 @@ class Invertible1x1ConvLUS(torch.nn.Module): def __init__(self, c): super(Invertible1x1ConvLUS, self).__init__() # Sample a random orthonormal matrix to initialize weights - W = torch.qr(torch.FloatTensor(c, c).normal_())[0] + W, _ = torch.linalg.qr(torch.FloatTensor(c, c).normal_()) # Ensure determinant is 1.0 not -1.0 if torch.det(W) < 0: W[:, 0] = -1 * W[:, 0] diff --git a/nemo/collections/tts/modules/radtts.py b/nemo/collections/tts/modules/radtts.py index 83bbcda58230..d41e7dd628e5 100644 --- a/nemo/collections/tts/modules/radtts.py +++ b/nemo/collections/tts/modules/radtts.py @@ -31,10 +31,8 @@ Invertible1x1ConvLUS, LinearNorm, get_mask_from_lengths, - getRadTTSEncoder, - sort_tensor, + get_radtts_encoder, ) -from nemo.collections.tts.modules.submodules import PartialConv1d from nemo.core.classes import Exportable, NeuralModule from nemo.core.neural_types.elements import Index, LengthsType, MelSpectrogramType, TokenDurationType, TokenIndex from nemo.core.neural_types.neural_type import NeuralType @@ -62,11 +60,11 @@ def pad_energy_avg_and_f0(energy_avg, f0, max_out_len): def adjust_f0(f0, f0_mean, f0_std, vmask_bool): if f0_mean > 0.0: - f0_mu, f0_sigma = f0[vmask_bool].mean(), f0[vmask_bool].std() - f0[vmask_bool] = (f0[vmask_bool] - f0_mu) / f0_sigma + f0_sigma, f0_mu = torch.std_mean(f0[vmask_bool]) + f0 = ((f0 - f0_mu) / f0_sigma).to(dtype=f0.dtype) f0_std = f0_std if f0_std > 0 else f0_sigma - f0[vmask_bool] = f0[vmask_bool] * f0_std + f0_mean - return f0 + f0 = (f0 * f0_std + f0_mean).to(dtype=f0.dtype) + return f0.masked_fill(~vmask_bool, 0.0) class FlowStep(nn.Module): @@ -146,8 +144,6 @@ def __init__( n_flows, n_conv_layers_per_step, n_mel_channels, - n_hidden, - mel_encoder_n_hidden, dummy_speaker_embedding, n_early_size, n_early_every, @@ -185,9 +181,7 @@ def __init__( self.speaker_embedding = torch.nn.Embedding(n_speakers, self.n_speaker_dim) self.embedding = torch.nn.Embedding(n_text, n_text_dim) self.flows = torch.nn.ModuleList() - self.encoder = getRadTTSEncoder( - encoder_embedding_dim=n_text_dim, norm_fn=nn.InstanceNorm1d, lstm_norm_fn=text_encoder_lstm_norm - ) + self.encoder = get_radtts_encoder(encoder_embedding_dim=n_text_dim) self.dummy_speaker_embedding = dummy_speaker_embedding self.learn_alignments = learn_alignments self.affine_activation = affine_activation @@ -196,11 +190,11 @@ def __init__( self.use_context_lstm = bool(use_context_lstm) self.context_lstm_norm = context_lstm_norm self.context_lstm_w_f0_and_energy = context_lstm_w_f0_and_energy - # self.length_regulator = LengthRegulator() self.use_first_order_features = bool(use_first_order_features) self.decoder_use_unvoiced_bias = kwargs['decoder_use_unvoiced_bias'] self.ap_pred_log_f0 = ap_pred_log_f0 self.ap_use_unvoiced_bias = kwargs['ap_use_unvoiced_bias'] + if 'atn' in include_modules or 'dec' in include_modules: if self.learn_alignments: self.attention = ConvAttention(n_mel_channels, self.n_speaker_dim, n_text_dim) @@ -218,12 +212,6 @@ def __init__( n_in_context_lstm = n_f0_dims + n_energy_avg_dims + n_text_dim n_in_context_lstm *= n_group_size n_in_context_lstm += self.n_speaker_dim - - n_context_hidden = n_f0_dims + n_energy_avg_dims + n_text_dim - n_context_hidden = n_context_hidden * n_group_size / 2 - n_context_hidden = self.n_speaker_dim + n_context_hidden - n_context_hidden = int(n_context_hidden) - n_flowstep_cond_dims = self.n_speaker_dim + n_text_dim * n_group_size self.context_lstm = BiLSTM( @@ -358,7 +346,7 @@ def preprocess_context(self, context, speaker_vecs, out_lens, f0, energy_avg): context_w_spkvec = torch.cat((context_w_spkvec, energy_avg), 1) unfolded_out_lens = out_lens // self.n_group_size - context_lstm_padded_output, _ = self.context_lstm.lstm_tensor( + context_lstm_padded_output = self.context_lstm.sort_and_lstm_tensor( context_w_spkvec.transpose(1, 2), unfolded_out_lens ) context_w_spkvec = context_lstm_padded_output.transpose(1, 2) @@ -466,10 +454,11 @@ def forward( f0_bias = 0 # unvoiced bias forward pass + voiced_mask_bool = voiced_mask.bool() if self.use_unvoiced_bias: f0_bias = self.unvoiced_bias_module(context.permute(0, 2, 1)) f0_bias = -f0_bias[..., 0] - f0_bias = f0_bias * (~voiced_mask.bool()).float() + f0_bias.masked_fill_(voiced_mask_bool, 0.0) # mel decoder forward pass if 'dec' in self.include_modules: @@ -478,7 +467,6 @@ def forward( # sometimes referred to as the "squeeze" operation # invert this by calling self.fold(mel_or_z) mel = self.unfold(mel.unsqueeze(-1)) - z_out = [] # where context is folded # mask f0 in case values are interpolated context_w_spkvec = self.preprocess_context( @@ -542,7 +530,7 @@ def forward( else: f0_target = torch.detach(f0) # fit to log f0 in f0 predictor - f0_target[voiced_mask.bool()] = torch.log(f0_target[voiced_mask.bool()]) + f0_target[voiced_mask_bool] = torch.log(f0_target[voiced_mask_bool]) f0_target = f0_target / 6 # scale to ~ [0, 1] in log space energy_avg = energy_avg * 2 - 1 # scale to ~ [-1, 1] @@ -603,8 +591,6 @@ def infer( voiced_mask=None, ): - # print ("Text, lens: ", text.shape, in_lens.shape) - batch_size = text.shape[0] n_tokens = text.shape[1] spk_vec = self.encode_speaker(speaker_id) @@ -615,7 +601,6 @@ def infer( spk_vec_text = self.encode_speaker(speaker_id_text) spk_vec_attributes = self.encode_speaker(speaker_id_attributes) txt_enc, _ = self.encode_text(text, in_lens) - print("txt_enc: ", txt_enc.shape) if dur is None: # get token durations @@ -626,9 +611,7 @@ def infer( dur = dur[:, 0] dur = dur.clamp(0, token_duration_max) - # get attributes f0, energy, vpred, etc) txt_enc_time_expanded, out_lens = regulate_len(dur, txt_enc.transpose(1, 2), pace) - # print ("txt_enc_time_expanded, out_lens, dur: ", txt_enc_time_expanded.shape, out_lens, dur) n_groups = torch.div(out_lens, self.n_group_size, rounding_mode='floor') max_out_len = torch.max(out_lens) @@ -637,8 +620,10 @@ def infer( if self.use_vpred_module: # get logits voiced_mask = self.v_pred_module.infer(None, txt_enc_time_expanded, spk_vec_attributes, lens=out_lens) - voiced_mask = torch.sigmoid(voiced_mask[:, 0]) > 0.5 - voiced_mask = voiced_mask.float() + voiced_mask_bool = torch.sigmoid(voiced_mask[:, 0]) > 0.5 + voiced_mask = voiced_mask_bool.to(dur.dtype) + else: + voiced_mask_bool = voiced_mask.bool() ap_txt_enc_time_expanded = txt_enc_time_expanded # voice mask augmentation only used for attribute prediction @@ -650,14 +635,13 @@ def infer( if self.use_unvoiced_bias: f0_bias = self.unvoiced_bias_module(txt_enc_time_expanded.permute(0, 2, 1)) f0_bias = -f0_bias[..., 0] - f0_bias = f0_bias * (~voiced_mask.bool()).float() if f0 is None: n_f0_feature_channels = 2 if self.use_first_order_features else 1 z_f0 = torch.normal(txt_enc.new_zeros(batch_size, n_f0_feature_channels, max_out_len)) * sigma_f0 - f0 = self.infer_f0(z_f0, ap_txt_enc_time_expanded, spk_vec_attributes, voiced_mask, out_lens)[:, 0] + f0 = self.infer_f0(z_f0, ap_txt_enc_time_expanded, spk_vec_attributes, voiced_mask_bool, out_lens)[:, 0] - f0 = adjust_f0(f0, f0_mean, f0_std, voiced_mask.to(dtype=bool)) + f0 = adjust_f0(f0, f0_mean, f0_std, voiced_mask_bool) if energy_avg is None: n_energy_feature_channels = 2 if self.use_first_order_features else 1 @@ -669,20 +653,17 @@ def infer( # replication pad, because ungrouping with different group sizes # may lead to mismatched lengths # FIXME: use replication pad - print("V mask, energy_avg, f0, f0_bias: ", voiced_mask.shape, energy_avg.shape, f0.shape, f0_bias.shape) (energy_avg, f0) = pad_energy_avg_and_f0(energy_avg, f0, max_out_len) - print("V mask, energy_avg, f0, f0_bias: ", voiced_mask.shape, energy_avg.shape, f0.shape, f0_bias.shape) context_w_spkvec = self.preprocess_context( - txt_enc_time_expanded, spk_vec, out_lens, f0 * voiced_mask + f0_bias, energy_avg + txt_enc_time_expanded, spk_vec, out_lens, (f0 + f0_bias) * voiced_mask, energy_avg ) residual = torch.normal(txt_enc.new_zeros(batch_size, 80 * self.n_group_size, torch.max(n_groups))) * sigma # map from z sample to data num_steps_to_exit = len(self.exit_steps) - mel = residual[:, num_steps_to_exit * self.n_early_size :] - remaining_residual = residual[:, : num_steps_to_exit * self.n_early_size] + remaining_residual, mel = torch.tensor_split(residual, [num_steps_to_exit * self.n_early_size,], dim=1) for i, flow_step in enumerate(reversed(self.flows)): curr_step = self.n_flows - i - 1 @@ -690,22 +671,19 @@ def infer( if num_steps_to_exit > 0 and curr_step == self.exit_steps[num_steps_to_exit - 1]: # concatenate the next chunk of z num_steps_to_exit = num_steps_to_exit - 1 - residual_to_add = remaining_residual[:, num_steps_to_exit * self.n_early_size :] - remaining_residual = remaining_residual[:, : num_steps_to_exit * self.n_early_size] + remaining_residual, residual_to_add = torch.tensor_split( + remaining_residual, [num_steps_to_exit * self.n_early_size,], dim=1 + ) mel = torch.cat((residual_to_add, mel), 1) if self.n_group_size > 1: mel = self.fold(mel) - # print ("mel=", mel.shape, "out_lens=", out_lens, "dur=", dur.shape) - return {'mel': mel, 'out_lens': out_lens, 'dur': dur, 'f0': f0, 'energy_avg': energy_avg} def infer_f0(self, residual, txt_enc_time_expanded, spk_vec, voiced_mask=None, lens=None): f0 = self.f0_pred_module.infer(residual, txt_enc_time_expanded, spk_vec, lens) - if voiced_mask is not None and len(voiced_mask.shape) == 2: - voiced_mask = voiced_mask[:, None] # constants if self.ap_pred_log_f0: if self.use_first_order_features: @@ -720,14 +698,15 @@ def infer_f0(self, residual, txt_enc_time_expanded, spk_vec, voiced_mask=None, l if voiced_mask is None: voiced_mask = f0 > 0.0 else: - voiced_mask = voiced_mask.bool() - # due to grouping, f0 might be 1 frame short - voiced_mask = voiced_mask[:, :, : f0.shape[-1]] + if len(voiced_mask.shape) == 2: + voiced_mask = voiced_mask[:, None] + # due to grouping, f0 might be 1 frame short + voiced_mask = voiced_mask[:, :, : f0.shape[-1]] + if self.ap_pred_log_f0: # if variable is set, decoder sees linear f0 - # mask = f0 > 0.0 if voiced_mask is None else voiced_mask.bool() - f0[voiced_mask] = torch.exp(f0[voiced_mask]).to(f0) - f0[~voiced_mask] = 0.0 + f0 = torch.exp(f0).to(dtype=f0.dtype) + f0.masked_fill_(~voiced_mask, 0.0) return f0 def infer_energy(self, residual, txt_enc_time_expanded, spk_vec, lens): @@ -783,17 +762,8 @@ def output_types(self): # Methods for model exportability def _prepare_for_export(self, **kwargs): - PartialConv1d.forward = PartialConv1d.forward_no_cache self.remove_norms() super()._prepare_for_export(**kwargs) - self.encoder = torch.jit.script(self.encoder) - self.v_pred_module.feat_pred_fn = torch.jit.script(self.v_pred_module.feat_pred_fn) - self.f0_pred_module.feat_pred_fn = torch.jit.script(self.f0_pred_module.feat_pred_fn) - self.energy_pred_module.feat_pred_fn = torch.jit.script(self.energy_pred_module.feat_pred_fn) - self.dur_pred_layer.feat_pred_fn = torch.jit.script(self.dur_pred_layer.feat_pred_fn) - - if self.use_context_lstm: - self.context_lstm = torch.jit.script(self.context_lstm) def input_example(self, max_batch=1, max_dim=256): """ @@ -804,7 +774,7 @@ def input_example(self, max_batch=1, max_dim=256): par = next(self.parameters()) sz = (max_batch, max_dim) inp = torch.randint(0, 16, sz, device=par.device, dtype=torch.int64) - lens = torch.randint(0, max_dim, (max_batch,), device=par.device, dtype=torch.int) + lens = torch.randint(16, max_dim, (max_batch,), device=par.device, dtype=torch.int) speaker = torch.randint(0, 1, (max_batch,), device=par.device, dtype=torch.int64) inputs = { 'text': inp, diff --git a/nemo/collections/tts/modules/submodules.py b/nemo/collections/tts/modules/submodules.py index 90dd822e1650..e61b9b224885 100644 --- a/nemo/collections/tts/modules/submodules.py +++ b/nemo/collections/tts/modules/submodules.py @@ -12,13 +12,56 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional, Tuple +from typing import Tuple import torch +from torch import Tensor from torch.autograd import Variable from torch.nn import functional as F +def masked_instance_norm( + input: Tensor, mask: Tensor, weight: Tensor, bias: Tensor, momentum: float, eps: float = 1e-5, +) -> Tensor: + r"""Applies Masked Instance Normalization for each channel in each data sample in a batch. + + See :class:`~MaskedInstanceNorm1d` for details. + """ + lengths = mask.sum((-1,)) + mean = (input * mask).sum((-1,)) / lengths # (N, C) + var = (((input - mean[(..., None)]) * mask) ** 2).sum((-1,)) / lengths # (N, C) + out = (input - mean[(..., None)]) / torch.sqrt(var[(..., None)] + eps) # (N, C, ...) + out = out * weight[None, :][(..., None)] + bias[None, :][(..., None)] + + return out + + +class MaskedInstanceNorm1d(torch.nn.InstanceNorm1d): + r"""Applies Instance Normalization over a masked 3D input + (a mini-batch of 1D inputs with additional channel dimension).. + + See documentation of :class:`~torch.nn.InstanceNorm1d` for details. + + Shape: + - Input: :math:`(N, C, L)` + - Mask: :math:`(N, 1, L)` + - Output: :math:`(N, C, L)` (same shape as input) + """ + + def __init__( + self, + num_features: int, + eps: float = 1e-5, + momentum: float = 0.1, + affine: bool = False, + track_running_stats: bool = False, + ) -> None: + super(MaskedInstanceNorm1d, self).__init__(num_features, eps, momentum, affine, track_running_stats) + + def forward(self, input: Tensor, mask: Tensor) -> Tensor: + return masked_instance_norm(input, mask, self.weight, self.bias, self.momentum, self.eps,) + + class PartialConv1d(torch.nn.Conv1d): """ Zero padding creates a unique identifier for where the edge of the data is, such that the model can almost always identify @@ -26,31 +69,22 @@ class PartialConv1d(torch.nn.Conv1d): this affect. """ + __constants__ = ['slide_winsize'] + slide_winsize: float + def __init__(self, *args, **kwargs): super(PartialConv1d, self).__init__(*args, **kwargs) weight_maskUpdater = torch.ones(1, 1, self.kernel_size[0]) self.register_buffer("weight_maskUpdater", weight_maskUpdater, persistent=False) - slide_winsize = torch.tensor(self.weight_maskUpdater.shape[1] * self.weight_maskUpdater.shape[2]) - self.register_buffer("slide_winsize", slide_winsize, persistent=False) + self.slide_winsize = self.weight_maskUpdater.shape[1] * self.weight_maskUpdater.shape[2] - if self.bias is not None: - bias_view = self.bias.view(1, self.out_channels, 1) - self.register_buffer('bias_view', bias_view, persistent=False) - # caching part - self.last_size = (-1, -1, -1) - - update_mask = torch.ones(1, 1, 1) - self.register_buffer('update_mask', update_mask, persistent=False) - mask_ratio = torch.ones(1, 1, 1) - self.register_buffer('mask_ratio', mask_ratio, persistent=False) - self.partial: bool = True - - def calculate_mask(self, input: torch.Tensor, mask_in: Optional[torch.Tensor]): + def forward(self, input, mask_in): + if mask_in is None: + mask = torch.ones(1, 1, input.shape[2], dtype=input.dtype, device=input.device) + else: + mask = mask_in + input = torch.mul(input, mask) with torch.no_grad(): - if mask_in is None: - mask = torch.ones(1, 1, input.shape[2], dtype=input.dtype, device=input.device) - else: - mask = mask_in update_mask = F.conv1d( mask, self.weight_maskUpdater, @@ -60,58 +94,22 @@ def calculate_mask(self, input: torch.Tensor, mask_in: Optional[torch.Tensor]): dilation=self.dilation, groups=1, ) - # for mixed precision training, change 1e-8 to 1e-6 - mask_ratio = self.slide_winsize / (update_mask + 1e-6) + update_mask_filled = torch.masked_fill(update_mask, update_mask == 0, self.slide_winsize) + mask_ratio = self.slide_winsize / update_mask_filled update_mask = torch.clamp(update_mask, 0, 1) - mask_ratio = torch.mul(mask_ratio.to(update_mask), update_mask) - return torch.mul(input, mask), mask_ratio, update_mask - - def forward_aux(self, input: torch.Tensor, mask_ratio: torch.Tensor, update_mask: torch.Tensor) -> torch.Tensor: - assert len(input.shape) == 3 + mask_ratio = torch.mul(mask_ratio, update_mask) raw_out = self._conv_forward(input, self.weight, self.bias) if self.bias is not None: - output = torch.mul(raw_out - self.bias_view, mask_ratio) + self.bias_view + bias_view = self.bias.view(1, self.out_channels, 1) + output = torch.mul(raw_out - bias_view, mask_ratio) + bias_view output = torch.mul(output, update_mask) else: output = torch.mul(raw_out, mask_ratio) return output - @torch.jit.ignore - def forward_with_cache(self, input: torch.Tensor, mask_in: Optional[torch.Tensor] = None) -> torch.Tensor: - use_cache = not (torch.jit.is_tracing() or torch.onnx.is_in_onnx_export()) - cache_hit = use_cache and mask_in is None and self.last_size == input.shape - if cache_hit: - mask_ratio = self.mask_ratio - update_mask = self.update_mask - else: - input, mask_ratio, update_mask = self.calculate_mask(input, mask_in) - if use_cache: - # if a mask is input, or tensor shape changed, update mask ratio - self.last_size = tuple(input.shape) - self.update_mask = update_mask - self.mask_ratio = mask_ratio - return self.forward_aux(input, mask_ratio, update_mask) - - def forward_no_cache(self, input: torch.Tensor, mask_in: Optional[torch.Tensor] = None) -> torch.Tensor: - if self.partial: - input, mask_ratio, update_mask = self.calculate_mask(input, mask_in) - return self.forward_aux(input, mask_ratio, update_mask) - else: - if mask_in is not None: - input = torch.mul(input, mask_in) - return self._conv_forward(input, self.weight, self.bias) - - def forward(self, input: torch.Tensor, mask_in: Optional[torch.Tensor] = None) -> torch.Tensor: - if self.partial: - return self.forward_with_cache(input, mask_in) - else: - if mask_in is not None: - input = torch.mul(input, mask_in) - return self._conv_forward(input, self.weight, self.bias) - class LinearNorm(torch.nn.Module): def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'): @@ -125,6 +123,9 @@ def forward(self, x): class ConvNorm(torch.nn.Module): + __constants__ = ['use_partial_padding'] + use_partial_padding: bool + def __init__( self, in_channels, @@ -135,16 +136,19 @@ def __init__( dilation=1, bias=True, w_init_gain='linear', - use_partial_padding: bool = False, - use_weight_norm: bool = False, + use_partial_padding=False, + use_weight_norm=False, norm_fn=None, ): super(ConvNorm, self).__init__() if padding is None: assert kernel_size % 2 == 1 padding = int(dilation * (kernel_size - 1) / 2) - self.use_partial_padding: bool = use_partial_padding - conv = PartialConv1d( + self.use_partial_padding = use_partial_padding + conv_fn = torch.nn.Conv1d + if use_partial_padding: + conv_fn = PartialConv1d + self.conv = conv_fn( in_channels, out_channels, kernel_size=kernel_size, @@ -153,20 +157,25 @@ def __init__( dilation=dilation, bias=bias, ) - conv.partial = use_partial_padding - torch.nn.init.xavier_uniform_(conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain)) + torch.nn.init.xavier_uniform_(self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain)) if use_weight_norm: - conv = torch.nn.utils.weight_norm(conv) + self.conv = torch.nn.utils.weight_norm(self.conv) if norm_fn is not None: self.norm = norm_fn(out_channels, affine=True) else: self.norm = None - self.conv = conv - def forward(self, input: torch.Tensor, mask_in: Optional[torch.Tensor] = None) -> torch.Tensor: - ret = self.conv(input, mask_in) - if self.norm is not None: - ret = self.norm(ret) + def forward(self, signal, mask=None): + if self.use_partial_padding: + ret = self.conv(signal, mask) + if self.norm is not None: + ret = self.norm(ret, mask) + else: + if mask is not None: + signal = signal * mask + ret = self.conv(signal) + if self.norm is not None: + ret = self.norm(ret) return ret diff --git a/nemo/core/classes/exportable.py b/nemo/core/classes/exportable.py index 5a9ab55a4ee7..e5f7b5231600 100644 --- a/nemo/core/classes/exportable.py +++ b/nemo/core/classes/exportable.py @@ -128,7 +128,7 @@ def _export( # Set module mode with torch.onnx.select_model_mode_for_export( self, training - ), torch.inference_mode(), torch.jit.optimized_execution(True): + ), torch.inference_mode(), torch.no_grad(), torch.jit.optimized_execution(True): if input_example is None: input_example = self.input_module.input_example() @@ -147,12 +147,14 @@ def _export( output_names = self.output_names output_example = tuple(self.forward(*input_list, **input_dict)) + if check_trace: + if isinstance(check_trace, bool): + check_trace_input = [input_example] + else: + check_trace_input = check_trace + if format == ExportFormat.TORCHSCRIPT: - if check_trace: - if isinstance(check_trace, bool): - check_trace_input = {"forward": tuple(input_list) + tuple(input_dict.values())} - else: - check_trace_input = check_trace + jitted_model = torch.jit.trace_module( self, {"forward": tuple(input_list) + tuple(input_dict.values())}, @@ -165,14 +167,9 @@ def _export( if verbose: logging.info(f"JIT code:\n{jitted_model.code}") jitted_model.save(output) - assert os.path.exists(output) + jitted_model = torch.jit.load(output) if check_trace: - if isinstance(check_trace, bool): - check_trace_input = [input_example] - else: - check_trace_input = check_trace - verify_torchscript(jitted_model, output, check_trace_input, input_names, check_tolerance) elif format == ExportFormat.ONNX: @@ -196,10 +193,6 @@ def _export( ) if check_trace: - if isinstance(check_trace, bool): - check_trace_input = [input_example] - else: - check_trace_input = check_trace verify_runtime(self, output, check_trace_input, input_names) else: raise ValueError(f'Encountered unknown export format {format}.') diff --git a/nemo/core/optim/radam.py b/nemo/core/optim/radam.py index 62a5ecff87be..69cfab4bf858 100644 --- a/nemo/core/optim/radam.py +++ b/nemo/core/optim/radam.py @@ -81,8 +81,8 @@ def step(self, closure=None): exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] beta1, beta2 = group['betas'] - exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) - exp_avg.mul_(beta1).add_(1 - beta1, grad) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=(1.0 - beta2)) + exp_avg.mul_(beta1).add_(grad, alpha=(1.0 - beta1)) state['step'] += 1 buffered = self.buffer[int(state['step'] % 10)] diff --git a/nemo/utils/cast_utils.py b/nemo/utils/cast_utils.py index 9eb064936ea5..f973a4719e24 100644 --- a/nemo/utils/cast_utils.py +++ b/nemo/utils/cast_utils.py @@ -70,6 +70,6 @@ def __init__(self, mod): self.mod = mod def forward(self, x): - with avoid_float16_autocast_context(): + with torch.cuda.amp.autocast(enabled=False): ret = self.mod.forward(x.to(torch.float32)).to(x.dtype) return ret diff --git a/nemo/utils/export_utils.py b/nemo/utils/export_utils.py index a5c4e5b3d24f..fbe21b9cf8f8 100644 --- a/nemo/utils/export_utils.py +++ b/nemo/utils/export_utils.py @@ -15,7 +15,7 @@ import os from contextlib import nullcontext from enum import Enum -from typing import Callable, Dict, Optional, Type +from typing import Callable, Dict, List, Optional, Type import onnx import torch @@ -160,7 +160,7 @@ def verify_torchscript(model, output, input_examples, input_names, check_toleran for input_example in input_examples: input_list, input_dict = parse_input_example(input_example) output_example = model.forward(*input_list, **input_dict) - # ts_input = to_onnxrt_input(ort_input_names, input_names, input_dict, input_list) + all_good = all_good and run_ts_and_compare(ts_model, input_list, input_dict, output_example, check_tolerance) status = "SUCCESS" if all_good else "FAIL" logging.info(f"Torchscript generated at {output} verified with torchscript forward : " + status) @@ -203,7 +203,7 @@ def run_ts_and_compare(ts_model, ts_input_list, ts_input_dict, output_example, c if torch.is_tensor(expected): tout = out.to('cpu') - logging.info(f"Checking output {i}, shape: {expected.shape}:\n{expected}\n{tout}") + logging.debug(f"Checking output {i}, shape: {expected.shape}:\n{expected}\n{tout}") if not torch.allclose(tout, expected.cpu(), rtol=check_tolerance, atol=check_tolerance): all_good = False logging.info(f"onnxruntime results mismatch! PyTorch(expected):\n{expected}\nTorchScript:\n{tout}") @@ -219,7 +219,7 @@ def run_ort_and_compare(sess, ort_input, output_example, check_tolerance=0.01): if torch.is_tensor(expected): tout = torch.from_numpy(out) - logging.info(f"Checking output {i}, shape: {expected.shape}:\n{expected}\n{tout}") + logging.debug(f"Checking output {i}, shape: {expected.shape}:\n{expected}\n{tout}") if not torch.allclose(tout, expected.cpu(), rtol=check_tolerance, atol=100 * check_tolerance): all_good = False logging.info(f"onnxruntime results mismatch! PyTorch(expected):\n{expected}\nONNXruntime:\n{tout}") @@ -418,6 +418,11 @@ def replace_modules( return model +def script_module(m: nn.Module): + m1 = torch.jit.script(m) + return m1 + + default_replacements = { "BatchNorm1d": wrap_module(nn.BatchNorm1d, CastToFloat), "BatchNorm2d": wrap_module(nn.BatchNorm2d, CastToFloat), @@ -425,6 +430,11 @@ def replace_modules( "MatchedScaleMaskSoftmax": wrap_module(nn.Softmax, ExportableMatchedScaleMaskSoftmax), } +script_replacements = { + "BiLSTM": script_module, + "ConvLSTMLinear": script_module, +} + def replace_for_export(model: nn.Module) -> nn.Module: """ @@ -438,3 +448,5 @@ def replace_for_export(model: nn.Module) -> nn.Module: """ replace_modules(model, default_Apex_replacements) replace_modules(model, default_replacements) + # This one has to be the last + replace_modules(model, script_replacements) diff --git a/scripts/export.py b/scripts/export.py index b3d6317e936c..2e100e446e72 100644 --- a/scripts/export.py +++ b/scripts/export.py @@ -143,10 +143,11 @@ def nemo_export(argv): if check_trace and len(in_args) > 0: input_example = model.input_module.input_example(**in_args) check_trace = [input_example] - for key, arg in in_args: + for key, arg in in_args.items(): in_args[key] = (arg + 1) // 2 input_example2 = model.input_module.input_example(**in_args) check_trace.append(input_example2) + logging.info(f"Using additional check args: {in_args}") _, descriptions = model.export( out, diff --git a/tests/collections/tts/test_tts_exportables.py b/tests/collections/tts/test_tts_exportables.py index 3c3f13a028a6..e3e496373271 100644 --- a/tests/collections/tts/test_tts_exportables.py +++ b/tests/collections/tts/test_tts_exportables.py @@ -15,8 +15,10 @@ import tempfile import pytest +from omegaconf import OmegaConf -from nemo.collections.tts.models import FastPitchModel, HifiGanModel +from nemo.collections.tts.models import FastPitchModel, HifiGanModel, RadTTSModel +from nemo.utils.app_state import AppState @pytest.fixture() @@ -31,6 +33,27 @@ def hifigan_model(): return model +@pytest.fixture() +def radtts_model(): + this_test_dir = os.path.dirname(os.path.abspath(__file__)) + + cfg = OmegaConf.load(os.path.join(this_test_dir, '../../../examples/tts/conf/rad-tts_feature_pred.yaml')) + cfg.model.init_from_ptl_ckpt = None + cfg.model.train_ds.dataset.manifest_filepath = "dummy.json" + cfg.model.train_ds.dataset.sup_data_path = "dummy.json" + cfg.model.validation_ds.dataset.manifest_filepath = "dummy.json" + cfg.model.validation_ds.dataset.sup_data_path = "dummy.json" + cfg.pitch_mean = 212.35 + cfg.pitch_std = 68.52 + + app_state = AppState() + app_state.is_model_being_restored = True + model = RadTTSModel(cfg=cfg.model) + app_state.is_model_being_restored = False + model.eval() + return model + + class TestExportable: @pytest.mark.run_only_on('GPU') @pytest.mark.unit @@ -50,7 +73,10 @@ def test_HifiGanModel_export_to_onnx(self, hifigan_model): filename = os.path.join(tmpdir, 'hfg.pt') model.export(output=filename, verbose=True, check_trace=True) - -if __name__ == "__main__": - t = TestExportable() - t.test_FastPitchModel_export_to_onnx(fastpitch_model()) + @pytest.mark.run_only_on('GPU') + @pytest.mark.unit + def test_RadTTSModel_export_to_torchscript(self, radtts_model): + model = radtts_model.cuda() + with tempfile.TemporaryDirectory() as tmpdir: + filename = os.path.join(tmpdir, 'rad.ts') + model.export(output=filename, verbose=True, check_trace=True) From 5ce64de4ccf34e10c0ec8471f97fafb99ae3ec1e Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Thu, 17 Nov 2022 13:23:25 -0800 Subject: [PATCH 2/5] Cleanup Signed-off-by: Boris Fomitchev --- nemo/collections/tts/modules/common.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/nemo/collections/tts/modules/common.py b/nemo/collections/tts/modules/common.py index 5d93340b4c2d..272cd54501fd 100644 --- a/nemo/collections/tts/modules/common.py +++ b/nemo/collections/tts/modules/common.py @@ -126,20 +126,14 @@ def lstm_tensor(self, context: Tensor, lens: Tensor, enforce_sorted: bool = Fals ) if not (torch.jit.is_scripting() or torch.jit.is_tracing()): self.bilstm.flatten_parameters() - if hasattr(self.bilstm, 'forward'): - ret, _ = self.bilstm.forward(seq) - else: - ret, _ = self.bilstm.forward_1(seq) + ret, _ = self.bilstm(seq) return nn.utils.rnn.pad_packed_sequence(ret, batch_first=True) @torch.jit.export def lstm_sequence(self, seq: PackedSequence) -> Tuple[Tensor, Tensor]: if not (torch.jit.is_scripting() or torch.jit.is_tracing()): self.bilstm.flatten_parameters() - if hasattr(self.bilstm, 'forward'): - ret, _ = self.bilstm.forward(seq) - elif hasattr(self.bilstm, 'forward_1'): - ret, _ = self.bilstm.forward_1(seq) + ret, _ = self.bilstm(seq) return nn.utils.rnn.pad_packed_sequence(ret, batch_first=True) @torch.jit.export From 6aa61ff36e5cc1a62d8e5aa8b749923272dc5077 Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Thu, 17 Nov 2022 21:16:33 -0800 Subject: [PATCH 3/5] Fixes for Torchscript/Triton Signed-off-by: Boris Fomitchev --- nemo/collections/tts/modules/common.py | 42 +++++++++----------------- nemo/collections/tts/modules/radtts.py | 4 +-- nemo/core/classes/exportable.py | 1 - nemo/utils/cast_utils.py | 2 +- nemo/utils/export_utils.py | 36 ++++++++++++++-------- 5 files changed, 40 insertions(+), 45 deletions(-) diff --git a/nemo/collections/tts/modules/common.py b/nemo/collections/tts/modules/common.py index 272cd54501fd..906de05cd8c8 100644 --- a/nemo/collections/tts/modules/common.py +++ b/nemo/collections/tts/modules/common.py @@ -119,33 +119,29 @@ def __init__(self, input_size, hidden_size, num_layers=1, lstm_norm_fn="spectral lstm_norm_fn_pntr(self.bilstm, 'weight_hh_l0_reverse') self.bilstm.flatten_parameters() - @torch.jit.export def lstm_tensor(self, context: Tensor, lens: Tensor, enforce_sorted: bool = False) -> Tuple[Tensor, Tensor]: seq = nn.utils.rnn.pack_padded_sequence( context, lens.long().cpu(), batch_first=True, enforce_sorted=enforce_sorted ) - if not (torch.jit.is_scripting() or torch.jit.is_tracing()): - self.bilstm.flatten_parameters() - ret, _ = self.bilstm(seq) - return nn.utils.rnn.pad_packed_sequence(ret, batch_first=True) + return self.lstm_sequence(seq) - @torch.jit.export def lstm_sequence(self, seq: PackedSequence) -> Tuple[Tensor, Tensor]: if not (torch.jit.is_scripting() or torch.jit.is_tracing()): self.bilstm.flatten_parameters() ret, _ = self.bilstm(seq) return nn.utils.rnn.pad_packed_sequence(ret, batch_first=True) - @torch.jit.export - def sort_and_lstm_tensor(self, context: Tensor, lens: Tensor) -> Tensor: + def forward(self, context: Tensor, lens: Tensor) -> Tensor: context, lens_sorted, unsort_ids = sort_tensor(context, lens) - seq = nn.utils.rnn.pack_padded_sequence( - context, lens_sorted.long().cpu(), batch_first=True, enforce_sorted=True - ) - return self.lstm_sequence(seq)[0][unsort_ids] + dtype = context.dtype + # this is only needed for Torchscript to run in Triton + # (https://github.com/pytorch/pytorch/issues/89241) + with torch.cuda.amp.autocast(enabled=False): + ret = self.lstm_tensor(context.to(dtype=torch.float32), lens_sorted, enforce_sorted=True) + return ret[0].to(dtype=dtype)[unsort_ids] -class ConvLSTMLinear(BiLSTM): +class ConvLSTMLinear(nn.Module): def __init__( self, in_dim=None, @@ -157,7 +153,8 @@ def __init__( use_partial_padding=False, norm_fn=None, ): - super(ConvLSTMLinear, self).__init__(n_channels, int(n_channels // 2), 1) + super(ConvLSTMLinear, self).__init__() + self.bilstm = BiLSTM(n_channels, int(n_channels // 2), 1) self.convolutions = nn.ModuleList() if n_layers > 0: @@ -188,27 +185,16 @@ def __init__( if out_dim is not None: self.dense = nn.Linear(n_channels, out_dim) - def masked_conv_to_sequence(self, context: Tensor, lens: Tensor, enforce_sorted: bool = False) -> PackedSequence: + def forward(self, context: Tensor, lens: Tensor) -> Tensor: mask = get_mask_from_lengths_and_val(lens, context) mask = mask.to(dtype=context.dtype).unsqueeze(1) for conv in self.convolutions: context = self.dropout(F.relu(conv(context, mask))) - context = context.transpose(1, 2) - seq = torch.nn.utils.rnn.pack_padded_sequence( - context, lens.long().cpu(), batch_first=True, enforce_sorted=enforce_sorted - ) - return seq - - def forward(self, context: Tensor, lens: Tensor) -> Tensor: - context, lens, unsort_ids = sort_tensor(context, lens) - seq = self.masked_conv_to_sequence(context, lens, enforce_sorted=True) - context, _ = self.lstm_sequence(seq) - context = context[unsort_ids] - + # Apply Bidirectional LSTM + context = self.bilstm(context, lens) if self.dense is not None: context = self.dense(context).permute(0, 2, 1) - return context diff --git a/nemo/collections/tts/modules/radtts.py b/nemo/collections/tts/modules/radtts.py index d41e7dd628e5..c050ff2e2d76 100644 --- a/nemo/collections/tts/modules/radtts.py +++ b/nemo/collections/tts/modules/radtts.py @@ -346,9 +346,7 @@ def preprocess_context(self, context, speaker_vecs, out_lens, f0, energy_avg): context_w_spkvec = torch.cat((context_w_spkvec, energy_avg), 1) unfolded_out_lens = out_lens // self.n_group_size - context_lstm_padded_output = self.context_lstm.sort_and_lstm_tensor( - context_w_spkvec.transpose(1, 2), unfolded_out_lens - ) + context_lstm_padded_output = self.context_lstm(context_w_spkvec.transpose(1, 2), unfolded_out_lens) context_w_spkvec = context_lstm_padded_output.transpose(1, 2) if not self.context_lstm_w_f0_and_energy: diff --git a/nemo/core/classes/exportable.py b/nemo/core/classes/exportable.py index e5f7b5231600..0ac2ea663b57 100644 --- a/nemo/core/classes/exportable.py +++ b/nemo/core/classes/exportable.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import os from abc import ABC from typing import List, Union diff --git a/nemo/utils/cast_utils.py b/nemo/utils/cast_utils.py index f973a4719e24..9eb064936ea5 100644 --- a/nemo/utils/cast_utils.py +++ b/nemo/utils/cast_utils.py @@ -70,6 +70,6 @@ def __init__(self, mod): self.mod = mod def forward(self, x): - with torch.cuda.amp.autocast(enabled=False): + with avoid_float16_autocast_context(): ret = self.mod.forward(x.to(torch.float32)).to(x.dtype) return ret diff --git a/nemo/utils/export_utils.py b/nemo/utils/export_utils.py index fbe21b9cf8f8..197d3b478167 100644 --- a/nemo/utils/export_utils.py +++ b/nemo/utils/export_utils.py @@ -15,7 +15,7 @@ import os from contextlib import nullcontext from enum import Enum -from typing import Callable, Dict, List, Optional, Type +from typing import Callable, Dict, Optional, Type import onnx import torch @@ -154,14 +154,16 @@ def to_onnxrt_input(ort_input_names, input_names, input_dict, input_list): def verify_torchscript(model, output, input_examples, input_names, check_tolerance=0.01): - ts_model = torch.jit.load(output) - all_good = True for input_example in input_examples: input_list, input_dict = parse_input_example(input_example) output_example = model.forward(*input_list, **input_dict) - - all_good = all_good and run_ts_and_compare(ts_model, input_list, input_dict, output_example, check_tolerance) + # We disable autocast here to make sure exported TS will run under Triton or other C++ env + with torch.cuda.amp.autocast(enabled=False): + ts_model = torch.jit.load(output) + all_good = all_good and run_ts_and_compare( + ts_model, input_list, input_dict, output_example, check_tolerance + ) status = "SUCCESS" if all_good else "FAIL" logging.info(f"Torchscript generated at {output} verified with torchscript forward : " + status) return all_good @@ -204,9 +206,15 @@ def run_ts_and_compare(ts_model, ts_input_list, ts_input_dict, output_example, c if torch.is_tensor(expected): tout = out.to('cpu') logging.debug(f"Checking output {i}, shape: {expected.shape}:\n{expected}\n{tout}") - if not torch.allclose(tout, expected.cpu(), rtol=check_tolerance, atol=check_tolerance): + this_good = True + try: + if not torch.allclose(tout, expected.cpu(), rtol=check_tolerance, atol=check_tolerance): + this_good = False + except Exception: # there may ne size mismatch and it may be OK + this_good = False + if not this_good: + logging.info(f"Results mismatch! PyTorch(expected):\n{expected}\nTorchScript:\n{tout}") all_good = False - logging.info(f"onnxruntime results mismatch! PyTorch(expected):\n{expected}\nTorchScript:\n{tout}") return all_good @@ -220,9 +228,15 @@ def run_ort_and_compare(sess, ort_input, output_example, check_tolerance=0.01): if torch.is_tensor(expected): tout = torch.from_numpy(out) logging.debug(f"Checking output {i}, shape: {expected.shape}:\n{expected}\n{tout}") - if not torch.allclose(tout, expected.cpu(), rtol=check_tolerance, atol=100 * check_tolerance): - all_good = False + this_good = True + try: + if not torch.allclose(tout, expected.cpu(), rtol=check_tolerance, atol=100 * check_tolerance): + this_good = False + except Exception: # there may ne size mismatch and it may be OK + this_good = False + if not this_good: logging.info(f"onnxruntime results mismatch! PyTorch(expected):\n{expected}\nONNXruntime:\n{tout}") + all_good = False return all_good @@ -419,8 +433,7 @@ def replace_modules( def script_module(m: nn.Module): - m1 = torch.jit.script(m) - return m1 + return torch.jit.script(m) default_replacements = { @@ -432,7 +445,6 @@ def script_module(m: nn.Module): script_replacements = { "BiLSTM": script_module, - "ConvLSTMLinear": script_module, } From d4f48e1500c9bdaffd63ce11a84e18ccd1fbb44e Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Fri, 18 Nov 2022 21:19:14 -0800 Subject: [PATCH 4/5] Added autocast to radtts UT Signed-off-by: Boris Fomitchev --- nemo/collections/tts/modules/radtts.py | 4 ++-- tests/collections/tts/test_tts_exportables.py | 4 +++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/nemo/collections/tts/modules/radtts.py b/nemo/collections/tts/modules/radtts.py index c050ff2e2d76..aca6a5c44727 100644 --- a/nemo/collections/tts/modules/radtts.py +++ b/nemo/collections/tts/modules/radtts.py @@ -771,8 +771,8 @@ def input_example(self, max_batch=1, max_dim=256): """ par = next(self.parameters()) sz = (max_batch, max_dim) - inp = torch.randint(0, 16, sz, device=par.device, dtype=torch.int64) - lens = torch.randint(16, max_dim, (max_batch,), device=par.device, dtype=torch.int) + inp = torch.randint(16, 32, sz, device=par.device, dtype=torch.int64) + lens = torch.randint(max_dim // 4, max_dim // 2, (max_batch,), device=par.device, dtype=torch.int) speaker = torch.randint(0, 1, (max_batch,), device=par.device, dtype=torch.int64) inputs = { 'text': inp, diff --git a/tests/collections/tts/test_tts_exportables.py b/tests/collections/tts/test_tts_exportables.py index e3e496373271..bf2c0842eb91 100644 --- a/tests/collections/tts/test_tts_exportables.py +++ b/tests/collections/tts/test_tts_exportables.py @@ -15,6 +15,7 @@ import tempfile import pytest +import torch from omegaconf import OmegaConf from nemo.collections.tts.models import FastPitchModel, HifiGanModel, RadTTSModel @@ -79,4 +80,5 @@ def test_RadTTSModel_export_to_torchscript(self, radtts_model): model = radtts_model.cuda() with tempfile.TemporaryDirectory() as tmpdir: filename = os.path.join(tmpdir, 'rad.ts') - model.export(output=filename, verbose=True, check_trace=True) + with torch.cuda.amp.autocast(enabled=True): + model.export(output=filename, verbose=True, check_trace=True) From 00e1fec90d3c7552b69fd9fe32ae96c95d4e6897 Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Sun, 20 Nov 2022 20:45:23 -0800 Subject: [PATCH 5/5] using cuda() for training example Signed-off-by: Boris Fomitchev --- examples/tts/radtts.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/tts/radtts.py b/examples/tts/radtts.py index 7260e8d9907f..7dbdaedced03 100644 --- a/examples/tts/radtts.py +++ b/examples/tts/radtts.py @@ -61,14 +61,14 @@ def prepare_model_weights(model, unfreeze_modules): def main(cfg): trainer = pl.Trainer(**cfg.trainer) exp_manager(trainer, cfg.get('exp_manager', None)) - model = RadTTSModel(cfg=cfg.model, trainer=trainer) + model = RadTTSModel(cfg=cfg.model, trainer=trainer).cuda() if cfg.model.load_from_checkpoint: model.maybe_init_from_pretrained_checkpoint(cfg=cfg.model) prepare_model_weights(model, cfg.model.trainerConfig.unfreeze_modules) lr_logger = pl.callbacks.LearningRateMonitor() epoch_time_logger = LogEpochTimeCallback() trainer.callbacks.extend([lr_logger, epoch_time_logger]) - trainer.fit(model) + trainer.fit(model.cuda()) if __name__ == '__main__':