Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add rezero option; register ops for APEX AMP training #192

Open
wants to merge 5 commits into
base: 3.0.0-dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion experiments/srupp_experiments/train_enwik8.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def __init__(self, words, args):

def init_weights(self):
params = list(self.embedding_layer.parameters()) + list(self.output_layer.parameters())
new_init_version = self.args.layer_norm
new_init_version = True
for p in params:
if p.dim() > 1: # matrix
# keep old init version for reproducibility
Expand Down
2 changes: 1 addition & 1 deletion experiments/srupp_experiments/train_lm1b.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def __init__(self, args):
dropout=args.emb_dropout,
emb_weights=shared_embs,
proj_weights=shared_projs,
scale_emb=not args.layer_norm,
scale_emb=False,
)
self.init_weights()

Expand Down
2 changes: 1 addition & 1 deletion experiments/srupp_experiments/train_wt103.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def __init__(self, args):
dropout=args.emb_dropout,
emb_weights=shared_embs,
proj_weights=shared_projs,
scale_emb=not args.layer_norm,
scale_emb=False,
)
self.init_weights()

Expand Down
35 changes: 25 additions & 10 deletions sru/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -703,9 +703,10 @@ def __init__(self,
num_heads: int = 1,
dropout: float = 0.0,
attn_dropout: float = 0.0,
rezero_init_alpha: float = 0.0,
layer_norm: bool = False,
normalize_after: bool = True):
normalize_after: bool = True,
use_rezero: bool = True,
rezero_init_alpha: float = 0.0):
"""Initialize the self-attention module.

Parameters
Expand All @@ -726,13 +727,14 @@ def __init__(self,
(default=0.0).
attn_dropout: float, optional
dropout probability applied on attention map.
rezero_init_alpha: float, optional
initial scalar value for the attention transformation `x + alpha * Attention(x)`
(default=0).
normalize_after: bool, optional
if True, apply post layer normalization; otherwise apply pre layer normalization
(default=True).

use_rezero: bool, optional
if True, apply rezero init using a scalar value `alpha` for the attention transformation
`x + alpha * Attention(x)`.
rezero_init_alpha: float, optional
initial value of the scalar value `alpha` when using rezero (default=0).
"""
super(SRUppAttention, self).__init__()
self.in_features = in_features
Expand All @@ -745,7 +747,9 @@ def __init__(self,
self.linear1 = nn.Linear(in_features, proj_features, bias=False)
self.linear2 = nn.Linear(proj_features, proj_features * 2, bias=False)
self.linear3 = nn.Linear(proj_features, out_features, bias=False)
self.alpha = nn.Parameter(torch.Tensor([float(rezero_init_alpha)])) # type: ignore
self.alpha: Optional[nn.Parameter] = None
if use_rezero:
self.alpha = nn.Parameter(torch.Tensor([float(rezero_init_alpha)])) # type: ignore
self.normalize_after = normalize_after
self.layer_norm: Optional[nn.Module] = None
if layer_norm:
Expand All @@ -761,7 +765,10 @@ def reset_parameters(self):
nn.init.xavier_uniform_(self.linear1.weight)
nn.init.xavier_uniform_(self.linear2.weight)
nn.init.xavier_uniform_(self.linear3.weight)
self.alpha.data[:] = self.rezero_init_alpha
if self.alpha is not None:
self.alpha.data[:] = self.rezero_init_alpha
else:
self.linear2.weight.data[self.proj_features:].mul_(0.0)
if self.linear1.bias is not None:
self.linear1.bias.data.zero_()
if self.linear2.bias is not None:
Expand Down Expand Up @@ -861,7 +868,10 @@ def forward(self,
attn_output = torch.bmm(attn_output_weights, v)
attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, proj_dim)

attn_output = attn_output * self.alpha + residual
if self.alpha is not None:
attn_output = attn_output * self.alpha + residual
else:
attn_output = attn_output + residual
layer_norm = self.layer_norm
if layer_norm is not None:
if self.normalize_after:
Expand Down Expand Up @@ -963,6 +973,7 @@ def __init__(self,
normalize_after: bool = False,
attn_layer_norm: bool = True,
highway_bias: float = -2.0,
use_rezero: bool = True,
attention_every_n_layers: int = 1,
attention_last_n_layers: int = -1,
rescale: bool = False,
Expand Down Expand Up @@ -999,6 +1010,9 @@ def __init__(self,
attention is disabled (default=True).
highway_bias: float, optional
the initial value of the bias used in the highway (sigmoid) gate (default=-1.0).
use_rezero: bool, optional
if True, apply rezero init using a scalar value `alpha` for the attention transformation
`x + alpha * Attention(x)`.
attention_every_n_layers: int, optional
only introduce attention every few layers of SRU++. by default, every SRU++ layer has
attention (default=1).
Expand Down Expand Up @@ -1053,10 +1067,11 @@ def __init__(self,
in_features,
out_features,
proj_features,
num_heads=num_heads,
dropout=dropout,
attn_dropout=attn_dropout,
num_heads=num_heads,
layer_norm=attn_layer_norm,
use_rezero=use_rezero,
)
else:
custom_m = SRUppProjectedLinear(
Expand Down
109 changes: 82 additions & 27 deletions sru/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,35 @@
)


# Wraper functions for APEX AMP
def apex_amp_sru_elementwise_fp16(*args, **kwargs):
return _apex_amp_sru_elementwise(*args, **kwargs)


def apex_amp_sru_elementwise_fp32(*args, **kwargs):
return _apex_amp_sru_elementwise(*args, **kwargs)


def _apex_amp_sru_elementwise(*args, **kwargs):
# Will already have been imported and cached at this point
from .cuda_functional import ElementwiseRecurrence
return ElementwiseRecurrence.apply(*args, **kwargs)


# Register elementwise recurrence operator as apex amp functions when APEX is available
try:
from apex import amp
import sys
APEX_AMP_AVAILABLE = True
current_module = sys.modules[__name__]
warnings.warn("Registering SRU op in {}".format(current_module))

amp.register_half_function(current_module, "apex_amp_sru_elementwise_fp16")
amp.register_float_function(current_module, "apex_amp_sru_elementwise_fp32")
except ImportError:
APEX_AMP_AVAILABLE = False


@torch.jit.script
def elementwise_recurrence_inference(U: Tensor,
x: Tensor,
Expand Down Expand Up @@ -112,35 +141,61 @@ def elementwise_recurrence_gpu(U: Tensor,
"""Elementwise forward operation of SRU on GPU.

"""
from .cuda_functional import ElementwiseRecurrence
in_autocast = getattr(torch, 'is_autocast_enabled', lambda: False)()

if amp_recurrence_fp16 and U.dtype == torch.float16:
cast = torch.Tensor.half
# APEX is available and not in native Pytorch AMP autocast
if APEX_AMP_AVAILABLE and not in_autocast:
if amp_recurrence_fp16 and U.dtype == torch.float16:
warnings.warn("Running SRU with APEX and cast type {}".format(torch.Tensor.half))
apex_sru_elementwise_gpu = apex_amp_sru_elementwise_fp16
else:
warnings.warn("Running SRU with APEX and cast type {}".format(torch.Tensor.float))
apex_sru_elementwise_gpu = apex_amp_sru_elementwise_fp32

return apex_sru_elementwise_gpu(
U,
x,
weight_c,
bias,
c_init,
activation_type,
hidden_size,
bidirectional,
has_skip_term,
scale_x,
dropout_mask_c,
mask_pad
)
else:
cast = torch.Tensor.float

U = cast(U)
x = cast(x)
weight_c = cast(weight_c)
bias = cast(bias)
c_init = cast(c_init)
scale_x = cast(scale_x) if scale_x is not None else scale_x
dropout_mask_c = cast(dropout_mask_c) if dropout_mask_c is not None else dropout_mask_c

return ElementwiseRecurrence.apply(
U,
x,
weight_c,
bias,
c_init,
activation_type,
hidden_size,
bidirectional,
has_skip_term,
scale_x,
dropout_mask_c,
mask_pad
)
from .cuda_functional import ElementwiseRecurrence
if amp_recurrence_fp16 and U.dtype == torch.float16:
cast = torch.Tensor.half
else:
cast = torch.Tensor.float
warnings.warn("Running SRU with native AMP and cast type {}".format(cast))

U = cast(U)
x = cast(x)
weight_c = cast(weight_c)
bias = cast(bias)
c_init = cast(c_init)
scale_x = cast(scale_x) if scale_x is not None else scale_x
dropout_mask_c = cast(dropout_mask_c) if dropout_mask_c is not None else dropout_mask_c

return ElementwiseRecurrence.apply(
U,
x,
weight_c,
bias,
c_init,
activation_type,
hidden_size,
bidirectional,
has_skip_term,
scale_x,
dropout_mask_c,
mask_pad
)


@torch.jit.unused
Expand Down