-
Notifications
You must be signed in to change notification settings - Fork 2.2k
/
conformer_modules.py
170 lines (140 loc) · 5.98 KB
/
conformer_modules.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import torch
from torch import nn as nn
from torch.nn import LayerNorm
from nemo.collections.asr.parts.activations import Swish
from nemo.collections.asr.parts.multi_head_attention import MultiHeadAttention, RelPositionMultiHeadAttention
__all__ = ['ConformerConvolution', 'ConformerFeedForward', 'ConformerEncoderBlock']
class ConformerEncoderBlock(torch.nn.Module):
"""A single block of the Conformer encoder.
Args:
d_model (int): input dimension of MultiheadAttentionMechanism and PositionwiseFeedForward
d_ff (int): hidden dimension of PositionwiseFeedForward
n_heads (int): number of heads for multi-head attention
conv_kernel_size (int): kernel size for depthwise convolution in convolution module
dropout (float): dropout probabilities for linear layers
dropout_att (float): dropout probabilities for attention distributions
"""
def __init__(self, d_model, d_ff, conv_kernel_size, self_attention_model, n_heads, dropout, dropout_att):
super(ConformerEncoderBlock, self).__init__()
self.self_attention_model = self_attention_model
self.n_heads = n_heads
self.fc_factor = 0.5
# first feed forward module
self.norm_feed_forward1 = LayerNorm(d_model)
self.feed_forward1 = ConformerFeedForward(d_model=d_model, d_ff=d_ff, dropout=dropout)
# convolution module
self.norm_conv = LayerNorm(d_model)
self.conv = ConformerConvolution(d_model=d_model, kernel_size=conv_kernel_size)
# multi-headed self-attention module
self.norm_self_att = LayerNorm(d_model)
if self_attention_model == 'rel_pos':
self.self_attn = RelPositionMultiHeadAttention(n_head=n_heads, n_feat=d_model, dropout_rate=dropout_att)
elif self_attention_model == 'abs_pos':
self.self_attn = MultiHeadAttention(n_head=n_heads, n_feat=d_model, dropout_rate=dropout_att)
else:
raise ValueError(f"Not valid self_attention_model: '{self_attention_model}'!")
# second feed forward module
self.norm_feed_forward2 = LayerNorm(d_model)
self.feed_forward2 = ConformerFeedForward(d_model=d_model, d_ff=d_ff, dropout=dropout)
self.dropout = nn.Dropout(dropout)
self.norm_out = LayerNorm(d_model)
def forward(self, x, att_mask=None, pos_emb=None, pad_mask=None):
"""
Args:
x (torch.Tensor): input signals (B, T, d_model)
att_mask (torch.Tensor): attention masks(B, T, T)
pos_emb (torch.Tensor): (L, 1, d_model)
pad_mask (torch.tensor): padding mask
Returns:
x (torch.Tensor): (B, T, d_model)
"""
residual = x
x = self.norm_feed_forward1(x)
x = self.feed_forward1(x)
x = self.fc_factor * self.dropout(x) + residual
residual = x
x = self.norm_self_att(x)
if self.self_attention_model == 'rel_pos':
x = self.self_attn(query=x, key=x, value=x, pos_emb=pos_emb, mask=att_mask)
elif self.self_attention_model == 'abs_pos':
x = self.self_attn(query=x, key=x, value=x, mask=att_mask)
else:
x = None
x = self.dropout(x) + residual
residual = x
x = self.norm_conv(x)
x = self.conv(x)
x = self.dropout(x) + residual
residual = x
x = self.norm_feed_forward2(x)
x = self.feed_forward2(x)
x = self.fc_factor * self.dropout(x) + residual
x = self.norm_out(x)
return x
class ConformerConvolution(nn.Module):
"""The convolution module for the Conformer model.
Args:
d_model (int): hidden dimension
kernel_size (int): kernel size for depthwise convolution
"""
def __init__(self, d_model, kernel_size):
super(ConformerConvolution, self).__init__()
assert (kernel_size - 1) % 2 == 0
self.d_model = d_model
self.pointwise_conv1 = nn.Conv1d(
in_channels=d_model, out_channels=d_model * 2, kernel_size=1, stride=1, padding=0, bias=True
)
self.depthwise_conv = nn.Conv1d(
in_channels=d_model,
out_channels=d_model,
kernel_size=kernel_size,
stride=1,
padding=(kernel_size - 1) // 2,
groups=d_model,
bias=True,
)
self.batch_norm = nn.BatchNorm1d(d_model)
self.activation = Swish()
self.pointwise_conv2 = nn.Conv1d(
in_channels=d_model, out_channels=d_model, kernel_size=1, stride=1, padding=0, bias=True
)
def forward(self, x):
x = x.transpose(1, 2)
x = self.pointwise_conv1(x)
x = nn.functional.glu(x, dim=1)
x = self.depthwise_conv(x)
x = self.batch_norm(x)
x = self.activation(x)
x = self.pointwise_conv2(x)
x = x.transpose(1, 2)
return x
class ConformerFeedForward(nn.Module):
"""
feed-forward module of Conformer model.
"""
def __init__(self, d_model, d_ff, dropout, activation=Swish()):
super(ConformerFeedForward, self).__init__()
self.linear1 = nn.Linear(d_model, d_ff)
self.activation = activation
self.dropout = nn.Dropout(p=dropout)
self.linear2 = nn.Linear(d_ff, d_model)
def forward(self, x):
x = self.linear1(x)
x = self.activation(x)
x = self.dropout(x)
x = self.linear2(x)
return x