From 2c900aa2210681fd54480024469dcdc0af1ea802 Mon Sep 17 00:00:00 2001 From: Vahid Date: Wed, 21 Oct 2020 17:21:34 -0700 Subject: [PATCH] Updated docs. Signed-off-by: Vahid --- .../asr/modules/conformer_encoder.py | 4 +- nemo/collections/asr/modules/lstm_decoder.py | 9 +- .../asr/modules/multi_head_attention.py | 107 +++++++++--------- nemo/collections/asr/modules/subsampling.py | 42 ++++--- 4 files changed, 85 insertions(+), 77 deletions(-) diff --git a/nemo/collections/asr/modules/conformer_encoder.py b/nemo/collections/asr/modules/conformer_encoder.py index 779394a5ba05..41416b06ee1b 100644 --- a/nemo/collections/asr/modules/conformer_encoder.py +++ b/nemo/collections/asr/modules/conformer_encoder.py @@ -126,8 +126,8 @@ def __init__( self.pre_encode = ConvSubsampling( subsampling=subsampling, subsampling_factor=subsampling_factor, - idim=feat_in, - odim=d_model, + feat_in=feat_in, + feat_out=d_model, conv_channels=subsampling_conv_channels, activation=nn.ReLU(), ) diff --git a/nemo/collections/asr/modules/lstm_decoder.py b/nemo/collections/asr/modules/lstm_decoder.py index 149d9b1fb9fa..1a00c9077be0 100644 --- a/nemo/collections/asr/modules/lstm_decoder.py +++ b/nemo/collections/asr/modules/lstm_decoder.py @@ -27,7 +27,14 @@ class LSTMDecoder(NeuralModule, Exportable): """ - Simple LSTM Decoder for use with CTC-based ASR models + Simple LSTM Decoder for ASR models + Args: + feat_in (int): size of the input features + num_classes (int): the size of the vocabulary + lstm_hidden_size (int): hidden size of the LSTM layers + vocabulary (vocab): The vocabulary + bidirectional (bool): default is False. Whether LSTMs are bidirectional or not + num_layers (int): default is 1. Number of LSTM layers stacked """ def save_to(self, save_path: str): diff --git a/nemo/collections/asr/modules/multi_head_attention.py b/nemo/collections/asr/modules/multi_head_attention.py index f2c26a1e39ad..5a9501d963a5 100644 --- a/nemo/collections/asr/modules/multi_head_attention.py +++ b/nemo/collections/asr/modules/multi_head_attention.py @@ -36,9 +36,10 @@ class MultiHeadAttention(nn.Module): """Multi-Head Attention layer. - :param int n_head: the number of head s - :param int n_feat: the number of features - :param float dropout_rate: dropout rate + Args: + n_head (int): number of heads + n_feat (int): size of the features + dropout_rate (float): dropout rate """ def __init__(self, n_head, n_feat, dropout_rate): @@ -56,29 +57,34 @@ def __init__(self, n_head, n_feat, dropout_rate): self.dropout = nn.Dropout(p=dropout_rate) def forward_qkv(self, query, key, value): - """Transform query, key and value. - :param torch.Tensor query: (batch, time1, size) - :param torch.Tensor key: (batch, time2, size) - :param torch.Tensor value: (batch, time2, size) - :return torch.Tensor transformed query, key and value + """Transforms query, key and value. + Args: + query (torch.Tensor): (batch, time1, size) + key (torch.Tensor): (batch, time2, size) + value (torch.Tensor): (batch, time2, size) + returns: + q (torch.Tensor): (batch, head, time1, size) + k (torch.Tensor): (batch, head, time2, size) + v (torch.Tensor): (batch, head, time2, size) """ n_batch = query.size(0) q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k) k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k) v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k) - q = q.transpose(1, 2) # (batch, head, time1, d_k) - k = k.transpose(1, 2) # (batch, head, time2, d_k) - v = v.transpose(1, 2) # (batch, head, time2, d_k) + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) return q, k, v def forward_attention(self, value, scores, mask): """Compute attention context vector. - :param torch.Tensor value: (batch, time2, size) - :param torch.Tensor scores: (batch, time1, time2) - :param torch.Tensor mask: (batch, time1, time2) - :return torch.Tensor transformed `value` (batch, time2, d_model) - weighted by the attention score (batch, time1, time2) + Args: + value (torch.Tensor): (batch, time2, size) + scores(torch.Tensor): (batch, time1, time2) + mask(torch.Tensor): (batch, time1, time2) + returns: + value (torch.Tensor): transformed `value` (batch, time2, d_model) weighted by the attention scores """ n_batch = value.size(0) if mask is not None: @@ -97,13 +103,13 @@ def forward_attention(self, value, scores, mask): def forward(self, query, key, value, mask): """Compute 'Scaled Dot Product Attention'. - :param torch.Tensor query: (batch, time1, size) - :param torch.Tensor key: (batch, time2, size) - :param torch.Tensor value: (batch, time2, size) - :param torch.Tensor mask: (batch, time1, time2) - :param torch.nn.Dropout dropout: - :return torch.Tensor: attentined and transformed `value` (batch, time1, d_model) - weighted by the query dot key attention (batch, head, time1, time2) + Args: + query (torch.Tensor): (batch, time1, size) + key (torch.Tensor): (batch, time2, size) + value(torch.Tensor): (batch, time2, size) + mask (torch.Tensor): (batch, time1, time2) + returns: + output (torch.Tensor): transformed `value` (batch, time1, d_model) weighted by the query dot key attention """ q, k, v = self.forward_qkv(query, key, value) scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k) @@ -113,9 +119,10 @@ def forward(self, query, key, value, mask): class RelPositionMultiHeadAttention(MultiHeadAttention): """Multi-Head Attention layer with relative position encoding. Paper: https://arxiv.org/abs/1901.02860 - :param int n_head: the number of head s - :param int n_feat: the number of features - :param float dropout_rate: dropout rate + Args: + n_head (int): number of heads + n_feat (int): size of the features + dropout_rate (float): dropout rate """ def __init__(self, n_head, n_feat, dropout_rate): @@ -132,8 +139,9 @@ def __init__(self, n_head, n_feat, dropout_rate): def rel_shift(self, x, zero_triu=False): """Compute relative positinal encoding. - :param torch.Tensor x: (batch, time, size) - :param bool zero_triu: return the lower triangular part of the matrix + Args: + x (torch.Tensor): (batch, time, size) + zero_triu (bool): return the lower triangular part of the matrix """ zero_pad = torch.zeros((*x.size()[:3], 1), device=x.device, dtype=x.dtype) x_padded = torch.cat([zero_pad, x], dim=-1) @@ -149,14 +157,14 @@ def rel_shift(self, x, zero_triu=False): def forward(self, query, key, value, mask, pos_emb): """Compute 'Scaled Dot Product Attention' with rel. positional encoding. - :param torch.Tensor query: (batch, time1, size) - :param torch.Tensor key: (batch, time2, size) - :param torch.Tensor value: (batch, time2, size) - :param torch.Tensor pos_emb: (batch, time1, size) - :param torch.Tensor mask: (batch, time1, time2) - :param torch.nn.Dropout dropout: - :return torch.Tensor: attentined and transformed `value` (batch, time1, d_model) - weighted by the query dot key attention (batch, head, time1, time2) + Args: + query (torch.Tensor): (batch, time1, size) + key (torch.Tensor): (batch, time2, size) + value(torch.Tensor): (batch, time2, size) + mask (torch.Tensor): (batch, time1, time2) + pos_emb (torch.Tensor) : (batch, time1, size) + Returns: + output (torch.Tensor): transformed `value` (batch, time1, d_model) weighted by the query dot key attention """ q, k, v = self.forward_qkv(query, key, value) q = q.transpose(1, 2) # (batch, time1, head, d_k) @@ -188,10 +196,11 @@ def forward(self, query, key, value, mask, pos_emb): class PositionalEncoding(torch.nn.Module): """Positional encoding. - :param int d_model: embedding dim - :param float dropout_rate: dropout rate - :param int max_len: maximum input length - :param reverse: whether to reverse the input position + Args: + d_model (int): embedding dim + dropout_rate (float): dropout rate + max_len (int): maximum input length + reverse (int): whether to reverse the input position """ def __init__(self, d_model, dropout_rate, max_len=5000, reverse=False, xscale=None): @@ -229,7 +238,7 @@ def forward(self, x: torch.Tensor): Args: x (torch.Tensor): Input. Its shape is (batch, time, ...) Returns: - torch.Tensor: Encoded tensor. Its shape is (batch, time, ...) + Encoded Output (torch.Tensor): Its shape is (batch, time, ...) """ self.extend_pe(x) if self.xscale: @@ -241,17 +250,13 @@ def forward(self, x: torch.Tensor): class RelPositionalEncoding(PositionalEncoding): """Relitive positional encoding module. See : Appendix B in https://arxiv.org/abs/1901.02860 - :param int d_model: embedding dim - :param float dropout_rate: dropout rate - :param int max_len: maximum input length + Args: + d_model (int): embedding dim + dropout_rate (float): dropout rate + max_len (int): maximum input length """ def __init__(self, d_model, dropout_rate, max_len=5000, dropout_emb_rate=0.0, xscale=None): - """Initialize class. - :param int d_model: embedding dim - :param float dropout_rate: dropout rate - :param int max_len: maximum input length - """ super().__init__(d_model, dropout_rate, max_len, reverse=True, xscale=xscale) if dropout_emb_rate > 0: @@ -264,8 +269,8 @@ def forward(self, x): Args: x (torch.Tensor): Input. Its shape is (batch, time, ...) Returns: - torch.Tensor: x. Its shape is (batch, time, ...) - torch.Tensor: pos_emb. Its shape is (1, time, ...) + x (torch.Tensor): Its shape is (batch, time, ...) + pos_emb (torch.Tensor): Its shape is (1, time, ...) """ self.extend_pe(x) if self.xscale: diff --git a/nemo/collections/asr/modules/subsampling.py b/nemo/collections/asr/modules/subsampling.py index 2d1d5e07c320..85f31611b8b4 100644 --- a/nemo/collections/asr/modules/subsampling.py +++ b/nemo/collections/asr/modules/subsampling.py @@ -19,19 +19,21 @@ class ConvSubsampling(torch.nn.Module): - """Convolutional 2D subsampling (to 1/4 length). - - :param int idim: input dim - :param int odim: output dim - :param str activation: activation functions - :param flaot dropout_rate: dropout rate - + """Convolutional subsampling which supports VGGNet and striding approach introduced in: + VGGNet Subsampling: https://arxiv.org/pdf/1910.12977.pdf + Striding Subsampling: + "Speech-Transformer: A No-Recurrence Sequence-to-Sequence Model for Speech Recognition" by Linhao Dong et al. + Args: + subsampling (str): The subsampling technique from {"vggnet", "striding"} + subsampling_factor (int): The subsampling factor which should be a power of 2 + feat_in (int): size of the input features + feat_out (int): size of the output features + conv_channels (int): Number of channels for the convolution layers. + activation (Module): activation function, default is nn.ReLU() """ - def __init__(self, subsampling, subsampling_factor, idim, odim, conv_channels=-1, activation=nn.ReLU()): + def __init__(self, subsampling, subsampling_factor, feat_in, feat_out, conv_channels, activation=nn.ReLU()): super(ConvSubsampling, self).__init__() - if conv_channels <= 0: - conv_channels = odim self._subsampling = subsampling if subsampling_factor % 2 != 0: @@ -44,7 +46,7 @@ def __init__(self, subsampling, subsampling_factor, idim, odim, conv_channels=-1 self._padding = 0 self._stride = 2 self._kernel_size = 2 - self._ceil_mode = True # TODO: is False better? + self._ceil_mode = True for i in range(self._sampling_num): layers.append( @@ -69,7 +71,7 @@ def __init__(self, subsampling, subsampling_factor, idim, odim, conv_channels=-1 ) in_channels = conv_channels elif subsampling == 'striding': - self._padding = 1 # TODO: is 0 better? + self._padding = 1 self._stride = 2 self._kernel_size = 3 self._ceil_mode = False @@ -89,7 +91,7 @@ def __init__(self, subsampling, subsampling_factor, idim, odim, conv_channels=-1 else: raise ValueError(f"Not valid sub-sampling: {subsampling}!") - in_length = idim + in_length = feat_in for i in range(self._sampling_num): out_length = calc_length( length=int(in_length), @@ -100,23 +102,16 @@ def __init__(self, subsampling, subsampling_factor, idim, odim, conv_channels=-1 ) in_length = out_length - self.out = torch.nn.Linear(conv_channels * out_length, odim) + self.out = torch.nn.Linear(conv_channels * out_length, feat_out) self.conv = torch.nn.Sequential(*layers) def forward(self, x, lengths): - """Subsample x. - - :param torch.Tensor x: input tensor - :param torch.Tensor x_mask: input mask - :return: subsampled x and mask - :rtype Tuple[torch.Tensor or Tuple[torch.Tensor, torch.Tensor], torch.Tensor] - """ - x = x.unsqueeze(1) # (b, c, t, f) + x = x.unsqueeze(1) x = self.conv(x) b, c, t, f = x.size() x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) - # TODO: improve the performance of here + # TODO: improve the performance of length calculation new_lengths = lengths for i in range(self._sampling_num): new_lengths = [ @@ -135,6 +130,7 @@ def forward(self, x, lengths): def calc_length(length, padding, kernel_size, stride, ceil_mode): + """ Calculates the output length of a Tensor passed through a convolution or max pooling layer""" if ceil_mode: length = math.ceil((length + (2 * padding) - (kernel_size - 1) - 1) / float(stride) + 1) else: