diff --git a/CHANGELOG.md b/CHANGELOG.md index 8ad54f16b..d0e694f93 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,13 @@ Note that Sockeye has checks in place to not translate with an old model that wa Each version section may have have subsections for: _Added_, _Changed_, _Removed_, _Deprecated_, and _Fixed_. + +## [2.1.25] + +### Changed + +- Reverting PR #772 as it causes issues with `amp`. + ## [2.1.24] ### Changed diff --git a/sockeye/__init__.py b/sockeye/__init__.py index 738323fae..e2782f453 100644 --- a/sockeye/__init__.py +++ b/sockeye/__init__.py @@ -11,4 +11,4 @@ # express or implied. See the License for the specific language governing # permissions and limitations under the License. -__version__ = '2.1.24' +__version__ = '2.1.25' diff --git a/sockeye/decoder.py b/sockeye/decoder.py index 48ce20ba9..cbb026923 100644 --- a/sockeye/decoder.py +++ b/sockeye/decoder.py @@ -142,8 +142,10 @@ def __init__(self, prefix=C.TARGET_POSITIONAL_EMBEDDING_PREFIX, scale_up_input=True, scale_down_positions=False) - self.autoregressive_bias = transformer.AutoRegressiveBias(prefix="autoregressive_bias_") + self.valid_length_mask = transformer.TransformerValidLengthMask(num_heads=self.config.attention_heads, + fold_heads=False, + name="bias") self.layers = mx.gluon.nn.HybridSequential() for i in range(config.num_layers): self.layers.add(transformer.TransformerDecoderBlock(config, prefix="%d_" % i, dtype=dtype, @@ -185,15 +187,14 @@ def init_state_from_encoder(self, :param encoder_valid_length: Valid lengths of encoder outputs. Shape: (batch,). :return: Initial states. """ - # (batch, heads) - att_valid_length = encoder_valid_length.reshape((-1, 1)).repeat(repeats=self.config.attention_heads, axis=1) + source_mask = self.valid_length_mask(encoder_outputs, encoder_valid_length) # (batch_size, 1) step = mx.nd.expand_dims(mx.nd.zeros_like(encoder_valid_length), axis=1) if self.inference_only: # Encoder projection caching, therefore we don't pass the encoder_outputs - states = [step, att_valid_length] + states = [step, source_mask] for layer in self.layers: encoder_attention_keys, encoder_attention_values = \ @@ -202,7 +203,7 @@ def init_state_from_encoder(self, states.append(encoder_attention_values) else: # NO encoder projection caching - states = [step, encoder_outputs, att_valid_length] + states = [step, encoder_outputs, source_mask] batch_size = encoder_outputs.shape[0] dummy_autoregr_states = [mx.nd.zeros(layer.get_states_shape(batch_size), @@ -274,8 +275,8 @@ def forward(self, step_input, states): new_states = [step, states[1]] + encoder_attention_keys_values + autoregr_states else: encoder_outputs = states[1] - att_valid_length = states[2] - new_states = [step, encoder_outputs, att_valid_length] + autoregr_states + source_mask = states[2] + new_states = [step, encoder_outputs, source_mask] + autoregr_states assert len(new_states) == len(states) else: @@ -285,7 +286,7 @@ def forward(self, step_input, states): def hybrid_forward(self, F, step_input, states): mask = None if self.inference_only: - steps, att_valid_length, *other = states + steps, source_mask, *other = states source_encoded = None # use constant pre-computed key value projections from the states enc_att_kv = other[:self.config.num_layers * 2] enc_att_kv = [enc_att_kv[i:i + 2] for i in range(0, len(enc_att_kv), 2)] @@ -293,7 +294,7 @@ def hybrid_forward(self, F, step_input, states): else: if any(layer.needs_mask for layer in self.layers): mask = self.autoregressive_bias(step_input) # mask: (1, length, length) - steps, source_encoded, att_valid_length, *autoregr_states = states + steps, source_encoded, source_mask, *autoregr_states = states enc_att_kv = [(None, None) for _ in range(self.config.num_layers)] if any(layer.num_state_tensors > 1 for layer in self.layers): @@ -301,8 +302,8 @@ def hybrid_forward(self, F, step_input, states): states_iter = iter(autoregr_states) autoregr_states = [list(islice(states_iter, 0, layer.num_state_tensors)) for layer in self.layers] - # Fold the heads of source_mask (batch_size, num_heads, seq_len) -> (batch_size * num_heads, seq_len) - att_valid_length = F.reshape(att_valid_length, shape=(-3, -2)) + # Fold the heads of source_mask (batch_size, num_heads, seq_len) -> (batch_size * num_heads, 1, seq_len) + source_mask = F.expand_dims(F.reshape(source_mask, shape=(-3, -2)), axis=1) # target: (batch_size, length, model_size) target = self.pos_embedding(step_input, steps) @@ -315,7 +316,7 @@ def hybrid_forward(self, F, step_input, states): target, new_layer_autoregr_state = layer(target, mask, source_encoded, - att_valid_length, + source_mask, layer_autoregr_state, enc_att_k, enc_att_v) diff --git a/sockeye/encoder.py b/sockeye/encoder.py index 23f61bc19..ec4ea41ea 100644 --- a/sockeye/encoder.py +++ b/sockeye/encoder.py @@ -308,6 +308,9 @@ def __init__(self, prefix=C.SOURCE_POSITIONAL_EMBEDDING_PREFIX, scale_up_input=True, scale_down_positions=False) + self.valid_length_mask = transformer.TransformerValidLengthMask(num_heads=self.config.attention_heads, + fold_heads=True, + name="bias") self.layers = mx.gluon.nn.HybridSequential() for i in range(config.num_layers): @@ -325,11 +328,11 @@ def hybrid_forward(self, F, data, valid_length): if self.config.dropout_prepost > 0.0: data = F.Dropout(data=data, p=self.config.dropout_prepost) - # (batch_size * heads,) - att_valid_length = F.repeat(valid_length, repeats=self.config.attention_heads, axis=0) + # (batch_size * heads, 1, seq_len) + bias = F.expand_dims(self.valid_length_mask(data, valid_length), axis=1) for block in self.layers: - data = block(data, att_valid_length) + data = block(data, bias) data = self.final_process(data, None) return data, valid_length diff --git a/sockeye/layers.py b/sockeye/layers.py index 691f7a8d2..52d1ff3a9 100644 --- a/sockeye/layers.py +++ b/sockeye/layers.py @@ -289,6 +289,31 @@ def combine_heads(F, x: mx.sym.Symbol, depth_per_head: int, heads: int) -> mx.sy return F.reshape(x, shape=(-1, 0, depth_per_head * heads)) +def broadcast_to_heads(F, x: mx.sym.Symbol, num_heads: int, ndim: int, fold_heads: bool = True) -> mx.sym.Symbol: + """ + Broadcasts batch-major input of shape (batch, d1 ... dn-1) to (batch*heads, d1 ... dn-1). + + :param x: Batch-major input. Shape: (batch, d1 ... dn-1). + :param num_heads: Number of heads. + :param ndim: Number of dimensions in x. + :param fold_heads: Whether to fold heads dimension into batch dimension. + :return: Tensor with each sample repeated heads-many times. + Shape: (batch * heads, d1 ... dn-1) if fold_heads == True, (batch, heads, d1 ... dn-1) else. + """ + dims = [0] * (ndim - 1) + # x: (batch, 1) + x = F.expand_dims(x, axis=1) + # x: (batch, heads, dims...) + #x = F.broadcast_to(x, shape=[0, num_heads] + dims) + x = F.repeat(x, repeats=num_heads, axis=1) + if fold_heads: + # (batch * heads, dims...) + return F.reshape(x, shape=[-3] + dims) + else: + # x: (batch, heads, dims...) + return x + + class DotAttentionCell(mx.gluon.HybridBlock): def __init__(self, dropout: float = 0.0, prefix: str = '') -> None: @@ -304,15 +329,23 @@ def hybrid_forward(self, F, queries, keys, values, lengths=None, bias=None): # (n, lq, lk) logits = F.batch_dot(lhs=queries, rhs=keys, transpose_b=True) + # TODO(fhieber): consider softmax with length argument once available. + # TODO(fhieber: Also see https://github.com/dmlc/gluon-nlp/pull/910 + if lengths is not None: + # mask lk dimension + # (lk, n, lq) + logits = F.transpose(logits, axes=(2, 0, 1)) + logits = F.SequenceMask(logits, + use_sequence_length=True, + sequence_length=lengths, + value=-C.LARGE_VALUES[self._dtype]) + # (n, lq, lk) + logits = F.transpose(data=logits, axes=(1, 2, 0)) + if bias is not None: logits = F.broadcast_add(logits, bias) - if lengths is not None: - lengths = F.broadcast_like(F.expand_dims(lengths, axis=1), logits, lhs_axes=(1,), rhs_axes=(1,)) - probs = F.softmax(logits, axis=-1, length=F.cast(lengths, dtype='int32'), use_length=True) - else: - probs = F.softmax(logits, axis=-1) - + probs = F.softmax(logits, axis=-1) probs = F.Dropout(probs, p=self.dropout) if self.dropout > 0.0 else probs # (n, lq, lk) x (n, lk, dv) -> (n, lq, dv) @@ -362,7 +395,7 @@ def _attend(self, :param queries: Query tensor. Shape: (batch_size, heads, query_max_length, depth_per_head). :param keys: Keys. Shape: (batch_size, heads, memory_max_length, depth_per_head). :param values: Values. Shape: (batch_size, heads, memory_max_length, depth_per_head). - :param lengths: Optional lengths of keys. Shape: (batch_size*heads,). + :param lengths: Optional lengths of keys. Shape: (batch_size,). :param bias: Optional 3d bias. :return: Context vectors. Shape: (batch_size, query_max_length, output_depth). """ @@ -371,6 +404,8 @@ def _attend(self, queries = F.reshape(queries, shape=(-3, -1, self.depth_per_head)) keys = F.reshape(keys, shape=(-3, -1, self.depth_per_head)) values = F.reshape(values, shape=(-3, -1, self.depth_per_head)) + lengths = broadcast_to_heads(F, lengths, self.heads, ndim=1, + fold_heads=True) if lengths is not None else lengths # (batch*heads, query_max_length, depth_per_head) contexts = self.dot_att(queries, keys, values, lengths, bias) diff --git a/sockeye/transformer.py b/sockeye/transformer.py index 2dace2008..d764b0ec5 100644 --- a/sockeye/transformer.py +++ b/sockeye/transformer.py @@ -105,9 +105,9 @@ def __init__(self, if config.use_lhuc: self.lhuc = layers.LHUC(config.model_size) - def hybrid_forward(self, F, data: mx.sym.Symbol, lengths: mx.sym.Symbol) -> mx.sym.Symbol: + def hybrid_forward(self, F, data: mx.sym.Symbol, bias: mx.sym.Symbol) -> mx.sym.Symbol: # self-attention - data_self_att, _, __ = self.self_attention(self.pre_self_attention(data, None), [None, None], lengths, None) + data_self_att, _, __ = self.self_attention(self.pre_self_attention(data, None), [None, None], None, bias) data = self.post_self_attention(data_self_att, data) # feed-forward @@ -215,7 +215,7 @@ def hybrid_forward(self, F, target: mx.sym.Symbol, target_bias: mx.sym.Symbol, source: mx.sym.Symbol, - source_att_lengths: mx.sym.Symbol, + source_bias: mx.sym.Symbol, autoregr_states: mx.sym.Symbol, enc_att_k: Optional[mx.sym.Symbol] = None, enc_att_v: Optional[mx.sym.Symbol] = None) -> Tuple[mx.sym.Symbol, @@ -230,8 +230,8 @@ def hybrid_forward(self, F, # encoder attention target_enc_att = self.enc_attention(self.pre_enc_attention(target, None), source, - source_att_lengths, None, + source_bias, enc_att_k, enc_att_v) @@ -328,6 +328,51 @@ def hybrid_forward(self, F, x): return y +class TransformerValidLengthMask(mx.gluon.HybridBlock): + """ + Returns bias/mask for variable sequence lengths. + + :param num_heads: Number of attention heads. + :param fold_heads: Whether to fold heads dimension into batch dimension. + :param name: Name of symbol. + :return: Bias symbol. Shape: (batch, seq_len) + """ + def __init__(self, num_heads: Optional[int] = None, fold_heads: bool = True, name: str = '') -> None: + super().__init__(prefix=name) + self.num_heads = num_heads + self.fold_heads = fold_heads + self._dtype = 'float32' + + def cast(self, dtype): + self._dtype = dtype + super().cast(dtype) + + def hybrid_forward(self, F, data, lengths): + """ + Returns bias/mask for variable sequence lengths. + + :param F: symbolic or ndarray. + :param data: Input data to mask. Shape: (batch, seq_len, _). + :param lengths: Sequence lengths. Shape: (batch,). + :return: + """ + # (batch, 1) + mask = F.reshape(F.zeros_like(lengths.astype(self._dtype)), shape=(-1, 1)) + # (batch, seq_len) + mask = F.broadcast_like(mask, data, lhs_axes=(1,), rhs_axes=(1,)) + # (batch_size, max_length) + mask = F.SequenceMask(data=mask, + use_sequence_length=True, + sequence_length=lengths, + axis=1, + value=-C.LARGE_VALUES[self._dtype]) + if self.num_heads is not None: + # (batch_size, heads, max_length) if fold_heads == False else (batch_size * heads, max_length) + mask = layers.broadcast_to_heads(F, mask, self.num_heads, ndim=2, fold_heads=self.fold_heads) + + return F.BlockGrad(mask) + + class AutoRegressiveBias(mx.gluon.HybridBlock): def __init__(self, prefix: str = '',) -> None: super().__init__(prefix=prefix)