Skip to content

Commit

Permalink
added stacking suport to conformer. (#4045)
Browse files Browse the repository at this point in the history
  • Loading branch information
VahidooX committed Apr 27, 2022
1 parent da1b56c commit 0d052c8
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 18 deletions.
34 changes: 20 additions & 14 deletions nemo/collections/asr/modules/conformer_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from nemo.collections.asr.parts.submodules.conformer_modules import ConformerLayer
from nemo.collections.asr.parts.submodules.multi_head_attention import PositionalEncoding, RelPositionalEncoding
from nemo.collections.asr.parts.submodules.subsampling import ConvSubsampling
from nemo.collections.asr.parts.submodules.subsampling import ConvSubsampling, StackingSubsampling
from nemo.core.classes.common import typecheck
from nemo.core.classes.exportable import Exportable
from nemo.core.classes.module import NeuralModule
Expand Down Expand Up @@ -149,18 +149,23 @@ def __init__(
if subsampling_conv_channels == -1:
subsampling_conv_channels = d_model
if subsampling and subsampling_factor > 1:
self.pre_encode = ConvSubsampling(
subsampling=subsampling,
subsampling_factor=subsampling_factor,
feat_in=feat_in,
feat_out=d_model,
conv_channels=subsampling_conv_channels,
activation=nn.ReLU(),
)
self._feat_out = d_model
if subsampling == 'stacking':
self.pre_encode = StackingSubsampling(
subsampling_factor=subsampling_factor, feat_in=feat_in, feat_out=d_model
)
else:
self.pre_encode = ConvSubsampling(
subsampling=subsampling,
subsampling_factor=subsampling_factor,
feat_in=feat_in,
feat_out=d_model,
conv_channels=subsampling_conv_channels,
activation=nn.ReLU(),
)
else:
self.pre_encode = nn.Linear(feat_in, d_model)
self._feat_out = d_model

self._feat_out = d_model

if not untie_biases and self_attention_model == "rel_pos":
d_head = d_model // n_heads
Expand Down Expand Up @@ -247,10 +252,11 @@ def forward_for_export(self, audio_signal, length):

audio_signal = torch.transpose(audio_signal, 1, 2)

if isinstance(self.pre_encode, ConvSubsampling):
audio_signal, length = self.pre_encode(audio_signal, length)
else:
if isinstance(self.pre_encode, nn.Linear):
audio_signal = self.pre_encode(audio_signal)
else:
audio_signal, length = self.pre_encode(audio_signal, length)

audio_signal, pos_emb = self.pos_enc(audio_signal)
# adjust size
max_audio_length = audio_signal.size(1)
Expand Down
6 changes: 3 additions & 3 deletions nemo/collections/asr/modules/rnn_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,10 +161,10 @@ def forward(self, audio_signal, length=None):

audio_signal = torch.transpose(audio_signal, 1, 2)

if isinstance(self.pre_encode, ConvSubsampling) or isinstance(self.pre_encode, StackingSubsampling):
audio_signal, length = self.pre_encode(audio_signal, length)
else:
if isinstance(self.pre_encode, nn.Linear):
audio_signal = self.pre_encode(audio_signal)
else:
audio_signal, length = self.pre_encode(audio_signal, length)

for lth, layer in enumerate(self.layers):
audio_signal = layer(audio_signal)
Expand Down
2 changes: 1 addition & 1 deletion nemo/collections/asr/parts/submodules/subsampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def __init__(self, subsampling_factor, feat_in, feat_out):

def forward(self, x, lengths):
b, t, h = x.size()
pad_size = self.subsampling_factor - (t % self.subsampling_factor)
pad_size = (self.subsampling_factor - (t % self.subsampling_factor)) % self.subsampling_factor
x = torch.nn.functional.pad(x, (0, 0, 0, pad_size))
_, t, _ = x.size()
x = torch.reshape(x, (b, t // self.subsampling_factor, h * self.subsampling_factor))
Expand Down

0 comments on commit 0d052c8

Please sign in to comment.