Skip to content

Commit

Permalink
add padding based on encoder output
Browse files Browse the repository at this point in the history
Signed-off-by: Nithin Rao Koluguri <nithinraok>
  • Loading branch information
Nithin Rao Koluguri committed Apr 3, 2024
1 parent e772dbf commit 3374027
Showing 1 changed file with 10 additions and 4 deletions.
14 changes: 10 additions & 4 deletions nemo/collections/asr/losses/ssl_losses/contrastive.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from math import ceil

import torch
import torch.nn.functional as F
from torch import nn
Expand Down Expand Up @@ -147,13 +149,17 @@ def sample_negatives(self, y, num):

@typecheck()
def forward(self, spectrograms, spec_masks, decoder_outputs, decoder_lengths=None):
spec_in = spectrograms.transpose(-2, -1)
targets = spectrograms.transpose(-2, -1)
masks = spec_masks.transpose(-2, -1)
targets = spec_in
# BxTxC
diff = int(ceil(targets.shape[1] / decoder_outputs.shape[1]) * decoder_outputs.shape[1]) - targets.shape[1]

if diff > 0:
targets = F.pad(targets, (0, 0, 0, diff))
masks = F.pad(masks, (0, 0, 0, diff))

targets = targets.reshape(targets.shape[0], targets.shape[1] // self.combine_time_steps, -1)
masks = masks.reshape(targets.shape[0], targets.shape[1], -1)
targets = targets.reshape(targets.shape[0], decoder_outputs.shape[1], -1)
masks = masks.reshape(targets.shape[0], decoder_outputs.shape[1], -1)

if self.quantized_targets:
if self.store_ids:
Expand Down

0 comments on commit 3374027

Please sign in to comment.