Skip to content

Commit

Permalink
Add blocking cross-attention between decoder and encoded prepended to…
Browse files Browse the repository at this point in the history
…kens (#1085)

* Add blocking cross-attention between decoder and encoded prepended tokens
* Use a new dictionary-based prepared data format
  • Loading branch information
xingniu committed Feb 23, 2023
1 parent 13c63be commit d912554
Show file tree
Hide file tree
Showing 25 changed files with 427 additions and 100 deletions.
15 changes: 15 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,21 @@ Note that Sockeye has checks in place to not translate with an old model that wa

Each version section may have subsections for: _Added_, _Changed_, _Removed_, _Deprecated_, and _Fixed_.

## [3.1.32]

### Added

- Sockeye now supports blocking cross-attention between decoder and encoded prepended tokens.
- If the source contains prepended text and a tag indicating the end of prepended text,
Sockeye supports blocking the cross-attention between decoder and encoded prepended tokens (including the tag).
To enable this operation, specify `--end-of-prepending-tag` for training or data preparation,
and `--transformer-block-prepended-cross-attention` for training.

### Changed

- Sockeye uses a new dictionary-based prepared data format that supports storing length of prepended source tokens
(version 7). The previous format (version 6) is still supported.

## [3.1.31]

### Fixed
Expand Down
7 changes: 7 additions & 0 deletions docs/training.md
Original file line number Diff line number Diff line change
Expand Up @@ -185,3 +185,10 @@ This is similar to using `--restrict-lexicon` for `sockeye-translate` with the a
To use NVS simply specify `--neural-vocab-selection` to `sockeye-train`.
This will train a model with NVS that is automatically used by `sockeye-translate`.
If you want look at translations without vocabulary selection specify `--skip-nvs` as an argument to `sockeye-translate`.
## Prepended Source Text
If the source contains prepended text and a tag indicating the end of prepended text,
Sockeye supports blocking the cross-attention between decoder and encoded prepended tokens (including the tag).
To enable this operation, specify `--end-of-prepending-tag` for training or data preparation,
and `--transformer-block-prepended-cross-attention` for training.
2 changes: 1 addition & 1 deletion sockeye/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

__version__ = '3.1.31'
__version__ = '3.1.32'
11 changes: 11 additions & 0 deletions sockeye/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,12 @@ def add_training_data_args(params, required=False):
required=required,
type=regular_file(),
help='Target side of parallel training data.')
params.add_argument('--end-of-prepending-tag',
type=str,
default=None,
help='Tag indicating the end of prepended text. Prepended tokens before this tag (inclusive) '
'will be marked, and they will not be counted toward source length when calculating '
'maximum output length for beam search.')


def add_validation_data_params(params):
Expand Down Expand Up @@ -687,6 +693,11 @@ def add_model_parameters(params):
choices=C.POSITIONAL_EMBEDDING_TYPES,
default=C.FIXED_POSITIONAL_EMBEDDING,
help='The type of positional embedding. Default: %(default)s.')
model_params.add_argument('--transformer-block-prepended-cross-attention',
action='store_true',
default=False,
help='Block cross-attention between decoder and encoded prepended tokens. '
'Default: %(default)s.')
model_params.add_argument('--transformer-preprocess',
type=multiple_values(num_values=2, greater_or_equal=None, data_type=str),
default=('n', 'n'),
Expand Down
8 changes: 7 additions & 1 deletion sockeye/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
UNK_ID = VOCAB_SYMBOLS.index(UNK_SYMBOL)
BOS_ID = VOCAB_SYMBOLS.index(BOS_SYMBOL)
EOS_ID = VOCAB_SYMBOLS.index(EOS_SYMBOL)
INVALID_ID = -1 # an example of invalid ids (i.e., negative integers)
# reserve extra space for the EOS or BOS symbol that is added to both source and target
SPACE_FOR_XOS = 1

Expand Down Expand Up @@ -336,13 +337,18 @@
FIXED_PARAM_STRATEGY_ENCODER_HALF_AND_SOURCE_EMBEDDINGS]

# data sharding
DATA_KEY_SOURCE = 'source'
DATA_KEY_TARGET = 'target'
DATA_KEY_PREPENDED_SOURCE_LENGTH = 'prepended_source_length'
SHARD_NAME = "shard.%05d"
SHARD_SOURCE = SHARD_NAME + ".source"
SHARD_TARGET = SHARD_NAME + ".target"
SHARD_PREPENDED_SOURCE_LENGTH = SHARD_NAME + ".prepended_source_length"
DATA_INFO = "data.info"
DATA_CONFIG = "data.config"
PREPARED_DATA_VERSION_FILE = "data.version"
PREPARED_DATA_VERSION = 6
PREPARED_DATA_VERSION = 7
PREPARED_DATA_LEGACY_VERSION = 6

# reranking metric options
RERANK_BLEU = "bleu"
Expand Down
185 changes: 156 additions & 29 deletions sockeye/data_io.py

Large diffs are not rendered by default.

15 changes: 6 additions & 9 deletions sockeye/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,26 +200,23 @@ def init_state_from_encoder(self,
[autoregressive state dummies] * num_layers.
:param encoder_outputs: Encoder outputs. Shape: (batch, source_length, encoder_dim).
:param encoder_valid_length: Valid lengths of encoder outputs. Shape: (batch,).
:param encoder_valid_length: Valid lengths of encoder outputs. Shape: (batch, 2).
:param target_embed: Target-side embedding layer output. Shape: (batch, target_length, target_embedding_dim).
:return: Initial states.
"""
source_max_len = encoder_outputs.size()[1]
# (batch * heads, 1, source_max_len)
source_mask = layers.prepare_source_length_mask(encoder_valid_length, self.config.attention_heads,
source_max_len, mask_prepended_tokens=
self.config.block_prepended_cross_attention)
if target_embed is None: # Inference: initial step = 0. Shape: (batch_size, 1)
steps = pt.zeros_like(encoder_valid_length).unsqueeze(1)
# (batch * heads, 1, source_max_len)
source_mask = layers.prepare_source_length_mask(encoder_valid_length, self.config.attention_heads,
source_max_len)
steps = pt.zeros_like(encoder_valid_length[:, :1])
# Shape: (batch, heads, 1, src_max_len)
source_mask = source_mask.view(-1, self.config.attention_heads, 1, source_max_len)
else: # Training: steps up to target length. Shape: (1, target_length)
target_length = target_embed.size()[1]
steps = pt.arange(0, target_length, device=target_embed.device).unsqueeze(0)
# (batch * heads, 1, source_max_len)
source_mask = layers.prepare_source_length_mask(encoder_valid_length, self.config.attention_heads,
source_max_len)
source_mask = source_mask.expand(-1, target_length, -1) # Shape: (batch * heads, trg_max_len, src_max_len)

# Shape: (batch, heads, trg_max_len, src_max_len)
source_mask = source_mask.view(-1, self.config.attention_heads, target_length, source_max_len)

Expand Down
11 changes: 8 additions & 3 deletions sockeye/generate_decoder_states.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,13 +152,15 @@ def init_store_file(self, initial_size: int) -> None:
def generate_states_and_store(self,
sources: List[str],
targets: List[str],
batch_size: int) -> None:
batch_size: int,
eop_id: int = C.INVALID_ID) -> None:
"""
Generate decoder states by force-decoding the sentence pairs in `sources` and `targets` with a NMT model.
:param sources: list of source segments.
:param targets: list of target segments.
:param batch_size: number of sentence pairs to decode at once.
:param eop_id: End-of-prepending tag id.
"""
assert self.state_store_file != None, \
"You should call probe_token_count first to initialize the store files."
Expand All @@ -171,7 +173,8 @@ def generate_states_and_store(self,
target_vocabs=self.target_vocabs,
batch_size=batch_size,
max_seq_len_source=self.max_seq_len_source,
max_seq_len_target=self.max_seq_len_target
max_seq_len_target=self.max_seq_len_target,
eop_id=eop_id
)

with pt.inference_mode():
Expand Down Expand Up @@ -254,7 +257,7 @@ def store(args: argparse.Namespace):
args.state_dtype, C.KNN_WORD_DATA_STORE_DTYPE, device)
generator.num_states = DecoderStateGenerator.probe_token_count(targets[0], max_seq_len_target)
generator.init_store_file(generator.num_states)
generator.generate_states_and_store(sources, targets, args.batch_size)
generator.generate_states_and_store(sources, targets, args.batch_size, model.eop_id)
generator.save_config()


Expand All @@ -272,6 +275,8 @@ def main():
level=args.loglevel) # pylint: disable=no-member

utils.log_basic_info(args)
if args.end_of_prepending_tag is not None:
logger.warning("The end-of-prepending tag defined in the model will be used.")

store(args)

Expand Down
21 changes: 15 additions & 6 deletions sockeye/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from . import utils
from . import vocab
from .beam_search import CandidateScorer, get_search_algorithm, GreedySearch, SearchResult
from .data_io import tokens2ids
from .data_io import tokens2ids, get_prepended_token_length
from .model import SockeyeModel

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -858,6 +858,10 @@ def num_source_factors(self) -> int:
def num_target_factors(self) -> int:
return self.models[0].num_target_factors

@property
def eop_id(self) -> int:
return self.models[0].eop_id

def translate(self, trans_inputs: List[TranslatorInput], fill_up_batches: bool = True) -> List[TranslatorOutput]:
"""
Batch-translates a list of TranslatorInputs, returns a list of TranslatorOutputs.
Expand Down Expand Up @@ -1001,13 +1005,14 @@ def _get_inference_input(self,
optional target prefix, and optional target prefix factors.
"""
batch_size = len(trans_inputs)
lengths = [len(inp) for inp in trans_inputs]

max_target_prefix_length = max(inp.num_target_prefix_tokens for inp in trans_inputs)
max_target_prefix_factors_length = max(inp.num_target_prefix_factors for inp in trans_inputs)
max_length = max(len(inp) for inp in trans_inputs)
# assembling source ids on cpu array (faster) and copy to Translator.device (potentially GPU) in one go below.
source_np = np.zeros((batch_size, max_length, self.num_source_factors), dtype='int32')
# total token length and prepended token length
length_np = np.zeros((batch_size, 2), dtype='int32')

target_prefix_np = np.zeros((batch_size, max_target_prefix_length), dtype='int32') \
if max_target_prefix_length > 0 else None
Expand All @@ -1019,9 +1024,13 @@ def _get_inference_input(self,
max_output_lengths = [] # type: List[int]
for j, trans_input in enumerate(trans_inputs):
num_tokens = len(trans_input) # includes eos
max_output_lengths.append(self._get_max_output_length(num_tokens))
source_np[j, :num_tokens, 0] = tokens2ids(itertools.chain(trans_input.get_source_prefix_tokens(),
trans_input.tokens), self.source_vocabs[0])
primary_source_ids = tokens2ids(itertools.chain(trans_input.get_source_prefix_tokens(),
trans_input.tokens), self.source_vocabs[0])
source_np[j, :num_tokens, 0] = primary_source_ids
length_np[j, 0] = num_tokens
length_np[j, 1] = get_prepended_token_length(primary_source_ids, self.eop_id)
# the effective source length excludes prepended tokens
max_output_lengths.append(self._get_max_output_length(length_np[j, 0] - length_np[j, 1]))
if target_prefix_np is not None and trans_input.num_target_prefix_tokens > 0:
target_prefix_np[j, :trans_input.num_target_prefix_tokens] = \
tokens2ids(trans_input.get_target_prefix_tokens(), self.vocab_targets[0])
Expand Down Expand Up @@ -1068,7 +1077,7 @@ def _get_inference_input(self,
"will default to not using a restrict lexicon.")

source = pt.tensor(source_np, device=self.device, dtype=pt.int32)
source_length = pt.tensor(lengths, device=self.device, dtype=pt.int32) # shape: (batch_size,)
source_length = pt.tensor(length_np, device=self.device, dtype=pt.int32) # shape: (batch_size, 2)
max_out_lengths = pt.tensor(max_output_lengths, device=self.device, dtype=pt.int32)
target_prefix = pt.tensor(target_prefix_np, device=self.device, dtype=pt.int32) \
if target_prefix_np is not None else None
Expand Down
26 changes: 18 additions & 8 deletions sockeye/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,16 +320,26 @@ def forward(self,
return interleaved_matmul_encdec_valatt(key_values, probs, heads=self.heads)


def prepare_source_length_mask(lengths: pt.Tensor, heads: int, max_length: int, expand=True) -> pt.Tensor:
def prepare_source_length_mask(lengths: pt.Tensor, heads: int, max_length: int, expand: bool = True,
mask_prepended_tokens: bool = True) -> pt.Tensor:
"""
lengths: (batch_size,)
expand: Expand to the heads.
Prepare source length masks where positions of invalid tokens are marked as True.
:param lengths: Total source length and prepended source length. Shape: (batch_size, 2)
:param heads: Number of attention heads.
:param max_length: Maximum sequence length.
:param expand: Expand to the heads.
:param mask_prepended_tokens: Mask prepended tokens.
:return: Source length mask.
"""
# (batch_size, max_len)
mask = ~(pt.arange(max_length, device=lengths.device).unsqueeze(0) < lengths.reshape((-1, 1)))
mask = ~(pt.arange(max_length, device=lengths.device).unsqueeze(0) < lengths[:, :1])
if mask_prepended_tokens:
prepended_token_mask = pt.arange(max_length, device=lengths.device).unsqueeze(0) < lengths[:, 1:2]
mask |= prepended_token_mask
if expand:
# (batch_size*heads, 1, max_len)
mask = mask.unsqueeze(1).expand(-1, heads, -1).reshape((-1, max_length)).unsqueeze(1)
# (batch_size * heads, 1, max_len)
mask = mask.unsqueeze(1).expand(-1, heads, -1).reshape((-1, max_length)).unsqueeze(1)
return mask


Expand Down Expand Up @@ -673,7 +683,7 @@ def separate_kv(module: pt.nn.Module):
def get_positional_embeddings(length: int, depth: int) -> pt.Tensor:
utils.check_condition(depth % 2 == 0, "Positional embeddings require an even embedding size it "
"is however %d." % depth)
# (1, depth)
# (1, depth/2)
channels = pt.arange(depth // 2).unsqueeze(0)

# (length, 1)
Expand All @@ -683,7 +693,7 @@ def get_positional_embeddings(length: int, depth: int) -> pt.Tensor:
sin = pt.sin(scaled_positions)
# cosines:
cos = pt.cos(scaled_positions)
# interleave: (length, num_embed)
# stack sin and cos: (length, depth)
encodings = pt.hstack([sin, cos])
return encodings

Expand Down
14 changes: 9 additions & 5 deletions sockeye/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def encode(self, inputs: pt.Tensor, valid_length: pt.Tensor) -> Tuple[pt.Tensor,
Encodes the input sequence.
:param inputs: Source input data. Shape: (batch_size, length, num_source_factors).
:param valid_length: Optional Tensor of sequence lengths within this batch. Shape: (batch_size,)
:param valid_length: Optional Tensor of sequence lengths within this batch. Shape: (batch_size, 2)
:return: Encoder outputs, encoded output lengths, attention mask
"""
if self.traced_embedding_source is None:
Expand All @@ -240,17 +240,17 @@ def encode_and_initialize(self, inputs: pt.Tensor, valid_length: pt.Tensor,
Used for inference/decoding.
:param inputs: Source input data. Shape: (batch_size, length, num_source_factors).
:param valid_length: Tensor of sequence lengths within this batch. Shape: (batch_size,)
:param valid_length: Tensor of sequence lengths within this batch. Shape: (batch_size, 2)
:param constant_length_ratio: Constant length ratio
:return: Initial states for the decoder, predicted output length of shape (batch_size,), 0 if not available.
Returns the neural vocabulary selection model prediction if enabled, None otherwise.
"""

# Encode input. Shape: (batch, length, num_hidden), (batch,)
# Encode input. Shape: (batch, length, num_hidden), (batch, 2), (batch * heads, 1, length)
source_encoded, source_encoded_lengths, att_mask = self.encode(inputs, valid_length=valid_length)

predicted_output_length = self.predict_output_length(source_encoded,
source_encoded_lengths,
source_encoded_lengths[:, 0], # total source length
constant_length_ratio)
# Decoder init states
states = self.decoder.init_state_from_encoder(source_encoded, source_encoded_lengths)
Expand Down Expand Up @@ -345,7 +345,7 @@ def forward(self, source, source_length, target, target_length): # pylint: disa

if self.length_ratio is not None:
# predicted_length_ratios: (batch_size,)
forward_output[C.LENRATIO_NAME] = self.length_ratio(source_encoded, source_encoded_length)
forward_output[C.LENRATIO_NAME] = self.length_ratio(source_encoded, source_encoded_length[:, 0])

if nvs_prediction is not None:
forward_output[C.NVS_PRED_NAME] = nvs_prediction
Expand Down Expand Up @@ -606,6 +606,10 @@ def length_ratio_std(self) -> float:
def output_layer_vocab_size(self) -> int:
return self.output_layer.vocab_size

@property
def eop_id(self) -> int:
return self.config.config_data.eop_id

def _cache_wrapper(self, class_func):
@lru_cache(maxsize=self.forward_pass_cache_size)
def cache_func(*args):
Expand Down
3 changes: 2 additions & 1 deletion sockeye/nvs.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def __init__(self,
self.project_vocab = pt.nn.Linear(model_size, vocab_target_size, bias=True, dtype=dtype)

def forward(self, source_encoded: pt.Tensor, source_length: pt.Tensor, att_mask: pt.Tensor):
# TODO: att_mask might need to include prepended token masks
if self.model_type == C.NVS_TYPE_LOGIT_MAX:
# ============
# logit max:
Expand All @@ -43,7 +44,7 @@ def forward(self, source_encoded: pt.Tensor, source_length: pt.Tensor, att_mask:
# EOS based:
# ============
batch_size, max_len, _ = source_encoded.size()
source_encoded = source_encoded[pt.arange(0, batch_size, dtype=pt.long), (source_length-1).long()]
source_encoded = source_encoded[pt.arange(0, batch_size, dtype=pt.long), (source_length[:, 0] - 1).long()]
bow_pred = self.project_vocab(source_encoded)
else:
raise ValueError("Unknown neural vocabulary selection type.")
Expand Down
1 change: 1 addition & 0 deletions sockeye/prepare_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ def prepare_data(args: argparse.Namespace):
num_shards=num_shards,
output_prefix=output_folder,
bucket_scaling=bucket_scaling,
end_of_prepending_tag=args.end_of_prepending_tag,
pool=pool,
shards=shards,
keep_tmp_shard_files=keep_tmp_shard_files)
Expand Down
5 changes: 4 additions & 1 deletion sockeye/score.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ def score(args: argparse.Namespace):
level=args.loglevel) # pylint: disable=no-member

utils.log_basic_info(args)
if args.end_of_prepending_tag is not None:
logger.warning("The end-of-prepending tag defined in the model will be used.")

device = utils.init_device(args)
logger.info(f"Scoring device: {device}")
Expand Down Expand Up @@ -76,7 +78,8 @@ def score(args: argparse.Namespace):
target_vocabs=target_vocabs,
batch_size=args.batch_size,
max_seq_len_source=max_seq_len_source,
max_seq_len_target=max_seq_len_target)
max_seq_len_target=max_seq_len_target,
eop_id=model.eop_id)

constant_length_ratio = args.brevity_penalty_constant_length_ratio
if args.brevity_penalty_type == C.BREVITY_PENALTY_CONSTANT:
Expand Down

0 comments on commit d912554

Please sign in to comment.