-
Notifications
You must be signed in to change notification settings - Fork 2.2k
/
conformer_encoder.py
310 lines (263 loc) · 10.6 KB
/
conformer_encoder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from collections import OrderedDict
import torch
import torch.nn as nn
from torch.nn.modules import LayerNorm
from nemo.collections.asr.modules.conformer_modules import ConformerConvolution, ConformerFeedForward
from nemo.collections.asr.modules.multi_head_attention import (
MultiHeadAttention,
PositionalEncoding,
RelPositionalEncoding,
RelPositionMultiHeadAttention,
)
from nemo.collections.asr.modules.subsampling import ConvSubsampling
from nemo.core.classes.common import typecheck
from nemo.core.classes.exportable import Exportable
from nemo.core.classes.module import NeuralModule
from nemo.core.neural_types import AcousticEncodedRepresentation, LengthsType, NeuralType, SpectrogramType
__all__ = ['ConformerEncoder']
class ConformerEncoder(NeuralModule, Exportable):
"""
The encoder for ASR model of Conformer.
Based on this paper:
https://arxiv.org/abs/2005.08100
"""
def _prepare_for_export(self):
Exportable._prepare_for_export(self)
def input_example(self):
"""
Generates input examples for tracing etc.
Returns:
A tuple of input examples.
"""
input_example = torch.randn(16, self.__feat_in, 256).to(next(self.parameters()).device)
return tuple([input_example])
@property
def disabled_deployment_input_names(self):
"""Implement this method to return a set of input names disabled for export"""
return set(["length"])
@property
def disabled_deployment_output_names(self):
"""Implement this method to return a set of output names disabled for export"""
return set(["encoded_lengths"])
def save_to(self, save_path: str):
pass
@classmethod
def restore_from(cls, restore_path: str):
pass
@property
def input_types(self):
"""Returns definitions of module input ports.
"""
return OrderedDict(
{
"audio_signal": NeuralType(('B', 'D', 'T'), SpectrogramType()),
"length": NeuralType(tuple('B'), LengthsType()),
}
)
@property
def output_types(self):
"""Returns definitions of module output ports.
"""
return OrderedDict(
{
"outputs": NeuralType(('B', 'D', 'T'), AcousticEncodedRepresentation()),
"encoded_lengths": NeuralType(tuple('B'), LengthsType()),
}
)
def __init__(
self,
feat_in,
n_layers,
d_model,
feat_out=0,
subsampling='vggnet',
subsampling_factor=4,
subsampling_conv_channels=64,
ff_expansion_factor=4,
self_attention_model='rel_pos',
n_heads=4,
xscaling=True,
conv_kernel_size=31,
dropout=0.1,
dropout_emb=0.1,
dropout_att=0.0,
):
super().__init__()
d_ff = d_model * ff_expansion_factor
self.d_model = d_model
self.scale = math.sqrt(self.d_model)
if xscaling:
self.xscale = math.sqrt(d_model)
else:
self.xscale = None
if subsampling:
self.pre_encode = ConvSubsampling(
subsampling=subsampling,
subsampling_factor=subsampling_factor,
feat_in=feat_in,
feat_out=d_model,
conv_channels=subsampling_conv_channels,
activation=nn.ReLU(),
)
self.feat_out = d_model
else:
self.feat_out = d_model
self.pre_encode = nn.Linear(feat_in, d_model)
self.u_bias = None
self.v_bias = None
if self_attention_model == "rel_pos":
self.pos_enc = RelPositionalEncoding(
d_model=d_model, dropout_rate=dropout, dropout_emb_rate=dropout_emb, xscale=self.xscale
)
elif self_attention_model == "abs_pos":
self.pos_enc = PositionalEncoding(
d_model=d_model, dropout_rate=dropout, max_len=6000, reverse=False, xscale=self.xscale
)
else:
raise ValueError(f"Not valid self_attention_model: '{self_attention_model}'!")
self.layers = nn.ModuleList()
for i in range(n_layers):
layer = ConformerEncoderBlock(
d_model=d_model,
d_ff=d_ff,
conv_kernel_size=conv_kernel_size,
self_attention_model=self_attention_model,
n_heads=n_heads,
dropout=dropout,
dropout_att=dropout_att,
)
self.layers.append(layer)
if feat_out > 0 and feat_out != self.output_dim:
self.out_proj = nn.Linear(self.feat_out, feat_out)
self.feat_out = feat_out
else:
self.out_proj = None
self.feat_out = d_model
@typecheck()
def forward(self, audio_signal, length):
audio_signal = torch.transpose(audio_signal, 1, 2)
if isinstance(self.pre_encode, ConvSubsampling):
audio_signal, length = self.pre_encode(audio_signal, length)
else:
audio_signal = self.embed(audio_signal)
audio_signal, pos_emb = self.pos_enc(audio_signal)
bs, xmax, idim = audio_signal.size()
# Create the self-attention and padding masks
pad_mask = self.make_pad_mask(length, max_time=xmax, device=audio_signal.device)
xx_mask = pad_mask.unsqueeze(1).repeat([1, xmax, 1])
xx_mask = xx_mask & xx_mask.transpose(1, 2)
pad_mask = (~pad_mask).unsqueeze(2)
for lth, layer in enumerate(self.layers):
audio_signal = layer(
x=audio_signal,
att_mask=xx_mask,
pos_emb=pos_emb,
u_bias=self.u_bias,
v_bias=self.v_bias,
pad_mask=pad_mask,
)
if self.out_proj is not None:
audio_signal = self.out_proj(audio_signal)
audio_signal = torch.transpose(audio_signal, 1, 2)
return audio_signal, length
@staticmethod
def make_pad_mask(seq_lens, max_time, device=None):
"""Make masking for padding.
Args:
seq_lens (IntTensor): `[B]`
device_id (int):
Returns:
mask (IntTensor): `[B, T]`
"""
bs = seq_lens.size(0)
seq_range = torch.arange(0, max_time, dtype=torch.int32)
seq_range_expand = seq_range.unsqueeze(0).expand(bs, max_time)
seq_lens = seq_lens.type(seq_range_expand.dtype).to(seq_range_expand.device)
seq_length_expand = seq_lens.unsqueeze(-1)
mask = seq_range_expand < seq_length_expand
if device:
mask = mask.to(device)
return mask
class ConformerEncoderBlock(torch.nn.Module):
"""A single block of the Conformer encoder.
Args:
d_model (int): input dimension of MultiheadAttentionMechanism and PositionwiseFeedForward
d_ff (int): hidden dimension of PositionwiseFeedForward
n_heads (int): number of heads for multi-head attention
conv_kernel_size (int): kernel size for depthwise convolution in convolution module
dropout (float): dropout probabilities for linear layers
dropout_att (float): dropout probabilities for attention distributions
"""
def __init__(self, d_model, d_ff, conv_kernel_size, self_attention_model, n_heads, dropout, dropout_att):
super(ConformerEncoderBlock, self).__init__()
self.self_attention_model = self_attention_model
self.n_heads = n_heads
self.fc_factor = 0.5
# first feed forward module
self.norm_feed_forward1 = LayerNorm(d_model)
self.feed_forward1 = ConformerFeedForward(d_model=d_model, d_ff=d_ff, dropout=dropout)
# convolution module
self.norm_conv = LayerNorm(d_model)
self.conv = ConformerConvolution(d_model=d_model, kernel_size=conv_kernel_size)
# multi-headed self-attention module
self.norm_self_att = LayerNorm(d_model)
if self_attention_model == 'rel_pos':
self.self_attn = RelPositionMultiHeadAttention(n_head=n_heads, n_feat=d_model, dropout_rate=dropout_att)
elif self_attention_model == 'abs_pos':
self.self_attn = MultiHeadAttention(n_head=n_heads, n_feat=d_model, dropout_rate=dropout_att)
else:
raise ValueError(f"Not valid self_attention_model: '{self_attention_model}'!")
# second feed forward module
self.norm_feed_forward2 = LayerNorm(d_model)
self.feed_forward2 = ConformerFeedForward(d_model=d_model, d_ff=d_ff, dropout=dropout)
self.dropout = nn.Dropout(dropout)
self.norm_out = LayerNorm(d_model)
def forward(self, x, att_mask=None, pos_emb=None, u_bias=None, v_bias=None, pad_mask=None):
"""
Args:
x (FloatTensor): `[B, T, d_model]`
att_mask (ByteTensor): `[B, T, T]`
pos_emb (LongTensor): `[L, 1, d_model]`
u (FloatTensor): global parameter for relative positional embedding
v (FloatTensor): global parameter for relative positional embedding
Returns:
xs (FloatTensor): `[B, T, d_model]`
xx_aws (FloatTensor): `[B, H, T, T]`
"""
residual = x
x = self.norm_feed_forward1(x)
x = self.feed_forward1(x)
x = self.fc_factor * self.dropout(x) + residual
residual = x
x = self.norm_self_att(x)
if self.self_attention_model == 'rel_pos':
x = self.self_attn(query=x, key=x, value=x, pos_emb=pos_emb, mask=att_mask)
elif self.self_attention_model == 'abs_pos':
x = self.self_attn(query=x, key=x, value=x, mask=att_mask)
else:
x = None
x = self.dropout(x) + residual
residual = x
x = self.norm_conv(x)
x = self.conv(x)
x = self.dropout(x) + residual
residual = x
x = self.norm_feed_forward2(x)
x = self.feed_forward2(x)
x = self.fc_factor * self.dropout(x) + residual
x = self.norm_out(x)
return x