Skip to content

Commit

Permalink
Reverting PR #772. (#883)
Browse files Browse the repository at this point in the history
* Reverting PR #772.

* Correct version.
  • Loading branch information
tdomhan committed Sep 22, 2020
1 parent 5d66f4e commit 76b049b
Show file tree
Hide file tree
Showing 6 changed files with 118 additions and 27 deletions.
7 changes: 7 additions & 0 deletions CHANGELOG.md
Expand Up @@ -11,6 +11,13 @@ Note that Sockeye has checks in place to not translate with an old model that wa

Each version section may have have subsections for: _Added_, _Changed_, _Removed_, _Deprecated_, and _Fixed_.


## [2.1.25]

### Changed

- Reverting PR #772 as it causes issues with `amp`.

## [2.1.24]

### Changed
Expand Down
2 changes: 1 addition & 1 deletion sockeye/__init__.py
Expand Up @@ -11,4 +11,4 @@
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

__version__ = '2.1.24'
__version__ = '2.1.25'
25 changes: 13 additions & 12 deletions sockeye/decoder.py
Expand Up @@ -142,8 +142,10 @@ def __init__(self,
prefix=C.TARGET_POSITIONAL_EMBEDDING_PREFIX,
scale_up_input=True,
scale_down_positions=False)

self.autoregressive_bias = transformer.AutoRegressiveBias(prefix="autoregressive_bias_")
self.valid_length_mask = transformer.TransformerValidLengthMask(num_heads=self.config.attention_heads,
fold_heads=False,
name="bias")
self.layers = mx.gluon.nn.HybridSequential()
for i in range(config.num_layers):
self.layers.add(transformer.TransformerDecoderBlock(config, prefix="%d_" % i, dtype=dtype,
Expand Down Expand Up @@ -185,15 +187,14 @@ def init_state_from_encoder(self,
:param encoder_valid_length: Valid lengths of encoder outputs. Shape: (batch,).
:return: Initial states.
"""
# (batch, heads)
att_valid_length = encoder_valid_length.reshape((-1, 1)).repeat(repeats=self.config.attention_heads, axis=1)
source_mask = self.valid_length_mask(encoder_outputs, encoder_valid_length)

# (batch_size, 1)
step = mx.nd.expand_dims(mx.nd.zeros_like(encoder_valid_length), axis=1)

if self.inference_only:
# Encoder projection caching, therefore we don't pass the encoder_outputs
states = [step, att_valid_length]
states = [step, source_mask]

for layer in self.layers:
encoder_attention_keys, encoder_attention_values = \
Expand All @@ -202,7 +203,7 @@ def init_state_from_encoder(self,
states.append(encoder_attention_values)
else:
# NO encoder projection caching
states = [step, encoder_outputs, att_valid_length]
states = [step, encoder_outputs, source_mask]

batch_size = encoder_outputs.shape[0]
dummy_autoregr_states = [mx.nd.zeros(layer.get_states_shape(batch_size),
Expand Down Expand Up @@ -274,8 +275,8 @@ def forward(self, step_input, states):
new_states = [step, states[1]] + encoder_attention_keys_values + autoregr_states
else:
encoder_outputs = states[1]
att_valid_length = states[2]
new_states = [step, encoder_outputs, att_valid_length] + autoregr_states
source_mask = states[2]
new_states = [step, encoder_outputs, source_mask] + autoregr_states

assert len(new_states) == len(states)
else:
Expand All @@ -285,24 +286,24 @@ def forward(self, step_input, states):
def hybrid_forward(self, F, step_input, states):
mask = None
if self.inference_only:
steps, att_valid_length, *other = states
steps, source_mask, *other = states
source_encoded = None # use constant pre-computed key value projections from the states
enc_att_kv = other[:self.config.num_layers * 2]
enc_att_kv = [enc_att_kv[i:i + 2] for i in range(0, len(enc_att_kv), 2)]
autoregr_states = other[self.config.num_layers * 2:]
else:
if any(layer.needs_mask for layer in self.layers):
mask = self.autoregressive_bias(step_input) # mask: (1, length, length)
steps, source_encoded, att_valid_length, *autoregr_states = states
steps, source_encoded, source_mask, *autoregr_states = states
enc_att_kv = [(None, None) for _ in range(self.config.num_layers)]

if any(layer.num_state_tensors > 1 for layer in self.layers):
# separates autoregressive states by layer
states_iter = iter(autoregr_states)
autoregr_states = [list(islice(states_iter, 0, layer.num_state_tensors)) for layer in self.layers]

# Fold the heads of source_mask (batch_size, num_heads, seq_len) -> (batch_size * num_heads, seq_len)
att_valid_length = F.reshape(att_valid_length, shape=(-3, -2))
# Fold the heads of source_mask (batch_size, num_heads, seq_len) -> (batch_size * num_heads, 1, seq_len)
source_mask = F.expand_dims(F.reshape(source_mask, shape=(-3, -2)), axis=1)

# target: (batch_size, length, model_size)
target = self.pos_embedding(step_input, steps)
Expand All @@ -315,7 +316,7 @@ def hybrid_forward(self, F, step_input, states):
target, new_layer_autoregr_state = layer(target,
mask,
source_encoded,
att_valid_length,
source_mask,
layer_autoregr_state,
enc_att_k, enc_att_v)

Expand Down
9 changes: 6 additions & 3 deletions sockeye/encoder.py
Expand Up @@ -308,6 +308,9 @@ def __init__(self,
prefix=C.SOURCE_POSITIONAL_EMBEDDING_PREFIX,
scale_up_input=True,
scale_down_positions=False)
self.valid_length_mask = transformer.TransformerValidLengthMask(num_heads=self.config.attention_heads,
fold_heads=True,
name="bias")

self.layers = mx.gluon.nn.HybridSequential()
for i in range(config.num_layers):
Expand All @@ -325,11 +328,11 @@ def hybrid_forward(self, F, data, valid_length):
if self.config.dropout_prepost > 0.0:
data = F.Dropout(data=data, p=self.config.dropout_prepost)

# (batch_size * heads,)
att_valid_length = F.repeat(valid_length, repeats=self.config.attention_heads, axis=0)
# (batch_size * heads, 1, seq_len)
bias = F.expand_dims(self.valid_length_mask(data, valid_length), axis=1)

for block in self.layers:
data = block(data, att_valid_length)
data = block(data, bias)

data = self.final_process(data, None)
return data, valid_length
Expand Down
49 changes: 42 additions & 7 deletions sockeye/layers.py
Expand Up @@ -289,6 +289,31 @@ def combine_heads(F, x: mx.sym.Symbol, depth_per_head: int, heads: int) -> mx.sy
return F.reshape(x, shape=(-1, 0, depth_per_head * heads))


def broadcast_to_heads(F, x: mx.sym.Symbol, num_heads: int, ndim: int, fold_heads: bool = True) -> mx.sym.Symbol:
"""
Broadcasts batch-major input of shape (batch, d1 ... dn-1) to (batch*heads, d1 ... dn-1).
:param x: Batch-major input. Shape: (batch, d1 ... dn-1).
:param num_heads: Number of heads.
:param ndim: Number of dimensions in x.
:param fold_heads: Whether to fold heads dimension into batch dimension.
:return: Tensor with each sample repeated heads-many times.
Shape: (batch * heads, d1 ... dn-1) if fold_heads == True, (batch, heads, d1 ... dn-1) else.
"""
dims = [0] * (ndim - 1)
# x: (batch, 1)
x = F.expand_dims(x, axis=1)
# x: (batch, heads, dims...)
#x = F.broadcast_to(x, shape=[0, num_heads] + dims)
x = F.repeat(x, repeats=num_heads, axis=1)
if fold_heads:
# (batch * heads, dims...)
return F.reshape(x, shape=[-3] + dims)
else:
# x: (batch, heads, dims...)
return x


class DotAttentionCell(mx.gluon.HybridBlock):

def __init__(self, dropout: float = 0.0, prefix: str = '') -> None:
Expand All @@ -304,15 +329,23 @@ def hybrid_forward(self, F, queries, keys, values, lengths=None, bias=None):
# (n, lq, lk)
logits = F.batch_dot(lhs=queries, rhs=keys, transpose_b=True)

# TODO(fhieber): consider softmax with length argument once available.
# TODO(fhieber: Also see https://github.com/dmlc/gluon-nlp/pull/910
if lengths is not None:
# mask lk dimension
# (lk, n, lq)
logits = F.transpose(logits, axes=(2, 0, 1))
logits = F.SequenceMask(logits,
use_sequence_length=True,
sequence_length=lengths,
value=-C.LARGE_VALUES[self._dtype])
# (n, lq, lk)
logits = F.transpose(data=logits, axes=(1, 2, 0))

if bias is not None:
logits = F.broadcast_add(logits, bias)

if lengths is not None:
lengths = F.broadcast_like(F.expand_dims(lengths, axis=1), logits, lhs_axes=(1,), rhs_axes=(1,))
probs = F.softmax(logits, axis=-1, length=F.cast(lengths, dtype='int32'), use_length=True)
else:
probs = F.softmax(logits, axis=-1)

probs = F.softmax(logits, axis=-1)
probs = F.Dropout(probs, p=self.dropout) if self.dropout > 0.0 else probs

# (n, lq, lk) x (n, lk, dv) -> (n, lq, dv)
Expand Down Expand Up @@ -362,7 +395,7 @@ def _attend(self,
:param queries: Query tensor. Shape: (batch_size, heads, query_max_length, depth_per_head).
:param keys: Keys. Shape: (batch_size, heads, memory_max_length, depth_per_head).
:param values: Values. Shape: (batch_size, heads, memory_max_length, depth_per_head).
:param lengths: Optional lengths of keys. Shape: (batch_size*heads,).
:param lengths: Optional lengths of keys. Shape: (batch_size,).
:param bias: Optional 3d bias.
:return: Context vectors. Shape: (batch_size, query_max_length, output_depth).
"""
Expand All @@ -371,6 +404,8 @@ def _attend(self,
queries = F.reshape(queries, shape=(-3, -1, self.depth_per_head))
keys = F.reshape(keys, shape=(-3, -1, self.depth_per_head))
values = F.reshape(values, shape=(-3, -1, self.depth_per_head))
lengths = broadcast_to_heads(F, lengths, self.heads, ndim=1,
fold_heads=True) if lengths is not None else lengths

# (batch*heads, query_max_length, depth_per_head)
contexts = self.dot_att(queries, keys, values, lengths, bias)
Expand Down
53 changes: 49 additions & 4 deletions sockeye/transformer.py
Expand Up @@ -105,9 +105,9 @@ def __init__(self,
if config.use_lhuc:
self.lhuc = layers.LHUC(config.model_size)

def hybrid_forward(self, F, data: mx.sym.Symbol, lengths: mx.sym.Symbol) -> mx.sym.Symbol:
def hybrid_forward(self, F, data: mx.sym.Symbol, bias: mx.sym.Symbol) -> mx.sym.Symbol:
# self-attention
data_self_att, _, __ = self.self_attention(self.pre_self_attention(data, None), [None, None], lengths, None)
data_self_att, _, __ = self.self_attention(self.pre_self_attention(data, None), [None, None], None, bias)
data = self.post_self_attention(data_self_att, data)

# feed-forward
Expand Down Expand Up @@ -215,7 +215,7 @@ def hybrid_forward(self, F,
target: mx.sym.Symbol,
target_bias: mx.sym.Symbol,
source: mx.sym.Symbol,
source_att_lengths: mx.sym.Symbol,
source_bias: mx.sym.Symbol,
autoregr_states: mx.sym.Symbol,
enc_att_k: Optional[mx.sym.Symbol] = None,
enc_att_v: Optional[mx.sym.Symbol] = None) -> Tuple[mx.sym.Symbol,
Expand All @@ -230,8 +230,8 @@ def hybrid_forward(self, F,
# encoder attention
target_enc_att = self.enc_attention(self.pre_enc_attention(target, None),
source,
source_att_lengths,
None,
source_bias,
enc_att_k,
enc_att_v)

Expand Down Expand Up @@ -328,6 +328,51 @@ def hybrid_forward(self, F, x):
return y


class TransformerValidLengthMask(mx.gluon.HybridBlock):
"""
Returns bias/mask for variable sequence lengths.
:param num_heads: Number of attention heads.
:param fold_heads: Whether to fold heads dimension into batch dimension.
:param name: Name of symbol.
:return: Bias symbol. Shape: (batch, seq_len)
"""
def __init__(self, num_heads: Optional[int] = None, fold_heads: bool = True, name: str = '') -> None:
super().__init__(prefix=name)
self.num_heads = num_heads
self.fold_heads = fold_heads
self._dtype = 'float32'

def cast(self, dtype):
self._dtype = dtype
super().cast(dtype)

def hybrid_forward(self, F, data, lengths):
"""
Returns bias/mask for variable sequence lengths.
:param F: symbolic or ndarray.
:param data: Input data to mask. Shape: (batch, seq_len, _).
:param lengths: Sequence lengths. Shape: (batch,).
:return:
"""
# (batch, 1)
mask = F.reshape(F.zeros_like(lengths.astype(self._dtype)), shape=(-1, 1))
# (batch, seq_len)
mask = F.broadcast_like(mask, data, lhs_axes=(1,), rhs_axes=(1,))
# (batch_size, max_length)
mask = F.SequenceMask(data=mask,
use_sequence_length=True,
sequence_length=lengths,
axis=1,
value=-C.LARGE_VALUES[self._dtype])
if self.num_heads is not None:
# (batch_size, heads, max_length) if fold_heads == False else (batch_size * heads, max_length)
mask = layers.broadcast_to_heads(F, mask, self.num_heads, ndim=2, fold_heads=self.fold_heads)

return F.BlockGrad(mask)


class AutoRegressiveBias(mx.gluon.HybridBlock):
def __init__(self, prefix: str = '',) -> None:
super().__init__(prefix=prefix)
Expand Down

0 comments on commit 76b049b

Please sign in to comment.