Skip to content

Commit

Permalink
Updated docs.
Browse files Browse the repository at this point in the history
Signed-off-by: Vahid <vnoroozi@nvidia.com>
  • Loading branch information
VahidooX committed Oct 22, 2020
1 parent a4c4c79 commit 2c900aa
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 77 deletions.
4 changes: 2 additions & 2 deletions nemo/collections/asr/modules/conformer_encoder.py
Expand Up @@ -126,8 +126,8 @@ def __init__(
self.pre_encode = ConvSubsampling(
subsampling=subsampling,
subsampling_factor=subsampling_factor,
idim=feat_in,
odim=d_model,
feat_in=feat_in,
feat_out=d_model,
conv_channels=subsampling_conv_channels,
activation=nn.ReLU(),
)
Expand Down
9 changes: 8 additions & 1 deletion nemo/collections/asr/modules/lstm_decoder.py
Expand Up @@ -27,7 +27,14 @@

class LSTMDecoder(NeuralModule, Exportable):
"""
Simple LSTM Decoder for use with CTC-based ASR models
Simple LSTM Decoder for ASR models
Args:
feat_in (int): size of the input features
num_classes (int): the size of the vocabulary
lstm_hidden_size (int): hidden size of the LSTM layers
vocabulary (vocab): The vocabulary
bidirectional (bool): default is False. Whether LSTMs are bidirectional or not
num_layers (int): default is 1. Number of LSTM layers stacked
"""

def save_to(self, save_path: str):
Expand Down
107 changes: 56 additions & 51 deletions nemo/collections/asr/modules/multi_head_attention.py
Expand Up @@ -36,9 +36,10 @@

class MultiHeadAttention(nn.Module):
"""Multi-Head Attention layer.
:param int n_head: the number of head s
:param int n_feat: the number of features
:param float dropout_rate: dropout rate
Args:
n_head (int): number of heads
n_feat (int): size of the features
dropout_rate (float): dropout rate
"""

def __init__(self, n_head, n_feat, dropout_rate):
Expand All @@ -56,29 +57,34 @@ def __init__(self, n_head, n_feat, dropout_rate):
self.dropout = nn.Dropout(p=dropout_rate)

def forward_qkv(self, query, key, value):
"""Transform query, key and value.
:param torch.Tensor query: (batch, time1, size)
:param torch.Tensor key: (batch, time2, size)
:param torch.Tensor value: (batch, time2, size)
:return torch.Tensor transformed query, key and value
"""Transforms query, key and value.
Args:
query (torch.Tensor): (batch, time1, size)
key (torch.Tensor): (batch, time2, size)
value (torch.Tensor): (batch, time2, size)
returns:
q (torch.Tensor): (batch, head, time1, size)
k (torch.Tensor): (batch, head, time2, size)
v (torch.Tensor): (batch, head, time2, size)
"""
n_batch = query.size(0)
q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
q = q.transpose(1, 2) # (batch, head, time1, d_k)
k = k.transpose(1, 2) # (batch, head, time2, d_k)
v = v.transpose(1, 2) # (batch, head, time2, d_k)
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)

return q, k, v

def forward_attention(self, value, scores, mask):
"""Compute attention context vector.
:param torch.Tensor value: (batch, time2, size)
:param torch.Tensor scores: (batch, time1, time2)
:param torch.Tensor mask: (batch, time1, time2)
:return torch.Tensor transformed `value` (batch, time2, d_model)
weighted by the attention score (batch, time1, time2)
Args:
value (torch.Tensor): (batch, time2, size)
scores(torch.Tensor): (batch, time1, time2)
mask(torch.Tensor): (batch, time1, time2)
returns:
value (torch.Tensor): transformed `value` (batch, time2, d_model) weighted by the attention scores
"""
n_batch = value.size(0)
if mask is not None:
Expand All @@ -97,13 +103,13 @@ def forward_attention(self, value, scores, mask):

def forward(self, query, key, value, mask):
"""Compute 'Scaled Dot Product Attention'.
:param torch.Tensor query: (batch, time1, size)
:param torch.Tensor key: (batch, time2, size)
:param torch.Tensor value: (batch, time2, size)
:param torch.Tensor mask: (batch, time1, time2)
:param torch.nn.Dropout dropout:
:return torch.Tensor: attentined and transformed `value` (batch, time1, d_model)
weighted by the query dot key attention (batch, head, time1, time2)
Args:
query (torch.Tensor): (batch, time1, size)
key (torch.Tensor): (batch, time2, size)
value(torch.Tensor): (batch, time2, size)
mask (torch.Tensor): (batch, time1, time2)
returns:
output (torch.Tensor): transformed `value` (batch, time1, d_model) weighted by the query dot key attention
"""
q, k, v = self.forward_qkv(query, key, value)
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
Expand All @@ -113,9 +119,10 @@ def forward(self, query, key, value, mask):
class RelPositionMultiHeadAttention(MultiHeadAttention):
"""Multi-Head Attention layer with relative position encoding.
Paper: https://arxiv.org/abs/1901.02860
:param int n_head: the number of head s
:param int n_feat: the number of features
:param float dropout_rate: dropout rate
Args:
n_head (int): number of heads
n_feat (int): size of the features
dropout_rate (float): dropout rate
"""

def __init__(self, n_head, n_feat, dropout_rate):
Expand All @@ -132,8 +139,9 @@ def __init__(self, n_head, n_feat, dropout_rate):

def rel_shift(self, x, zero_triu=False):
"""Compute relative positinal encoding.
:param torch.Tensor x: (batch, time, size)
:param bool zero_triu: return the lower triangular part of the matrix
Args:
x (torch.Tensor): (batch, time, size)
zero_triu (bool): return the lower triangular part of the matrix
"""
zero_pad = torch.zeros((*x.size()[:3], 1), device=x.device, dtype=x.dtype)
x_padded = torch.cat([zero_pad, x], dim=-1)
Expand All @@ -149,14 +157,14 @@ def rel_shift(self, x, zero_triu=False):

def forward(self, query, key, value, mask, pos_emb):
"""Compute 'Scaled Dot Product Attention' with rel. positional encoding.
:param torch.Tensor query: (batch, time1, size)
:param torch.Tensor key: (batch, time2, size)
:param torch.Tensor value: (batch, time2, size)
:param torch.Tensor pos_emb: (batch, time1, size)
:param torch.Tensor mask: (batch, time1, time2)
:param torch.nn.Dropout dropout:
:return torch.Tensor: attentined and transformed `value` (batch, time1, d_model)
weighted by the query dot key attention (batch, head, time1, time2)
Args:
query (torch.Tensor): (batch, time1, size)
key (torch.Tensor): (batch, time2, size)
value(torch.Tensor): (batch, time2, size)
mask (torch.Tensor): (batch, time1, time2)
pos_emb (torch.Tensor) : (batch, time1, size)
Returns:
output (torch.Tensor): transformed `value` (batch, time1, d_model) weighted by the query dot key attention
"""
q, k, v = self.forward_qkv(query, key, value)
q = q.transpose(1, 2) # (batch, time1, head, d_k)
Expand Down Expand Up @@ -188,10 +196,11 @@ def forward(self, query, key, value, mask, pos_emb):

class PositionalEncoding(torch.nn.Module):
"""Positional encoding.
:param int d_model: embedding dim
:param float dropout_rate: dropout rate
:param int max_len: maximum input length
:param reverse: whether to reverse the input position
Args:
d_model (int): embedding dim
dropout_rate (float): dropout rate
max_len (int): maximum input length
reverse (int): whether to reverse the input position
"""

def __init__(self, d_model, dropout_rate, max_len=5000, reverse=False, xscale=None):
Expand Down Expand Up @@ -229,7 +238,7 @@ def forward(self, x: torch.Tensor):
Args:
x (torch.Tensor): Input. Its shape is (batch, time, ...)
Returns:
torch.Tensor: Encoded tensor. Its shape is (batch, time, ...)
Encoded Output (torch.Tensor): Its shape is (batch, time, ...)
"""
self.extend_pe(x)
if self.xscale:
Expand All @@ -241,17 +250,13 @@ def forward(self, x: torch.Tensor):
class RelPositionalEncoding(PositionalEncoding):
"""Relitive positional encoding module.
See : Appendix B in https://arxiv.org/abs/1901.02860
:param int d_model: embedding dim
:param float dropout_rate: dropout rate
:param int max_len: maximum input length
Args:
d_model (int): embedding dim
dropout_rate (float): dropout rate
max_len (int): maximum input length
"""

def __init__(self, d_model, dropout_rate, max_len=5000, dropout_emb_rate=0.0, xscale=None):
"""Initialize class.
:param int d_model: embedding dim
:param float dropout_rate: dropout rate
:param int max_len: maximum input length
"""
super().__init__(d_model, dropout_rate, max_len, reverse=True, xscale=xscale)

if dropout_emb_rate > 0:
Expand All @@ -264,8 +269,8 @@ def forward(self, x):
Args:
x (torch.Tensor): Input. Its shape is (batch, time, ...)
Returns:
torch.Tensor: x. Its shape is (batch, time, ...)
torch.Tensor: pos_emb. Its shape is (1, time, ...)
x (torch.Tensor): Its shape is (batch, time, ...)
pos_emb (torch.Tensor): Its shape is (1, time, ...)
"""
self.extend_pe(x)
if self.xscale:
Expand Down
42 changes: 19 additions & 23 deletions nemo/collections/asr/modules/subsampling.py
Expand Up @@ -19,19 +19,21 @@


class ConvSubsampling(torch.nn.Module):
"""Convolutional 2D subsampling (to 1/4 length).
:param int idim: input dim
:param int odim: output dim
:param str activation: activation functions
:param flaot dropout_rate: dropout rate
"""Convolutional subsampling which supports VGGNet and striding approach introduced in:
VGGNet Subsampling: https://arxiv.org/pdf/1910.12977.pdf
Striding Subsampling:
"Speech-Transformer: A No-Recurrence Sequence-to-Sequence Model for Speech Recognition" by Linhao Dong et al.
Args:
subsampling (str): The subsampling technique from {"vggnet", "striding"}
subsampling_factor (int): The subsampling factor which should be a power of 2
feat_in (int): size of the input features
feat_out (int): size of the output features
conv_channels (int): Number of channels for the convolution layers.
activation (Module): activation function, default is nn.ReLU()
"""

def __init__(self, subsampling, subsampling_factor, idim, odim, conv_channels=-1, activation=nn.ReLU()):
def __init__(self, subsampling, subsampling_factor, feat_in, feat_out, conv_channels, activation=nn.ReLU()):
super(ConvSubsampling, self).__init__()
if conv_channels <= 0:
conv_channels = odim
self._subsampling = subsampling

if subsampling_factor % 2 != 0:
Expand All @@ -44,7 +46,7 @@ def __init__(self, subsampling, subsampling_factor, idim, odim, conv_channels=-1
self._padding = 0
self._stride = 2
self._kernel_size = 2
self._ceil_mode = True # TODO: is False better?
self._ceil_mode = True

for i in range(self._sampling_num):
layers.append(
Expand All @@ -69,7 +71,7 @@ def __init__(self, subsampling, subsampling_factor, idim, odim, conv_channels=-1
)
in_channels = conv_channels
elif subsampling == 'striding':
self._padding = 1 # TODO: is 0 better?
self._padding = 1
self._stride = 2
self._kernel_size = 3
self._ceil_mode = False
Expand All @@ -89,7 +91,7 @@ def __init__(self, subsampling, subsampling_factor, idim, odim, conv_channels=-1
else:
raise ValueError(f"Not valid sub-sampling: {subsampling}!")

in_length = idim
in_length = feat_in
for i in range(self._sampling_num):
out_length = calc_length(
length=int(in_length),
Expand All @@ -100,23 +102,16 @@ def __init__(self, subsampling, subsampling_factor, idim, odim, conv_channels=-1
)
in_length = out_length

self.out = torch.nn.Linear(conv_channels * out_length, odim)
self.out = torch.nn.Linear(conv_channels * out_length, feat_out)
self.conv = torch.nn.Sequential(*layers)

def forward(self, x, lengths):
"""Subsample x.
:param torch.Tensor x: input tensor
:param torch.Tensor x_mask: input mask
:return: subsampled x and mask
:rtype Tuple[torch.Tensor or Tuple[torch.Tensor, torch.Tensor], torch.Tensor]
"""
x = x.unsqueeze(1) # (b, c, t, f)
x = x.unsqueeze(1)
x = self.conv(x)
b, c, t, f = x.size()
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))

# TODO: improve the performance of here
# TODO: improve the performance of length calculation
new_lengths = lengths
for i in range(self._sampling_num):
new_lengths = [
Expand All @@ -135,6 +130,7 @@ def forward(self, x, lengths):


def calc_length(length, padding, kernel_size, stride, ceil_mode):
""" Calculates the output length of a Tensor passed through a convolution or max pooling layer"""
if ceil_mode:
length = math.ceil((length + (2 * padding) - (kernel_size - 1) - 1) / float(stride) + 1)
else:
Expand Down

0 comments on commit 2c900aa

Please sign in to comment.