Skip to content

Commit

Permalink
Merge branch 'main' into r1.10.0-megamolbart
Browse files Browse the repository at this point in the history
  • Loading branch information
michalivne committed Jul 25, 2022
2 parents 1e6a4d5 + c324499 commit 8c50ac8
Showing 1 changed file with 29 additions and 6 deletions.
35 changes: 29 additions & 6 deletions nemo/collections/asr/models/ssl_models.py
Expand Up @@ -21,6 +21,7 @@
from pytorch_lightning import Trainer

from nemo.collections.asr.data import audio_to_text_dataset
from nemo.collections.asr.data.audio_to_text_dali import DALIOutputs
from nemo.collections.asr.parts.mixins import ASRModuleMixin
from nemo.collections.asr.parts.preprocessing.perturb import process_augmentations
from nemo.core.classes import ModelPT
Expand Down Expand Up @@ -143,6 +144,18 @@ def _setup_dataloader_from_config(self, config: Optional[Dict]):
audio_to_text_dataset.inject_dataloader_value_from_model_config(self.cfg, config, key='sample_rate')

shuffle = config['shuffle']
device = 'gpu' if torch.cuda.is_available() else 'cpu'
if config.get('use_dali', False):
device_id = self.local_rank if device == 'gpu' else None
dataset = audio_to_text_dataset.get_dali_char_dataset(
config=config,
shuffle=shuffle,
device_id=device_id,
global_rank=self.global_rank,
world_size=self.world_size,
preprocessor_cfg=self._cfg.preprocessor,
)
return dataset

# Instantiate tarred dataset loader or normal dataset loader
if config.get('is_tarred', False):
Expand Down Expand Up @@ -479,9 +492,14 @@ def decoder_loss_step(self, spectrograms, spec_masks, encoded, encoded_len, targ
# PTL-specific methods
def training_step(self, batch, batch_nb):
signal, signal_len, targets, target_lengths = batch
spectrograms, spec_masks, encoded, encoded_len = self.forward(
input_signal=signal, input_signal_length=signal_len,
)
if isinstance(batch, DALIOutputs) and batch.has_processed_signal:
spectrograms, spec_masks, encoded, encoded_len = self.forward(
processed_signal=signal, processed_signal_length=signal_len,
)
else:
spectrograms, spec_masks, encoded, encoded_len = self.forward(
input_signal=signal, input_signal_length=signal_len,
)

loss_value, loss_val_dict = self.decoder_loss_step(
spectrograms, spec_masks, encoded, encoded_len, targets, target_lengths
Expand All @@ -508,9 +526,14 @@ def validation_step(self, batch, batch_idx, dataloader_idx=0):
self._in_validation_step = True

signal, signal_len, targets, target_lengths = batch
spectrograms, spec_masks, encoded, encoded_len = self.forward(
input_signal=signal, input_signal_length=signal_len,
)
if isinstance(batch, DALIOutputs) and batch.has_processed_signal:
spectrograms, spec_masks, encoded, encoded_len = self.forward(
processed_signal=signal, processed_signal_length=signal_len,
)
else:
spectrograms, spec_masks, encoded, encoded_len = self.forward(
input_signal=signal, input_signal_length=signal_len,
)

loss_value, _ = self.decoder_loss_step(spectrograms, spec_masks, encoded, encoded_len, targets, target_lengths)

Expand Down

0 comments on commit 8c50ac8

Please sign in to comment.