Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding 'use_dynamo' option for export to use onnx.dynamo_export() instead of onnx.export() #9147

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions Dockerfile.ci
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ pip install --no-cache-dir --no-build-isolation --extra-index-url https://pypi.n
"nvidia-modelopt[torch]~=${MODELOPT_VERSION}" \
"apex @ git+https://github.com/NVIDIA/apex.git@${APEX_TAG}" \
"llama-index==0.10.43" \
"onnxscript @ git+https://github.com/microsoft/onnxscript" \
-r tools/ctc_segmentation/requirements.txt \
".[all]"

Expand Down
8 changes: 4 additions & 4 deletions nemo/collections/asr/models/asr_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,12 +240,12 @@ def output_names(self):
if getattr(self.input_module, 'export_cache_support', False):
in_types = self.input_module.output_types
otypes = {n: t for (n, t) in list(otypes.items())[:1]}
for (n, t) in list(in_types.items())[1:]:
for n, t in list(in_types.items())[1:]:
otypes[n] = t
return get_io_names(otypes, self.disabled_deployment_output_names)

def forward_for_export(
self, input, length=None, cache_last_channel=None, cache_last_time=None, cache_last_channel_len=None
self, audio_signal, length=None, cache_last_channel=None, cache_last_time=None, cache_last_channel_len=None
):
"""
This forward is used when we need to export the model to ONNX format.
Expand All @@ -264,12 +264,12 @@ def forward_for_export(
"""
enc_fun = getattr(self.input_module, 'forward_for_export', self.input_module.forward)
if cache_last_channel is None:
encoder_output = enc_fun(audio_signal=input, length=length)
encoder_output = enc_fun(audio_signal=audio_signal, length=length)
if isinstance(encoder_output, tuple):
encoder_output = encoder_output[0]
else:
encoder_output, length, cache_last_channel, cache_last_time, cache_last_channel_len = enc_fun(
audio_signal=input,
audio_signal=audio_signal,
length=length,
cache_last_channel=cache_last_channel,
cache_last_time=cache_last_time,
Expand Down
4 changes: 2 additions & 2 deletions nemo/collections/asr/models/label_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,8 +333,8 @@ def output_types(self) -> Optional[Dict[str, NeuralType]]:
"embs": NeuralType(('B', 'D'), AcousticEncodedRepresentation()),
}

def forward_for_export(self, processed_signal, processed_signal_len):
encoded, length = self.encoder(audio_signal=processed_signal, length=processed_signal_len)
def forward_for_export(self, audio_signal, length):
encoded, length = self.encoder(audio_signal=audio_signal, length=length)
logits, embs = self.decoder(encoder_output=encoded, length=length)
return logits, embs

Expand Down
70 changes: 37 additions & 33 deletions nemo/collections/asr/models/msdd_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,8 +163,7 @@ def add_speaker_model_config(self, cfg):
del cfg.speaker_model_cfg.validation_ds

def _init_segmentation_info(self):
"""Initialize segmentation settings: window, shift and multiscale weights.
"""
"""Initialize segmentation settings: window, shift and multiscale weights."""
self._diarizer_params = self.cfg_msdd_model.diarizer
self.multiscale_args_dict = parse_scale_configs(
self._diarizer_params.speaker_embeddings.parameters.window_length_in_sec,
Expand Down Expand Up @@ -275,10 +274,14 @@ def __setup_dataloader_from_config_infer(
)

def setup_training_data(self, train_data_config: Optional[Union[DictConfig, Dict]]):
self._train_dl = self.__setup_dataloader_from_config(config=train_data_config,)
self._train_dl = self.__setup_dataloader_from_config(
config=train_data_config,
)

def setup_validation_data(self, val_data_layer_config: Optional[Union[DictConfig, Dict]]):
self._validation_dl = self.__setup_dataloader_from_config(config=val_data_layer_config,)
self._validation_dl = self.__setup_dataloader_from_config(
config=val_data_layer_config,
)

def setup_test_data(self, test_data_config: Optional[Union[DictConfig, Dict]]):
if self.pairwise_infer:
Expand Down Expand Up @@ -338,32 +341,32 @@ def get_ms_emb_seq(
Merged embeddings without zero-padding in the batch. See `ms_seg_counts` for details.
Shape: (Total number of segments in the batch, emb_dim)
scale_mapping (Tensor):
The element at the m-th row and the n-th column of the scale mapping matrix indicates the (m+1)-th scale
segment index which has the closest center distance with (n+1)-th segment in the base scale.
Example:
scale_mapping_argmat[2][101] = 85
In the above example, it means that 86-th segment in the 3rd scale (python index is 2) is mapped with
102-th segment in the base scale. Thus, the longer segments bound to have more repeating numbers since
multiple base scale segments (since the base scale has the shortest length) fall into the range of the
longer segments. At the same time, each row contains N numbers of indices where N is number of
segments in the base-scale (i.e., the finest scale).
The element at the m-th row and the n-th column of the scale mapping matrix indicates the (m+1)-th scale
segment index which has the closest center distance with (n+1)-th segment in the base scale.
Example:
scale_mapping_argmat[2][101] = 85
In the above example, it means that 86-th segment in the 3rd scale (python index is 2) is mapped with
102-th segment in the base scale. Thus, the longer segments bound to have more repeating numbers since
multiple base scale segments (since the base scale has the shortest length) fall into the range of the
longer segments. At the same time, each row contains N numbers of indices where N is number of
segments in the base-scale (i.e., the finest scale).
Shape: (batch_size, scale_n, self.diar_window_length)
ms_seg_counts (Tensor):
Cumulative sum of the number of segments in each scale. This information is needed to reconstruct
the multi-scale input matrix during forward propagating.

Example: `batch_size=3, scale_n=6, emb_dim=192`
ms_seg_counts =
[[8, 9, 12, 16, 25, 51],
[11, 13, 14, 17, 25, 51],
[ 9, 9, 11, 16, 23, 50]]
Example: `batch_size=3, scale_n=6, emb_dim=192`
ms_seg_counts =
[[8, 9, 12, 16, 25, 51],
[11, 13, 14, 17, 25, 51],
[ 9, 9, 11, 16, 23, 50]]

In this function, `ms_seg_counts` is used to get the actual length of each embedding sequence without
zero-padding.
In this function, `ms_seg_counts` is used to get the actual length of each embedding sequence without
zero-padding.

Returns:
ms_emb_seq (Tensor):
Multi-scale embedding sequence that is mapped, matched and repeated. The longer scales are less repeated,
Multi-scale embedding sequence that is mapped, matched and repeated. The longer scales are less repeated,
while shorter scales are more frequently repeated following the scale mapping tensor.
"""
scale_n, batch_size = scale_mapping[0].shape[0], scale_mapping.shape[0]
Expand Down Expand Up @@ -409,9 +412,9 @@ def get_cluster_avg_embs_model(
[ 9, 9, 11, 16, 23, 50]
]

Counts of merged segments: (121, 131, 118)
embs has shape of (370, 192)
clus_label_index has shape of (3, 131)
Counts of merged segments: (121, 131, 118)
embs has shape of (370, 192)
clus_label_index has shape of (3, 131)

Shape: (batch_size, scale_n)

Expand Down Expand Up @@ -553,7 +556,7 @@ def forward(
with torch.no_grad():
self.msdd._speaker_model.eval()
logits, embs_d = self.msdd._speaker_model.forward_for_export(
processed_signal=audio_signal[detach_ids[1]], processed_signal_len=audio_signal_len[detach_ids[1]]
audio_signal=audio_signal[detach_ids[1]], length=audio_signal_len[detach_ids[1]]
)
embs = torch.zeros(audio_signal.shape[0], embs_d.shape[1]).to(embs_d.device)
embs[detach_ids[1], :] = embs_d.detach()
Expand Down Expand Up @@ -854,9 +857,9 @@ def run_clustering_diarizer(self, manifest_filepath: str, emb_dir: str):
os.makedirs(self.out_rttm_dir, exist_ok=True)

self.clus_diar_model._cluster_params = self.cfg_diar_infer.diarizer.clustering.parameters
self.clus_diar_model.multiscale_args_dict[
"multiscale_weights"
] = self.cfg_diar_infer.diarizer.speaker_embeddings.parameters.multiscale_weights
self.clus_diar_model.multiscale_args_dict["multiscale_weights"] = (
self.cfg_diar_infer.diarizer.speaker_embeddings.parameters.multiscale_weights
)
self.clus_diar_model._diarizer_params.speaker_embeddings.parameters = (
self.cfg_diar_infer.diarizer.speaker_embeddings.parameters
)
Expand Down Expand Up @@ -1076,7 +1079,6 @@ def extract_standalone_speaker_model(self, prefix: str = 'msdd._speaker_model.')
return _speaker_model

def _init_msdd_model(self, cfg: Union[DictConfig, NeuralDiarizerInferenceConfig]):

"""
Initialized MSDD model with the provided config. Load either from `.nemo` file or `.ckpt` checkpoint files.
"""
Expand Down Expand Up @@ -1128,7 +1130,7 @@ def get_pred_mat(self, data_list: List[Union[Tuple[int], List[torch.Tensor]]]) -
digit_map = dict(zip(sorted(set(all_tups)), range(n_est_spks)))
total_len = max([sess[1].shape[1] for sess in data_list])
sum_pred = torch.zeros(total_len, n_est_spks)
for (_dim_tup, pred_mat) in data_list:
for _dim_tup, pred_mat in data_list:
dim_tup = [digit_map[x] for x in _dim_tup]
if len(pred_mat.shape) == 3:
pred_mat = pred_mat.squeeze(0)
Expand Down Expand Up @@ -1167,8 +1169,7 @@ def get_integrated_preds_list(
return output_list

def get_emb_clus_infer(self, cluster_embeddings):
"""Assign dictionaries containing the clustering results from the class instance `cluster_embeddings`.
"""
"""Assign dictionaries containing the clustering results from the class instance `cluster_embeddings`."""
self.msdd_model.emb_sess_test_dict = cluster_embeddings.emb_sess_test_dict
self.msdd_model.clus_test_label_dict = cluster_embeddings.clus_test_label_dict
self.msdd_model.emb_seq_test = cluster_embeddings.emb_seq_test
Expand Down Expand Up @@ -1456,7 +1457,10 @@ def from_pretrained(
"""
logging.setLevel(logging.INFO if verbose else logging.WARNING)
cfg = NeuralDiarizerInferenceConfig.init_config(
diar_model_path=model_name, vad_model_path=vad_model_name, map_location=map_location, verbose=verbose,
diar_model_path=model_name,
vad_model_path=vad_model_name,
map_location=map_location,
verbose=verbose,
)
return cls(cfg)

Expand Down
3 changes: 1 addition & 2 deletions nemo/collections/asr/modules/conformer_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,7 @@ def streaming_post_process(self, rets, keep_all_outputs=True):
def forward(
self, audio_signal, length, cache_last_channel=None, cache_last_time=None, cache_last_channel_len=None
):
self.update_max_seq_length(seq_length=audio_signal.size(2), device=audio_signal.device)
return self.forward_internal(
audio_signal,
length,
Expand All @@ -512,8 +513,6 @@ def forward(
def forward_internal(
self, audio_signal, length, cache_last_channel=None, cache_last_time=None, cache_last_channel_len=None
):
self.update_max_seq_length(seq_length=audio_signal.size(2), device=audio_signal.device)

if length is None:
length = audio_signal.new_full(
(audio_signal.size(0),), audio_signal.size(-1), dtype=torch.int64, device=audio_signal.device
Expand Down
29 changes: 16 additions & 13 deletions nemo/collections/asr/parts/preprocessing/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def clean_spectrogram_batch(spectrogram: torch.Tensor, spectrogram_len: torch.Te


def splice_frames(x, frame_splicing):
""" Stacks frames together across feature dim
"""Stacks frames together across feature dim

input is batch_size, feature_dim, num_frames
output is batch_size, feature_dim*frame_splicing, num_frames
Expand Down Expand Up @@ -261,7 +261,7 @@ def __init__(
highfreq=None,
log=True,
log_zero_guard_type="add",
log_zero_guard_value=2 ** -24,
log_zero_guard_value=2**-24,
dither=CONSTANT,
pad_to=16,
max_duration=16.7,
Expand Down Expand Up @@ -308,6 +308,7 @@ def __init__(
self.hop_length = n_window_stride
self.n_fft = n_fft or 2 ** math.ceil(math.log2(self.win_length))
self.stft_pad_amount = (self.n_fft - self.hop_length) // 2 if exact_pad else None
self.exact_pad = exact_pad

if exact_pad:
logging.info("STFT using exact pad")
Expand All @@ -321,15 +322,6 @@ def __init__(
window_fn = torch_windows.get(window, None)
window_tensor = window_fn(self.win_length, periodic=False) if window_fn else None
self.register_buffer("window", window_tensor)
self.stft = lambda x: torch.stft(
x,
n_fft=self.n_fft,
hop_length=self.hop_length,
win_length=self.win_length,
center=False if exact_pad else True,
window=self.window.to(dtype=torch.float),
return_complex=True,
)

self.normalize = normalize
self.log = log
Expand Down Expand Up @@ -388,6 +380,17 @@ def __init__(
logging.debug(f"using grads: {use_grads}")
logging.debug(f"nb_augmentation_prob: {nb_augmentation_prob}")

def stft(self, x):
return torch.stft(
x,
n_fft=self.n_fft,
hop_length=self.hop_length,
win_length=self.win_length,
center=False if self.exact_pad else True,
window=self.window.to(dtype=torch.float),
return_complex=True,
)

def log_zero_guard_value_fn(self, x):
if isinstance(self.log_zero_guard_value, str):
if self.log_zero_guard_value == "tiny":
Expand Down Expand Up @@ -508,7 +511,7 @@ def __init__(
highfreq: Optional[float] = None,
log: bool = True,
log_zero_guard_type: str = "add",
log_zero_guard_value: Union[float, str] = 2 ** -24,
log_zero_guard_value: Union[float, str] = 2**-24,
dither: float = 1e-5,
window: str = "hann",
pad_to: int = 0,
Expand Down Expand Up @@ -579,7 +582,7 @@ def __init__(

@property
def filter_banks(self):
""" Matches the analogous class """
"""Matches the analogous class"""
return self._mel_spec_extractor.mel_scale.fb

def _resolve_log_zero_guard_value(self, dtype: torch.dtype) -> float:
Expand Down
6 changes: 3 additions & 3 deletions nemo/collections/asr/parts/submodules/jasper.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,7 @@ def forward_for_export(self, x, lengths):
mask = self.make_pad_mask(lengths, max_audio_length=max_len, device=x.device)
mask = ~mask # 0 represents value, 1 represents pad
x = x.float() # For stable AMP, SE must be computed at fp32.
x.masked_fill_(mask, 0.0) # mask padded values explicitly to 0
x = x.masked_fill(mask, 0.0) # mask padded values explicitly to 0
y = self._se_pool_step(x, mask) # [B, C, 1]
y = y.transpose(1, -1) # [B, 1, C]
y = self.fc(y) # [B, 1, C]
Expand Down Expand Up @@ -510,8 +510,8 @@ def _se_pool_step(self, x, mask):
return y

def set_max_len(self, max_len, seq_range=None):
""" Sets maximum input length.
Pre-calculates internal seq_range mask.
"""Sets maximum input length.
Pre-calculates internal seq_range mask.
"""
self.max_len = max_len
if seq_range is None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,11 @@ def __getitem__(self, idx):


def build_train_valid_test_datasets(
cfg, retro_config: RetroConfig, train_valid_test_num_samples, seq_length, tokenizer,
cfg,
retro_config: RetroConfig,
train_valid_test_num_samples,
seq_length,
tokenizer,
):

# gpt dataset
Expand All @@ -135,7 +139,10 @@ def build_train_valid_test_datasets(
}

retro_train_ds, retro_valid_ds, retro_test_ds = get_retro_datasets(
config=retro_config, gpt_datasets=gpt_datasets, sample_length=seq_length, eod_token_id=tokenizer.eos_id,
config=retro_config,
gpt_datasets=gpt_datasets,
sample_length=seq_length,
eod_token_id=tokenizer.eos_id,
)

train_ds = (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from nemo.collections.nlp.modules.common.megatron.utils import ApexGuardDefaults

try:
from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add
from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear
Expand Down
22 changes: 13 additions & 9 deletions nemo/collections/tts/modules/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def __init__(self, n_head, d_model, d_head, dropout, dropatt=0.1, pre_lnorm=Fals
self.n_head = n_head
self.d_model = d_model
self.d_head = d_head
self.scale = 1 / (d_head ** 0.5)
self.scale = 1 / (d_head**0.5)
self.pre_lnorm = pre_lnorm

self.qkv_net = nn.Linear(d_model, 3 * n_head * d_head)
Expand All @@ -125,13 +125,17 @@ def _forward(self, inp, attn_mask=None, conditioning=None):

head_q, head_k, head_v = torch.chunk(self.qkv_net(inp), 3, dim=2)

head_q = head_q.view(inp.size(0), inp.size(1), n_head, d_head)
head_k = head_k.view(inp.size(0), inp.size(1), n_head, d_head)
head_v = head_v.view(inp.size(0), inp.size(1), n_head, d_head)
s0 = inp.size(0)
s1 = inp.size(1)
s2 = s0 * n_head

q = head_q.permute(2, 0, 1, 3).reshape(-1, inp.size(1), d_head)
k = head_k.permute(2, 0, 1, 3).reshape(-1, inp.size(1), d_head)
v = head_v.permute(2, 0, 1, 3).reshape(-1, inp.size(1), d_head)
head_q = head_q.view(s0, s1, n_head, d_head)
head_k = head_k.view(s0, s1, n_head, d_head)
head_v = head_v.view(s0, s1, n_head, d_head)

q = head_q.permute(2, 0, 1, 3).reshape(s2, s1, d_head)
k = head_k.permute(2, 0, 1, 3).reshape(s2, s1, d_head)
v = head_v.permute(2, 0, 1, 3).reshape(s2, s1, d_head)

attn_score = torch.bmm(q, k.transpose(1, 2))
attn_score.mul_(self.scale)
Expand All @@ -145,8 +149,8 @@ def _forward(self, inp, attn_mask=None, conditioning=None):
attn_prob = self.dropatt(attn_prob)
attn_vec = torch.bmm(attn_prob, v)

attn_vec = attn_vec.view(n_head, inp.size(0), inp.size(1), d_head)
attn_vec = attn_vec.permute(1, 2, 0, 3).contiguous().view(inp.size(0), inp.size(1), n_head * d_head)
attn_vec = attn_vec.view(n_head, s0, s1, d_head)
attn_vec = attn_vec.permute(1, 2, 0, 3).contiguous().view(s0, s1, n_head * d_head)

# linear projection
attn_out = self.o_net(attn_vec)
Expand Down