Skip to content

Commit

Permalink
Merge branch 'master' of https://github.com/NVIDIA/NeMo into mfcc_fea…
Browse files Browse the repository at this point in the history
…tures

Signed-off-by: Jocelyn Huang <jocelynh@nvidia.com>
  • Loading branch information
redoctopus committed Nov 22, 2019
2 parents a905a54 + 597154e commit 2d1d7b2
Show file tree
Hide file tree
Showing 11 changed files with 392 additions and 104 deletions.
2 changes: 1 addition & 1 deletion collections/nemo_asr/nemo_asr/audio_preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import torchaudio
try:
from apex import amp
except AttributeError:
except (AttributeError, ModuleNotFoundError) as e:
print("Unable to import APEX. Mixed precision and distributed training "
"will not work.")

Expand Down
2 changes: 0 additions & 2 deletions collections/nemo_asr/nemo_asr/jasper.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,6 @@ def __init__(
self.dense_residual = True
groups = lcfg.get('groups', 1)
separable = lcfg.get('separable', False)
tied = lcfg.get('tied', False)
heads = lcfg.get('heads', -1)
encoder_layers.append(
JasperBlock(feat_in,
Expand All @@ -133,7 +132,6 @@ def __init__(
residual_mode=residual_mode,
normalization=normalization_mode,
norm_groups=norm_groups,
tied=tied,
activation=activation,
residual_panes=dense_res,
conv_mask=conv_mask))
Expand Down
194 changes: 103 additions & 91 deletions collections/nemo_asr/nemo_asr/parts/jasper.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,22 @@
# Taken straight from Patter https://github.com/ryanleary/patter
# TODO: review, and copyright and fix/add comments
# Copyright (C) NVIDIA CORPORATION. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Tuple

import torch
import torch.nn as nn
from torch import Tensor

jasper_activations = {
"hardtanh": nn.Hardtanh,
Expand All @@ -11,7 +26,9 @@


def init_weights(m, mode='xavier_uniform'):
if isinstance(m, nn.Conv1d) or isinstance(m, MaskedConv1d):
if isinstance(m, MaskedConv1d):
init_weights(m.conv, mode)
if isinstance(m, nn.Conv1d):
if mode == 'xavier_uniform':
nn.init.xavier_uniform_(m.weight, gain=1.0)
elif mode == 'xavier_normal':
Expand Down Expand Up @@ -40,50 +57,52 @@ def get_same_padding(kernel_size, stride, dilation):
return kernel_size // 2


class MaskedConv1d(nn.Conv1d):
class MaskedConv1d(nn.Module):
__constants__ = ["use_conv_mask", "real_out_channels", "heads"]

def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding=0, dilation=1, groups=1, heads=-1, bias=False,
use_mask=True):
super(MaskedConv1d, self).__init__()

if not (heads == -1 or groups == in_channels):
raise ValueError("Only use heads for depthwise convolutions")

self.real_out_channels = out_channels
if heads != -1:
self.real_out_channels = out_channels
in_channels = heads
out_channels = heads
groups = heads

super(MaskedConv1d, self).__init__(in_channels, out_channels,
kernel_size,
stride=stride,
padding=padding, dilation=dilation,
groups=groups, bias=bias)
self.conv = nn.Conv1d(in_channels, out_channels,
kernel_size,
stride=stride,
padding=padding, dilation=dilation,
groups=groups, bias=bias)
self.use_mask = use_mask
self.heads = heads

def get_seq_len(self, lens):
return ((lens + 2 * self.padding[0] - self.dilation[0] * (
self.kernel_size[0] - 1) - 1) / self.stride[0] + 1)
return ((lens + 2 * self.conv.padding[0] - self.conv.dilation[0] * (
self.conv.kernel_size[0] - 1) - 1) / self.conv.stride[0] + 1)

def forward(self, x, lens):
if self.use_mask:
lens = lens.to(dtype=torch.long)
max_len = x.size(2)
mask = torch.arange(max_len).to(lens.device)\
mask = torch.arange(max_len).to(lens.device) \
.expand(len(lens), max_len) >= lens.unsqueeze(1)
x = x.masked_fill(
mask.unsqueeze(1).type(torch.bool).to(device=x.device), 0
mask.unsqueeze(1).to(device=x.device), 0
)
del mask
# del mask
lens = self.get_seq_len(lens)

sh = x.shape
if self.heads != -1:
sh = x.shape
x = x.view(-1, self.heads, sh[-1])

out, lens = super(MaskedConv1d, self).forward(x), lens
out = self.conv(x)

if self.heads != -1:
out = out.view(sh[0], self.real_out_channels, -1)
Expand Down Expand Up @@ -112,11 +131,12 @@ def forward(self, x):


class JasperBlock(nn.Module):
__constants__ = ["conv_mask", "separable", "residual_mode", "res", "mconv"]

def __init__(self, inplanes, planes, repeat=3, kernel_size=11, stride=1,
dilation=1, padding='same', dropout=0.2, activation=None,
residual=True, groups=1, separable=False,
heads=-1, tied=False, normalization="batch",
heads=-1, normalization="batch",
norm_groups=1, residual_mode='add',
residual_panes=[], conv_mask=False):
super(JasperBlock, self).__init__()
Expand All @@ -129,11 +149,11 @@ def __init__(self, inplanes, planes, repeat=3, kernel_size=11, stride=1,
self.separable = separable
self.residual_mode = residual_mode

self.conv = nn.ModuleList()
inplanes_loop = inplanes
conv = nn.ModuleList()

if tied:
rep_layer = self._get_conv_bn_layer(
for _ in range(repeat - 1):
conv.extend(self._get_conv_bn_layer(
inplanes_loop,
planes,
kernel_size=kernel_size,
Expand All @@ -144,73 +164,70 @@ def __init__(self, inplanes, planes, repeat=3, kernel_size=11, stride=1,
heads=heads,
separable=separable,
normalization=normalization,
norm_groups=norm_groups)
norm_groups=norm_groups))

for _ in range(repeat - 1):
if tied:
self.conv.extend(rep_layer)
else:
self.conv.extend(
self._get_conv_bn_layer(
inplanes_loop,
planes,
kernel_size=kernel_size,
stride=stride,
dilation=dilation,
padding=padding_val,
groups=groups,
heads=heads,
separable=separable,
normalization=normalization,
norm_groups=norm_groups))

self.conv.extend(
self._get_act_dropout_layer(
drop_prob=dropout,
activation=activation))
conv.extend(self._get_act_dropout_layer(
drop_prob=dropout,
activation=activation))

inplanes_loop = planes

if tied:
self.conv.extend(rep_layer)
else:
self.conv.extend(
self._get_conv_bn_layer(
inplanes_loop,
planes,
kernel_size=kernel_size,
stride=stride,
dilation=dilation,
padding=padding_val,
groups=groups,
heads=heads,
separable=separable,
normalization=normalization,
norm_groups=norm_groups))
conv.extend(self._get_conv_bn_layer(
inplanes_loop,
planes,
kernel_size=kernel_size,
stride=stride,
dilation=dilation,
padding=padding_val,
groups=groups,
heads=heads,
separable=separable,
normalization=normalization,
norm_groups=norm_groups))

self.mconv = conv

self.res = nn.ModuleList() if residual else None
res_panes = residual_panes.copy()
self.dense_residual = residual

if residual:
res_list = nn.ModuleList()
if len(residual_panes) == 0:
res_panes = [inplanes]
self.dense_residual = False
for ip in res_panes:
self.res.append(
nn.ModuleList(
modules=self._get_conv_bn_layer(
ip,
planes,
kernel_size=1,
normalization=normalization,
norm_groups=norm_groups)))
self.out = nn.Sequential(
res_list.append(nn.ModuleList(self._get_conv_bn_layer(
ip,
planes,
kernel_size=1,
normalization=normalization,
norm_groups=norm_groups)))
self.res = res_list
else:
self.res = None

self.mout = nn.Sequential(
*self._get_act_dropout_layer(
drop_prob=dropout,
activation=activation
)
activation=activation)
)

def _get_conv(self, in_channels, out_channels, kernel_size=11,
stride=1, dilation=1, padding=0, bias=False,
groups=1, heads=-1, separable=False):
use_mask = self.conv_mask
if use_mask:
return MaskedConv1d(in_channels, out_channels, kernel_size,
stride=stride,
dilation=dilation, padding=padding, bias=bias,
groups=groups, heads=heads,
use_mask=use_mask)
else:
return nn.Conv1d(in_channels, out_channels, kernel_size,
stride=stride,
dilation=dilation, padding=padding, bias=bias,
groups=groups)

def _get_conv_bn_layer(self, in_channels, out_channels, kernel_size=11,
stride=1, dilation=1, padding=0, bias=False,
groups=1, heads=-1, separable=False,
Expand All @@ -220,23 +237,20 @@ def _get_conv_bn_layer(self, in_channels, out_channels, kernel_size=11,

if separable:
layers = [
MaskedConv1d(in_channels, in_channels, kernel_size,
stride=stride,
dilation=dilation, padding=padding, bias=bias,
groups=in_channels, heads=heads,
use_mask=self.conv_mask),
MaskedConv1d(in_channels, out_channels, kernel_size=1,
stride=1,
dilation=1, padding=0, bias=bias, groups=groups,
use_mask=self.conv_mask)
self._get_conv(in_channels, in_channels, kernel_size,
stride=stride,
dilation=dilation, padding=padding, bias=bias,
groups=in_channels, heads=heads),
self._get_conv(in_channels, out_channels, kernel_size,
stride=1,
dilation=1, padding=0, bias=bias, groups=groups)
]
else:
layers = [
MaskedConv1d(in_channels, out_channels, kernel_size,
stride=stride,
dilation=dilation, padding=padding, bias=bias,
groups=groups,
use_mask=self.conv_mask)
self._get_conv(in_channels, out_channels, kernel_size,
stride=stride,
dilation=dilation, padding=padding, bias=bias,
groups=groups)
]

if normalization == "group":
Expand Down Expand Up @@ -268,15 +282,13 @@ def _get_act_dropout_layer(self, drop_prob=0.2, activation=None):
]
return layers

def forward(self, input_):

def forward(self, input_: Tuple[Tensor, Tensor]):
xs, lens_orig = input_

# compute forward convolutions
out = xs[-1]

lens = lens_orig
for i, l in enumerate(self.conv):
for i, l in enumerate(self.mconv):
# if we're doing masked convolutions, we need to pass in and
# possibly update the sequence lengths
# if (i % 4) == 0 and self.conv_mask:
Expand All @@ -301,7 +313,7 @@ def forward(self, input_):
out = torch.max(out, res_out)

# compute the output
out = self.out(out)
out = self.mout(out)
if self.res is not None and self.dense_residual:
return xs + [out], lens

Expand Down
2 changes: 2 additions & 0 deletions collections/nemo_nlp/nemo_nlp/huggingface/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,8 @@ def __init__(self, *,

self.add_module("bert", model)
self.config = model.config
for key, value in self.config.to_dict().items():
self._local_parameters[key] = value

@staticmethod
def list_pretrained_models() -> Optional[List[PretrainedModelInfo]]:
Expand Down
Loading

0 comments on commit 2d1d7b2

Please sign in to comment.