Skip to content

Commit

Permalink
post_ln => normalize_after
Browse files Browse the repository at this point in the history
  • Loading branch information
hpasapp committed Dec 17, 2020
1 parent 7e4c758 commit b8235f0
Showing 1 changed file with 9 additions and 9 deletions.
18 changes: 9 additions & 9 deletions sru/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def __init__(self,
v1: bool = False,
custom_m: Optional[nn.Module] = None,
amp_recurrence_fp16: bool = False,
post_ln: bool = False):
normalize_after: bool = False):
"""Initialize the SRUCell module.
Parameters
Expand Down Expand Up @@ -95,7 +95,7 @@ def __init__(self,
When using AMP autocast, selects which type to use
for recurrence custom kernel.
False: torch.float32, True: torch.float16
post_ln: bool
normalize_after: bool
if True use post layer norm, else pre layer norm
"""
super(SRUCell, self).__init__()
Expand All @@ -116,7 +116,7 @@ def __init__(self,
self.activation_type = 1
self.activation = 'tanh'
self.amp_recurrence_fp16 = amp_recurrence_fp16
self.post_ln = post_ln
self.normalize_after = normalize_after

# projection dimension
self.projection_size = 0
Expand Down Expand Up @@ -150,7 +150,7 @@ def __init__(self,

self.layer_norm: Optional[nn.Module]= None
if layer_norm:
if post_ln:
if normalize_after:
self.layer_norm = nn.LayerNorm(self.output_size)
else:
self.layer_norm = nn.LayerNorm(self.input_size)
Expand Down Expand Up @@ -242,7 +242,7 @@ def forward(self,

# apply layer norm before activation (i.e. before SRU computation)
residual = input
if self.layer_norm is not None and not self.post_ln:
if self.layer_norm is not None and not self.normalize_after:
input = self.layer_norm(input)

# apply dropout for multiplication
Expand All @@ -268,7 +268,7 @@ def forward(self,
# apply elementwise recurrence to get hidden states h and c
h, c = self.apply_recurrence(U, V, residual, c0, scale_val, mask_c, mask_pad)

if self.layer_norm is not None and self.post_ln:
if self.layer_norm is not None and self.normalize_after:
h = self.layer_norm(h)

return h, c
Expand Down Expand Up @@ -439,7 +439,7 @@ def __init__(self,
custom_m: Optional[Union[nn.Module, List[nn.Module]]] = None,
proj_input_to_hidden_first: bool = False,
amp_recurrence_fp16: bool = False,
post_ln: bool = False):
normalize_after: bool = False):
"""Initialize the SRU module.
Parameters
Expand Down Expand Up @@ -498,7 +498,7 @@ def __init__(self,
When using AMP autocast, selects which type to use
for recurrence custom kernel.
False: torch.float32, True: torch.float16
post_ln: bool
normalize_after: bool
if True use post layer norm, else use pre layer norm
"""

Expand Down Expand Up @@ -552,7 +552,7 @@ def __init__(self,
v1=v1,
custom_m=custom_m_i,
amp_recurrence_fp16=amp_recurrence_fp16,
post_ln=post_ln
normalize_after=normalize_after
)
rnn_lst.append(layer_i)
self.rnn_lst = rnn_lst
Expand Down

0 comments on commit b8235f0

Please sign in to comment.