Skip to content

Commit

Permalink
fixed docs.
Browse files Browse the repository at this point in the history
Signed-off-by: Vahid <vnoroozi@nvidia.com>
  • Loading branch information
VahidooX committed Oct 22, 2020
1 parent ac86cb3 commit 75aa1e4
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 19 deletions.
28 changes: 10 additions & 18 deletions nemo/collections/asr/modules/conformer_encoder.py
Expand Up @@ -39,7 +39,8 @@ class ConformerEncoder(NeuralModule, Exportable):
"""
The encoder for ASR model of Conformer.
Based on this paper:
https://arxiv.org/abs/2005.08100
'Conformer: Convolution-augmented Transformer for Speech Recognition' by Anmol Gulati et al.
https://arxiv.org/abs/2005.08100
"""

def _prepare_for_export(self):
Expand Down Expand Up @@ -98,7 +99,7 @@ def __init__(
feat_in,
n_layers,
d_model,
feat_out=0,
feat_out=-1,
subsampling='vggnet',
subsampling_factor=4,
subsampling_conv_channels=64,
Expand Down Expand Up @@ -206,13 +207,7 @@ def forward(self, audio_signal, length):

@staticmethod
def make_pad_mask(seq_lens, max_time, device=None):
"""Make masking for padding.
Args:
seq_lens (IntTensor): `[B]`
device_id (int):
Returns:
mask (IntTensor): `[B, T]`
"""
"""Make masking for padding."""
bs = seq_lens.size(0)
seq_range = torch.arange(0, max_time, dtype=torch.int32)
seq_range_expand = seq_range.unsqueeze(0).expand(bs, max_time)
Expand Down Expand Up @@ -268,18 +263,15 @@ def __init__(self, d_model, d_ff, conv_kernel_size, self_attention_model, n_head
self.dropout = nn.Dropout(dropout)
self.norm_out = LayerNorm(d_model)

def forward(self, x, att_mask=None, pos_emb=None, u_bias=None, v_bias=None, pad_mask=None):
def forward(self, x, att_mask=None, pos_emb=None, pad_mask=None):
"""
Args:
x (FloatTensor): `[B, T, d_model]`
att_mask (ByteTensor): `[B, T, T]`
pos_emb (LongTensor): `[L, 1, d_model]`
u (FloatTensor): global parameter for relative positional embedding
v (FloatTensor): global parameter for relative positional embedding
x (torch.Tensor): input signals (B, T, d_model)
att_mask (torch.Tensor): attention masks(B, T, T)
pos_emb (torch.Tensor): (L, 1, d_model)
pad_mask (torch.tensor): padding mask
Returns:
xs (FloatTensor): `[B, T, d_model]`
xx_aws (FloatTensor): `[B, H, T, T]`
x (torch.Tensor): (B, T, d_model)
"""
residual = x
x = self.norm_feed_forward1(x)
Expand Down
1 change: 0 additions & 1 deletion nemo/collections/asr/modules/conformer_modules.py
Expand Up @@ -70,7 +70,6 @@ class ConformerFeedForward(nn.Module):
"""
feed-forward module of Conformer model.
"""

def __init__(self, d_model, d_ff, dropout, activation=Swish()):
super(ConformerFeedForward, self).__init__()
self.linear1 = nn.Linear(d_model, d_ff)
Expand Down

0 comments on commit 75aa1e4

Please sign in to comment.