diff --git a/examples/tts/radtts.py b/examples/tts/radtts.py index 7260e8d9907f..7dbdaedced03 100644 --- a/examples/tts/radtts.py +++ b/examples/tts/radtts.py @@ -61,14 +61,14 @@ def prepare_model_weights(model, unfreeze_modules): def main(cfg): trainer = pl.Trainer(**cfg.trainer) exp_manager(trainer, cfg.get('exp_manager', None)) - model = RadTTSModel(cfg=cfg.model, trainer=trainer) + model = RadTTSModel(cfg=cfg.model, trainer=trainer).cuda() if cfg.model.load_from_checkpoint: model.maybe_init_from_pretrained_checkpoint(cfg=cfg.model) prepare_model_weights(model, cfg.model.trainerConfig.unfreeze_modules) lr_logger = pl.callbacks.LearningRateMonitor() epoch_time_logger = LogEpochTimeCallback() trainer.callbacks.extend([lr_logger, epoch_time_logger]) - trainer.fit(model) + trainer.fit(model.cuda()) if __name__ == '__main__': diff --git a/nemo/collections/tts/modules/common.py b/nemo/collections/tts/modules/common.py index 5d93340b4c2d..906de05cd8c8 100644 --- a/nemo/collections/tts/modules/common.py +++ b/nemo/collections/tts/modules/common.py @@ -119,39 +119,29 @@ def __init__(self, input_size, hidden_size, num_layers=1, lstm_norm_fn="spectral lstm_norm_fn_pntr(self.bilstm, 'weight_hh_l0_reverse') self.bilstm.flatten_parameters() - @torch.jit.export def lstm_tensor(self, context: Tensor, lens: Tensor, enforce_sorted: bool = False) -> Tuple[Tensor, Tensor]: seq = nn.utils.rnn.pack_padded_sequence( context, lens.long().cpu(), batch_first=True, enforce_sorted=enforce_sorted ) - if not (torch.jit.is_scripting() or torch.jit.is_tracing()): - self.bilstm.flatten_parameters() - if hasattr(self.bilstm, 'forward'): - ret, _ = self.bilstm.forward(seq) - else: - ret, _ = self.bilstm.forward_1(seq) - return nn.utils.rnn.pad_packed_sequence(ret, batch_first=True) + return self.lstm_sequence(seq) - @torch.jit.export def lstm_sequence(self, seq: PackedSequence) -> Tuple[Tensor, Tensor]: if not (torch.jit.is_scripting() or torch.jit.is_tracing()): self.bilstm.flatten_parameters() - if hasattr(self.bilstm, 'forward'): - ret, _ = self.bilstm.forward(seq) - elif hasattr(self.bilstm, 'forward_1'): - ret, _ = self.bilstm.forward_1(seq) + ret, _ = self.bilstm(seq) return nn.utils.rnn.pad_packed_sequence(ret, batch_first=True) - @torch.jit.export - def sort_and_lstm_tensor(self, context: Tensor, lens: Tensor) -> Tensor: + def forward(self, context: Tensor, lens: Tensor) -> Tensor: context, lens_sorted, unsort_ids = sort_tensor(context, lens) - seq = nn.utils.rnn.pack_padded_sequence( - context, lens_sorted.long().cpu(), batch_first=True, enforce_sorted=True - ) - return self.lstm_sequence(seq)[0][unsort_ids] + dtype = context.dtype + # this is only needed for Torchscript to run in Triton + # (https://github.com/pytorch/pytorch/issues/89241) + with torch.cuda.amp.autocast(enabled=False): + ret = self.lstm_tensor(context.to(dtype=torch.float32), lens_sorted, enforce_sorted=True) + return ret[0].to(dtype=dtype)[unsort_ids] -class ConvLSTMLinear(BiLSTM): +class ConvLSTMLinear(nn.Module): def __init__( self, in_dim=None, @@ -163,7 +153,8 @@ def __init__( use_partial_padding=False, norm_fn=None, ): - super(ConvLSTMLinear, self).__init__(n_channels, int(n_channels // 2), 1) + super(ConvLSTMLinear, self).__init__() + self.bilstm = BiLSTM(n_channels, int(n_channels // 2), 1) self.convolutions = nn.ModuleList() if n_layers > 0: @@ -194,27 +185,16 @@ def __init__( if out_dim is not None: self.dense = nn.Linear(n_channels, out_dim) - def masked_conv_to_sequence(self, context: Tensor, lens: Tensor, enforce_sorted: bool = False) -> PackedSequence: + def forward(self, context: Tensor, lens: Tensor) -> Tensor: mask = get_mask_from_lengths_and_val(lens, context) mask = mask.to(dtype=context.dtype).unsqueeze(1) for conv in self.convolutions: context = self.dropout(F.relu(conv(context, mask))) - context = context.transpose(1, 2) - seq = torch.nn.utils.rnn.pack_padded_sequence( - context, lens.long().cpu(), batch_first=True, enforce_sorted=enforce_sorted - ) - return seq - - def forward(self, context: Tensor, lens: Tensor) -> Tensor: - context, lens, unsort_ids = sort_tensor(context, lens) - seq = self.masked_conv_to_sequence(context, lens, enforce_sorted=True) - context, _ = self.lstm_sequence(seq) - context = context[unsort_ids] - + # Apply Bidirectional LSTM + context = self.bilstm(context, lens) if self.dense is not None: context = self.dense(context).permute(0, 2, 1) - return context diff --git a/nemo/collections/tts/modules/radtts.py b/nemo/collections/tts/modules/radtts.py index d41e7dd628e5..aca6a5c44727 100644 --- a/nemo/collections/tts/modules/radtts.py +++ b/nemo/collections/tts/modules/radtts.py @@ -346,9 +346,7 @@ def preprocess_context(self, context, speaker_vecs, out_lens, f0, energy_avg): context_w_spkvec = torch.cat((context_w_spkvec, energy_avg), 1) unfolded_out_lens = out_lens // self.n_group_size - context_lstm_padded_output = self.context_lstm.sort_and_lstm_tensor( - context_w_spkvec.transpose(1, 2), unfolded_out_lens - ) + context_lstm_padded_output = self.context_lstm(context_w_spkvec.transpose(1, 2), unfolded_out_lens) context_w_spkvec = context_lstm_padded_output.transpose(1, 2) if not self.context_lstm_w_f0_and_energy: @@ -773,8 +771,8 @@ def input_example(self, max_batch=1, max_dim=256): """ par = next(self.parameters()) sz = (max_batch, max_dim) - inp = torch.randint(0, 16, sz, device=par.device, dtype=torch.int64) - lens = torch.randint(16, max_dim, (max_batch,), device=par.device, dtype=torch.int) + inp = torch.randint(16, 32, sz, device=par.device, dtype=torch.int64) + lens = torch.randint(max_dim // 4, max_dim // 2, (max_batch,), device=par.device, dtype=torch.int) speaker = torch.randint(0, 1, (max_batch,), device=par.device, dtype=torch.int64) inputs = { 'text': inp, diff --git a/nemo/utils/export_utils.py b/nemo/utils/export_utils.py index fbe21b9cf8f8..197d3b478167 100644 --- a/nemo/utils/export_utils.py +++ b/nemo/utils/export_utils.py @@ -15,7 +15,7 @@ import os from contextlib import nullcontext from enum import Enum -from typing import Callable, Dict, List, Optional, Type +from typing import Callable, Dict, Optional, Type import onnx import torch @@ -154,14 +154,16 @@ def to_onnxrt_input(ort_input_names, input_names, input_dict, input_list): def verify_torchscript(model, output, input_examples, input_names, check_tolerance=0.01): - ts_model = torch.jit.load(output) - all_good = True for input_example in input_examples: input_list, input_dict = parse_input_example(input_example) output_example = model.forward(*input_list, **input_dict) - - all_good = all_good and run_ts_and_compare(ts_model, input_list, input_dict, output_example, check_tolerance) + # We disable autocast here to make sure exported TS will run under Triton or other C++ env + with torch.cuda.amp.autocast(enabled=False): + ts_model = torch.jit.load(output) + all_good = all_good and run_ts_and_compare( + ts_model, input_list, input_dict, output_example, check_tolerance + ) status = "SUCCESS" if all_good else "FAIL" logging.info(f"Torchscript generated at {output} verified with torchscript forward : " + status) return all_good @@ -204,9 +206,15 @@ def run_ts_and_compare(ts_model, ts_input_list, ts_input_dict, output_example, c if torch.is_tensor(expected): tout = out.to('cpu') logging.debug(f"Checking output {i}, shape: {expected.shape}:\n{expected}\n{tout}") - if not torch.allclose(tout, expected.cpu(), rtol=check_tolerance, atol=check_tolerance): + this_good = True + try: + if not torch.allclose(tout, expected.cpu(), rtol=check_tolerance, atol=check_tolerance): + this_good = False + except Exception: # there may ne size mismatch and it may be OK + this_good = False + if not this_good: + logging.info(f"Results mismatch! PyTorch(expected):\n{expected}\nTorchScript:\n{tout}") all_good = False - logging.info(f"onnxruntime results mismatch! PyTorch(expected):\n{expected}\nTorchScript:\n{tout}") return all_good @@ -220,9 +228,15 @@ def run_ort_and_compare(sess, ort_input, output_example, check_tolerance=0.01): if torch.is_tensor(expected): tout = torch.from_numpy(out) logging.debug(f"Checking output {i}, shape: {expected.shape}:\n{expected}\n{tout}") - if not torch.allclose(tout, expected.cpu(), rtol=check_tolerance, atol=100 * check_tolerance): - all_good = False + this_good = True + try: + if not torch.allclose(tout, expected.cpu(), rtol=check_tolerance, atol=100 * check_tolerance): + this_good = False + except Exception: # there may ne size mismatch and it may be OK + this_good = False + if not this_good: logging.info(f"onnxruntime results mismatch! PyTorch(expected):\n{expected}\nONNXruntime:\n{tout}") + all_good = False return all_good @@ -419,8 +433,7 @@ def replace_modules( def script_module(m: nn.Module): - m1 = torch.jit.script(m) - return m1 + return torch.jit.script(m) default_replacements = { @@ -432,7 +445,6 @@ def script_module(m: nn.Module): script_replacements = { "BiLSTM": script_module, - "ConvLSTMLinear": script_module, } diff --git a/tests/collections/tts/test_tts_exportables.py b/tests/collections/tts/test_tts_exportables.py index e3e496373271..bf2c0842eb91 100644 --- a/tests/collections/tts/test_tts_exportables.py +++ b/tests/collections/tts/test_tts_exportables.py @@ -15,6 +15,7 @@ import tempfile import pytest +import torch from omegaconf import OmegaConf from nemo.collections.tts.models import FastPitchModel, HifiGanModel, RadTTSModel @@ -79,4 +80,5 @@ def test_RadTTSModel_export_to_torchscript(self, radtts_model): model = radtts_model.cuda() with tempfile.TemporaryDirectory() as tmpdir: filename = os.path.join(tmpdir, 'rad.ts') - model.export(output=filename, verbose=True, check_trace=True) + with torch.cuda.amp.autocast(enabled=True): + model.export(output=filename, verbose=True, check_trace=True)