Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding 'use_dynamo' option for export to use onnx.dynamo_export() instead of onnx.export() #9147

Merged
merged 21 commits into from
Jun 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
53102bc
Ininial WARs to implement dynamo option for export
borisfom May 8, 2024
d3c41f7
including weights in .onnx
borisfom May 8, 2024
e9e81b0
dynamo_export works for many small models
borisfom May 16, 2024
7907435
External weights behaviour fixed
borisfom May 17, 2024
732c119
Cleanup
borisfom Jun 7, 2024
a5caddb
Merge remote-tracking branch 'upstream/main' into undecorate-typecheck
borisfom Jun 7, 2024
ea50c5e
Merge remote-tracking branch 'upstream/main' into undecorate-typecheck
borisfom Jun 13, 2024
dd21b74
Apply isort and black reformatting
borisfom Jun 13, 2024
6c299d6
print cleaned up
borisfom Jun 13, 2024
7dcdc76
Merge branch 'undecorate-typecheck' of github.com:borisfom/NeMo into …
borisfom Jun 13, 2024
f425d8a
Added overloadable dynamic_shapes_for_export
borisfom Jun 13, 2024
c7a5e84
Addressing code review
borisfom Jun 15, 2024
1a28fe1
Merge remote-tracking branch 'upstream/main' into undecorate-typecheck
borisfom Jun 15, 2024
a88f1c2
Merge remote-tracking branch 'upstream/main' into undecorate-typecheck
borisfom Jun 18, 2024
21a5882
Fixing CI issues
borisfom Jun 19, 2024
7a8209a
Merge remote-tracking branch 'upstream/main' into undecorate-typecheck
borisfom Jun 26, 2024
9e04cc8
Merge branch 'main' into undecorate-typecheck
ericharper Jun 26, 2024
5289130
Fixing CI test failure
borisfom Jun 26, 2024
a762cc2
Merge branch 'undecorate-typecheck' of github.com:borisfom/NeMo into …
borisfom Jun 26, 2024
f804d26
Eliminated test cross-contamination
borisfom Jun 27, 2024
f6b2b5a
Merge branch 'main' into undecorate-typecheck
titu1994 Jun 27, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Dockerfile.ci
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ pip install --no-cache-dir --no-build-isolation --extra-index-url https://pypi.n
"nvidia-modelopt[torch]~=${MODELOPT_VERSION}" \
"apex @ git+https://github.com/NVIDIA/apex.git@${APEX_TAG}" \
"llama-index==0.10.43" \
"onnxscript @ git+https://github.com/microsoft/onnxscript" \
-r tools/ctc_segmentation/requirements.txt \
".[all]"

Expand Down
8 changes: 4 additions & 4 deletions nemo/collections/asr/models/asr_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,12 +240,12 @@ def output_names(self):
if getattr(self.input_module, 'export_cache_support', False):
in_types = self.input_module.output_types
otypes = {n: t for (n, t) in list(otypes.items())[:1]}
for (n, t) in list(in_types.items())[1:]:
for n, t in list(in_types.items())[1:]:
otypes[n] = t
return get_io_names(otypes, self.disabled_deployment_output_names)

def forward_for_export(
self, input, length=None, cache_last_channel=None, cache_last_time=None, cache_last_channel_len=None
self, audio_signal, length=None, cache_last_channel=None, cache_last_time=None, cache_last_channel_len=None
):
"""
This forward is used when we need to export the model to ONNX format.
Expand All @@ -264,12 +264,12 @@ def forward_for_export(
"""
enc_fun = getattr(self.input_module, 'forward_for_export', self.input_module.forward)
if cache_last_channel is None:
encoder_output = enc_fun(audio_signal=input, length=length)
encoder_output = enc_fun(audio_signal=audio_signal, length=length)
if isinstance(encoder_output, tuple):
encoder_output = encoder_output[0]
else:
encoder_output, length, cache_last_channel, cache_last_time, cache_last_channel_len = enc_fun(
audio_signal=input,
audio_signal=audio_signal,
length=length,
cache_last_channel=cache_last_channel,
cache_last_time=cache_last_time,
Expand Down
4 changes: 2 additions & 2 deletions nemo/collections/asr/models/label_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,8 +333,8 @@ def output_types(self) -> Optional[Dict[str, NeuralType]]:
"embs": NeuralType(('B', 'D'), AcousticEncodedRepresentation()),
}

def forward_for_export(self, processed_signal, processed_signal_len):
encoded, length = self.encoder(audio_signal=processed_signal, length=processed_signal_len)
def forward_for_export(self, audio_signal, length):
encoded, length = self.encoder(audio_signal=audio_signal, length=length)
logits, embs = self.decoder(encoder_output=encoded, length=length)
return logits, embs

Expand Down
70 changes: 37 additions & 33 deletions nemo/collections/asr/models/msdd_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,8 +163,7 @@ def add_speaker_model_config(self, cfg):
del cfg.speaker_model_cfg.validation_ds

def _init_segmentation_info(self):
"""Initialize segmentation settings: window, shift and multiscale weights.
"""
"""Initialize segmentation settings: window, shift and multiscale weights."""
self._diarizer_params = self.cfg_msdd_model.diarizer
self.multiscale_args_dict = parse_scale_configs(
self._diarizer_params.speaker_embeddings.parameters.window_length_in_sec,
Expand Down Expand Up @@ -275,10 +274,14 @@ def __setup_dataloader_from_config_infer(
)

def setup_training_data(self, train_data_config: Optional[Union[DictConfig, Dict]]):
self._train_dl = self.__setup_dataloader_from_config(config=train_data_config,)
self._train_dl = self.__setup_dataloader_from_config(
config=train_data_config,
)

def setup_validation_data(self, val_data_layer_config: Optional[Union[DictConfig, Dict]]):
self._validation_dl = self.__setup_dataloader_from_config(config=val_data_layer_config,)
self._validation_dl = self.__setup_dataloader_from_config(
config=val_data_layer_config,
)

def setup_test_data(self, test_data_config: Optional[Union[DictConfig, Dict]]):
if self.pairwise_infer:
Expand Down Expand Up @@ -338,32 +341,32 @@ def get_ms_emb_seq(
Merged embeddings without zero-padding in the batch. See `ms_seg_counts` for details.
Shape: (Total number of segments in the batch, emb_dim)
scale_mapping (Tensor):
The element at the m-th row and the n-th column of the scale mapping matrix indicates the (m+1)-th scale
segment index which has the closest center distance with (n+1)-th segment in the base scale.
Example:
scale_mapping_argmat[2][101] = 85
In the above example, it means that 86-th segment in the 3rd scale (python index is 2) is mapped with
102-th segment in the base scale. Thus, the longer segments bound to have more repeating numbers since
multiple base scale segments (since the base scale has the shortest length) fall into the range of the
longer segments. At the same time, each row contains N numbers of indices where N is number of
segments in the base-scale (i.e., the finest scale).
The element at the m-th row and the n-th column of the scale mapping matrix indicates the (m+1)-th scale
segment index which has the closest center distance with (n+1)-th segment in the base scale.
Example:
scale_mapping_argmat[2][101] = 85
In the above example, it means that 86-th segment in the 3rd scale (python index is 2) is mapped with
102-th segment in the base scale. Thus, the longer segments bound to have more repeating numbers since
multiple base scale segments (since the base scale has the shortest length) fall into the range of the
longer segments. At the same time, each row contains N numbers of indices where N is number of
segments in the base-scale (i.e., the finest scale).
Shape: (batch_size, scale_n, self.diar_window_length)
ms_seg_counts (Tensor):
Cumulative sum of the number of segments in each scale. This information is needed to reconstruct
the multi-scale input matrix during forward propagating.

Example: `batch_size=3, scale_n=6, emb_dim=192`
ms_seg_counts =
[[8, 9, 12, 16, 25, 51],
[11, 13, 14, 17, 25, 51],
[ 9, 9, 11, 16, 23, 50]]
Example: `batch_size=3, scale_n=6, emb_dim=192`
ms_seg_counts =
[[8, 9, 12, 16, 25, 51],
[11, 13, 14, 17, 25, 51],
[ 9, 9, 11, 16, 23, 50]]

In this function, `ms_seg_counts` is used to get the actual length of each embedding sequence without
zero-padding.
In this function, `ms_seg_counts` is used to get the actual length of each embedding sequence without
zero-padding.

Returns:
ms_emb_seq (Tensor):
Multi-scale embedding sequence that is mapped, matched and repeated. The longer scales are less repeated,
Multi-scale embedding sequence that is mapped, matched and repeated. The longer scales are less repeated,
while shorter scales are more frequently repeated following the scale mapping tensor.
"""
scale_n, batch_size = scale_mapping[0].shape[0], scale_mapping.shape[0]
Expand Down Expand Up @@ -409,9 +412,9 @@ def get_cluster_avg_embs_model(
[ 9, 9, 11, 16, 23, 50]
]

Counts of merged segments: (121, 131, 118)
embs has shape of (370, 192)
clus_label_index has shape of (3, 131)
Counts of merged segments: (121, 131, 118)
embs has shape of (370, 192)
clus_label_index has shape of (3, 131)

Shape: (batch_size, scale_n)

Expand Down Expand Up @@ -553,7 +556,7 @@ def forward(
with torch.no_grad():
self.msdd._speaker_model.eval()
logits, embs_d = self.msdd._speaker_model.forward_for_export(
processed_signal=audio_signal[detach_ids[1]], processed_signal_len=audio_signal_len[detach_ids[1]]
audio_signal=audio_signal[detach_ids[1]], length=audio_signal_len[detach_ids[1]]
)
embs = torch.zeros(audio_signal.shape[0], embs_d.shape[1]).to(embs_d.device)
embs[detach_ids[1], :] = embs_d.detach()
Expand Down Expand Up @@ -854,9 +857,9 @@ def run_clustering_diarizer(self, manifest_filepath: str, emb_dir: str):
os.makedirs(self.out_rttm_dir, exist_ok=True)

self.clus_diar_model._cluster_params = self.cfg_diar_infer.diarizer.clustering.parameters
self.clus_diar_model.multiscale_args_dict[
"multiscale_weights"
] = self.cfg_diar_infer.diarizer.speaker_embeddings.parameters.multiscale_weights
self.clus_diar_model.multiscale_args_dict["multiscale_weights"] = (
self.cfg_diar_infer.diarizer.speaker_embeddings.parameters.multiscale_weights
)
self.clus_diar_model._diarizer_params.speaker_embeddings.parameters = (
self.cfg_diar_infer.diarizer.speaker_embeddings.parameters
)
Expand Down Expand Up @@ -1076,7 +1079,6 @@ def extract_standalone_speaker_model(self, prefix: str = 'msdd._speaker_model.')
return _speaker_model

def _init_msdd_model(self, cfg: Union[DictConfig, NeuralDiarizerInferenceConfig]):

"""
Initialized MSDD model with the provided config. Load either from `.nemo` file or `.ckpt` checkpoint files.
"""
Expand Down Expand Up @@ -1128,7 +1130,7 @@ def get_pred_mat(self, data_list: List[Union[Tuple[int], List[torch.Tensor]]]) -
digit_map = dict(zip(sorted(set(all_tups)), range(n_est_spks)))
total_len = max([sess[1].shape[1] for sess in data_list])
sum_pred = torch.zeros(total_len, n_est_spks)
for (_dim_tup, pred_mat) in data_list:
for _dim_tup, pred_mat in data_list:
dim_tup = [digit_map[x] for x in _dim_tup]
if len(pred_mat.shape) == 3:
pred_mat = pred_mat.squeeze(0)
Expand Down Expand Up @@ -1167,8 +1169,7 @@ def get_integrated_preds_list(
return output_list

def get_emb_clus_infer(self, cluster_embeddings):
"""Assign dictionaries containing the clustering results from the class instance `cluster_embeddings`.
"""
"""Assign dictionaries containing the clustering results from the class instance `cluster_embeddings`."""
self.msdd_model.emb_sess_test_dict = cluster_embeddings.emb_sess_test_dict
self.msdd_model.clus_test_label_dict = cluster_embeddings.clus_test_label_dict
self.msdd_model.emb_seq_test = cluster_embeddings.emb_seq_test
Expand Down Expand Up @@ -1456,7 +1457,10 @@ def from_pretrained(
"""
logging.setLevel(logging.INFO if verbose else logging.WARNING)
cfg = NeuralDiarizerInferenceConfig.init_config(
diar_model_path=model_name, vad_model_path=vad_model_name, map_location=map_location, verbose=verbose,
diar_model_path=model_name,
vad_model_path=vad_model_name,
map_location=map_location,
verbose=verbose,
)
return cls(cfg)

Expand Down
3 changes: 1 addition & 2 deletions nemo/collections/asr/modules/conformer_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,7 @@ def streaming_post_process(self, rets, keep_all_outputs=True):
def forward(
self, audio_signal, length, cache_last_channel=None, cache_last_time=None, cache_last_channel_len=None
):
self.update_max_seq_length(seq_length=audio_signal.size(2), device=audio_signal.device)
return self.forward_internal(
audio_signal,
length,
Expand All @@ -512,8 +513,6 @@ def forward(
def forward_internal(
self, audio_signal, length, cache_last_channel=None, cache_last_time=None, cache_last_channel_len=None
):
self.update_max_seq_length(seq_length=audio_signal.size(2), device=audio_signal.device)

if length is None:
length = audio_signal.new_full(
(audio_signal.size(0),), audio_signal.size(-1), dtype=torch.int64, device=audio_signal.device
Expand Down
29 changes: 16 additions & 13 deletions nemo/collections/asr/parts/preprocessing/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def clean_spectrogram_batch(spectrogram: torch.Tensor, spectrogram_len: torch.Te


def splice_frames(x, frame_splicing):
""" Stacks frames together across feature dim
"""Stacks frames together across feature dim

input is batch_size, feature_dim, num_frames
output is batch_size, feature_dim*frame_splicing, num_frames
Expand Down Expand Up @@ -261,7 +261,7 @@ def __init__(
highfreq=None,
log=True,
log_zero_guard_type="add",
log_zero_guard_value=2 ** -24,
log_zero_guard_value=2**-24,
dither=CONSTANT,
pad_to=16,
max_duration=16.7,
Expand Down Expand Up @@ -308,6 +308,7 @@ def __init__(
self.hop_length = n_window_stride
self.n_fft = n_fft or 2 ** math.ceil(math.log2(self.win_length))
self.stft_pad_amount = (self.n_fft - self.hop_length) // 2 if exact_pad else None
self.exact_pad = exact_pad

if exact_pad:
logging.info("STFT using exact pad")
Expand All @@ -321,15 +322,6 @@ def __init__(
window_fn = torch_windows.get(window, None)
window_tensor = window_fn(self.win_length, periodic=False) if window_fn else None
self.register_buffer("window", window_tensor)
self.stft = lambda x: torch.stft(
x,
n_fft=self.n_fft,
hop_length=self.hop_length,
win_length=self.win_length,
center=False if exact_pad else True,
window=self.window.to(dtype=torch.float),
return_complex=True,
)

self.normalize = normalize
self.log = log
Expand Down Expand Up @@ -388,6 +380,17 @@ def __init__(
logging.debug(f"using grads: {use_grads}")
logging.debug(f"nb_augmentation_prob: {nb_augmentation_prob}")

def stft(self, x):
return torch.stft(
x,
n_fft=self.n_fft,
hop_length=self.hop_length,
win_length=self.win_length,
center=False if self.exact_pad else True,
window=self.window.to(dtype=torch.float),
return_complex=True,
)

def log_zero_guard_value_fn(self, x):
if isinstance(self.log_zero_guard_value, str):
if self.log_zero_guard_value == "tiny":
Expand Down Expand Up @@ -508,7 +511,7 @@ def __init__(
highfreq: Optional[float] = None,
log: bool = True,
log_zero_guard_type: str = "add",
log_zero_guard_value: Union[float, str] = 2 ** -24,
log_zero_guard_value: Union[float, str] = 2**-24,
dither: float = 1e-5,
window: str = "hann",
pad_to: int = 0,
Expand Down Expand Up @@ -579,7 +582,7 @@ def __init__(

@property
def filter_banks(self):
""" Matches the analogous class """
"""Matches the analogous class"""
return self._mel_spec_extractor.mel_scale.fb

def _resolve_log_zero_guard_value(self, dtype: torch.dtype) -> float:
Expand Down
6 changes: 3 additions & 3 deletions nemo/collections/asr/parts/submodules/jasper.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,7 @@ def forward_for_export(self, x, lengths):
mask = self.make_pad_mask(lengths, max_audio_length=max_len, device=x.device)
mask = ~mask # 0 represents value, 1 represents pad
x = x.float() # For stable AMP, SE must be computed at fp32.
x.masked_fill_(mask, 0.0) # mask padded values explicitly to 0
x = x.masked_fill(mask, 0.0) # mask padded values explicitly to 0
y = self._se_pool_step(x, mask) # [B, C, 1]
y = y.transpose(1, -1) # [B, 1, C]
y = self.fc(y) # [B, 1, C]
Expand Down Expand Up @@ -510,8 +510,8 @@ def _se_pool_step(self, x, mask):
return y

def set_max_len(self, max_len, seq_range=None):
""" Sets maximum input length.
Pre-calculates internal seq_range mask.
"""Sets maximum input length.
Pre-calculates internal seq_range mask.
"""
self.max_len = max_len
if seq_range is None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,11 @@ def __getitem__(self, idx):


def build_train_valid_test_datasets(
cfg, retro_config: RetroConfig, train_valid_test_num_samples, seq_length, tokenizer,
cfg,
retro_config: RetroConfig,
train_valid_test_num_samples,
seq_length,
tokenizer,
):

# gpt dataset
Expand All @@ -135,7 +139,10 @@ def build_train_valid_test_datasets(
}

retro_train_ds, retro_valid_ds, retro_test_ds = get_retro_datasets(
config=retro_config, gpt_datasets=gpt_datasets, sample_length=seq_length, eod_token_id=tokenizer.eos_id,
config=retro_config,
gpt_datasets=gpt_datasets,
sample_length=seq_length,
eod_token_id=tokenizer.eos_id,
)

train_ds = (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from nemo.collections.nlp.modules.common.megatron.utils import ApexGuardDefaults

try:
from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add
from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear
Expand Down
22 changes: 13 additions & 9 deletions nemo/collections/tts/modules/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def __init__(self, n_head, d_model, d_head, dropout, dropatt=0.1, pre_lnorm=Fals
self.n_head = n_head
self.d_model = d_model
self.d_head = d_head
self.scale = 1 / (d_head ** 0.5)
self.scale = 1 / (d_head**0.5)
self.pre_lnorm = pre_lnorm

self.qkv_net = nn.Linear(d_model, 3 * n_head * d_head)
Expand All @@ -125,13 +125,17 @@ def _forward(self, inp, attn_mask=None, conditioning=None):

head_q, head_k, head_v = torch.chunk(self.qkv_net(inp), 3, dim=2)

head_q = head_q.view(inp.size(0), inp.size(1), n_head, d_head)
head_k = head_k.view(inp.size(0), inp.size(1), n_head, d_head)
head_v = head_v.view(inp.size(0), inp.size(1), n_head, d_head)
s0 = inp.size(0)
s1 = inp.size(1)
s2 = s0 * n_head

q = head_q.permute(2, 0, 1, 3).reshape(-1, inp.size(1), d_head)
k = head_k.permute(2, 0, 1, 3).reshape(-1, inp.size(1), d_head)
v = head_v.permute(2, 0, 1, 3).reshape(-1, inp.size(1), d_head)
head_q = head_q.view(s0, s1, n_head, d_head)
head_k = head_k.view(s0, s1, n_head, d_head)
head_v = head_v.view(s0, s1, n_head, d_head)

q = head_q.permute(2, 0, 1, 3).reshape(s2, s1, d_head)
k = head_k.permute(2, 0, 1, 3).reshape(s2, s1, d_head)
v = head_v.permute(2, 0, 1, 3).reshape(s2, s1, d_head)

attn_score = torch.bmm(q, k.transpose(1, 2))
attn_score.mul_(self.scale)
Expand All @@ -145,8 +149,8 @@ def _forward(self, inp, attn_mask=None, conditioning=None):
attn_prob = self.dropatt(attn_prob)
attn_vec = torch.bmm(attn_prob, v)

attn_vec = attn_vec.view(n_head, inp.size(0), inp.size(1), d_head)
attn_vec = attn_vec.permute(1, 2, 0, 3).contiguous().view(inp.size(0), inp.size(1), n_head * d_head)
attn_vec = attn_vec.view(n_head, s0, s1, d_head)
attn_vec = attn_vec.permute(1, 2, 0, 3).contiguous().view(s0, s1, n_head * d_head)

# linear projection
attn_out = self.o_net(attn_vec)
Expand Down
Loading
Loading