diff --git a/nemo/collections/asr/modules/conformer_modules.py b/nemo/collections/asr/modules/conformer_modules.py index 8e6e0a0fefbc..5662ddbfe439 100644 --- a/nemo/collections/asr/modules/conformer_modules.py +++ b/nemo/collections/asr/modules/conformer_modules.py @@ -70,6 +70,7 @@ class ConformerFeedForward(nn.Module): """ feed-forward module of Conformer model. """ + def __init__(self, d_model, d_ff, dropout, activation=Swish()): super(ConformerFeedForward, self).__init__() self.linear1 = nn.Linear(d_model, d_ff)