Skip to content

Commit

Permalink
Adding the original change made for label_models (NVIDIA#9377) (NVIDI…
Browse files Browse the repository at this point in the history
…A#9378)

Signed-off-by: Taejin Park <tango4j@gmail.com>
Co-authored-by: Taejin Park <tango4j@gmail.com>
Signed-off-by: Boxiang Wang <boxiangw@nvidia.com>
  • Loading branch information
2 people authored and BoxiangW committed Jun 5, 2024
1 parent 0cdf29e commit cfa0043
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions nemo/collections/asr/models/label_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):
if 'loss' in cfg:
cfg_eval_loss = copy.deepcopy(cfg.loss)

if 'angular' in cfg.loss._target_:
if '_target_' in cfg.loss and 'angular' in cfg.loss._target_:
OmegaConf.set_struct(cfg, True)
with open_dict(cfg):
cfg.decoder.angular = True
Expand Down Expand Up @@ -341,7 +341,8 @@ def forward_for_export(self, processed_signal, processed_signal_len):
@typecheck()
def forward(self, input_signal, input_signal_length):
processed_signal, processed_signal_len = self.preprocessor(
input_signal=input_signal, length=input_signal_length,
input_signal=input_signal,
length=input_signal_length,
)

if self.spec_augmentation is not None and self.training:
Expand Down Expand Up @@ -627,7 +628,9 @@ def batch_inference(self, manifest_filepath, batch_size=32, sample_rate=16000, d
dataset = AudioToSpeechLabelDataset(manifest_filepath=manifest_filepath, labels=None, featurizer=featurizer)

dataloader = torch.utils.data.DataLoader(
dataset=dataset, batch_size=batch_size, collate_fn=dataset.fixed_seq_collate_fn,
dataset=dataset,
batch_size=batch_size,
collate_fn=dataset.fixed_seq_collate_fn,
)

logits = []
Expand Down

0 comments on commit cfa0043

Please sign in to comment.