Skip to content

Commit

Permalink
Refactor inference optimizations (#1094)
Browse files Browse the repository at this point in the history
  • Loading branch information
xingniu committed Aug 25, 2023
1 parent 8753d95 commit a59e7a2
Show file tree
Hide file tree
Showing 9 changed files with 105 additions and 68 deletions.
9 changes: 9 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,15 @@ Note that Sockeye has checks in place to not translate with an old model that wa

Each version section may have subsections for: _Added_, _Changed_, _Removed_, _Deprecated_, and _Fixed_.

## [3.1.36]

### Changed

- Make the `inference_only` mode switchable.
- Simplify inference optimizations by
(1) using `eval()` to disable dropout instead of explicitly setting dropout modules to None;
(2) always using default value `inplace=False` for activation modules.

## [3.1.35]

### Fixed
Expand Down
4 changes: 2 additions & 2 deletions sockeye/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2017--2022 Amazon.com, Inc. or its affiliates. All Rights Reserved.
# Copyright 2017--2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You may not
# use this file except in compliance with the License. A copy of the License
Expand All @@ -11,4 +11,4 @@
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

__version__ = '3.1.35'
__version__ = '3.1.36'
14 changes: 1 addition & 13 deletions sockeye/config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2017--2022 Amazon.com, Inc. or its affiliates. All Rights Reserved.
# Copyright 2017--2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You may not
# use this file except in compliance with the License. A copy of the License
Expand Down Expand Up @@ -83,15 +83,3 @@ def copy(self, **kwargs):
for name, value in kwargs.items():
object.__setattr__(copy_obj, name, value)
return copy_obj

def disable_dropout(self):
"""
Sets the value of all float-valued attributes in this config (or any of its children) that contain 'dropout'
in their name to 0.0.
"""
for attr, val in self.__dict__.items():
if isinstance(val, Config):
val.disable_dropout()
elif 'dropout' in attr and isinstance(val, float):
logger.debug("Setting %s to 0.0", attr)
setattr(self, attr, 0.0)
24 changes: 17 additions & 7 deletions sockeye/decoder.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2017--2022 Amazon.com, Inc. or its affiliates. All Rights Reserved.
# Copyright 2017--2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You may not
# use this file except in compliance with the License. A copy of the License
Expand Down Expand Up @@ -93,6 +93,10 @@ def get_decoder(cls,
def __init__(self):
super().__init__()

@abstractmethod
def set_inference_only(self, inference_only: bool):
raise NotImplementedError()

@abstractmethod
def state_structure(self) -> str:
raise NotImplementedError()
Expand Down Expand Up @@ -147,7 +151,6 @@ def __init__(self,
Decoder.__init__(self)
pt.nn.Module.__init__(self)
self.config = config
self.inference_only = inference_only
self.pos_embedding = layers.PositionalEmbeddings(weight_type=self.config.positional_embedding_type,
num_embed=self.config.model_size,
max_seq_len=self.config.max_seq_len_target,
Expand All @@ -158,7 +161,7 @@ def __init__(self,

self.layers = pt.nn.ModuleList( # using ModuleList because we have additional inputs
transformer.TransformerDecoderBlock(config,
inference_only=self.inference_only,
inference_only=inference_only,
dtype=dtype,
clamp_to_dtype=clamp_to_dtype)
for _ in range(config.num_layers))
Expand All @@ -168,8 +171,16 @@ def __init__(self,
num_hidden=self.config.model_size,
dtype=dtype,
clamp_to_dtype=clamp_to_dtype)
if self.config.dropout_prepost > 0.0:
self.dropout = pt.nn.Dropout(p=self.config.dropout_prepost, inplace=inference_only)
self.dropout = pt.nn.Dropout(p=self.config.dropout_prepost)
self.set_inference_only(inference_only)

def set_inference_only(self, inference_only: bool):
"""
Set inference_only.
"""
self.inference_only = inference_only
for layer in self.layers:
layer.set_inference_only(inference_only)

def state_structure(self) -> str:
"""
Expand Down Expand Up @@ -279,8 +290,7 @@ def forward(self, step_input: pt.Tensor, states: List[pt.Tensor]) -> Tuple[pt.Te
# (length, batch_size, model_size)
target = target.transpose(1, 0)

if self.config.dropout_prepost > 0.0:
target = self.dropout(target)
target = self.dropout(target)

new_autoregr_states = [] # type: List[pt.Tensor]
for layer, layer_autoregr_state, layer_enc_att_kv in zip(self.layers, autoregr_states, enc_att_kv):
Expand Down
6 changes: 3 additions & 3 deletions sockeye/encoder.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2017--2022 Amazon.com, Inc. or its affiliates. All Rights Reserved.
# Copyright 2017--2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You may not
# use this file except in compliance with the License. A copy of the License
Expand Down Expand Up @@ -116,7 +116,7 @@ def __init__(self,
self.factor_embeds.append(factor_embed)
self.factor_combinations.append(fc.combine)

self.dropout = pt.nn.Dropout(p=self.config.dropout) if self.config.dropout > 0.0 else None
self.dropout = pt.nn.Dropout(p=self.config.dropout)

def forward(self, data: pt.Tensor) -> pt.Tensor:
primary_data = data[:, :, 0]
Expand Down Expand Up @@ -177,7 +177,7 @@ def __init__(self,
pt.nn.Module.__init__(self)
self.config = config

self.dropout = pt.nn.Dropout(p=config.dropout_prepost) if config.dropout_prepost > 0.0 else None
self.dropout = pt.nn.Dropout(p=config.dropout_prepost)

self.pos_embedding = layers.PositionalEmbeddings(weight_type=self.config.positional_embedding_type,
num_embed=self.config.model_size,
Expand Down
17 changes: 14 additions & 3 deletions sockeye/inference.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2017--2022 Amazon.com, Inc. or its affiliates. All Rights Reserved.
# Copyright 2017--2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You may not
# use this file except in compliance with the License. A copy of the License
Expand Down Expand Up @@ -787,6 +787,8 @@ def __init__(self,
if strip_unknown_words:
self.strip_ids.add(self.unk_id)
self.models = models
for model in self.models:
model.eval()

# after models are loaded we ensured that they agree on max_input_length, max_output_length and batch size
# set a common max_output length for all models.
Expand Down Expand Up @@ -943,8 +945,7 @@ def translate(self, trans_inputs: List[TranslatorInput], fill_up_batches: bool =
batch = batch + [batch[0]] * rest

translator_inputs = [indexed_translator_input.translator_input for indexed_translator_input in batch]
with pt.inference_mode():
batch_translations = self._translate_np(*self._get_inference_input(translator_inputs))
batch_translations = self._translate_batch(translator_inputs)

# truncate to remove filler translations
if fill_up_batches and rest > 0:
Expand Down Expand Up @@ -988,6 +989,16 @@ def translate(self, trans_inputs: List[TranslatorInput], fill_up_batches: bool =

return results

def _translate_batch(self, translator_inputs: List[TranslatorInput]) -> List[Translation]:
"""
Translate a batch of inputs.
:param translator_inputs: List of TranslatorInputs.
:return: List of Translation.
"""
with pt.inference_mode():
return self._translate_np(*self._get_inference_input(translator_inputs))

def _get_inference_input(self,
trans_inputs: List[TranslatorInput]) -> Tuple[pt.Tensor,
pt.Tensor,
Expand Down
35 changes: 27 additions & 8 deletions sockeye/layers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2017--2022 Amazon.com, Inc. or its affiliates. All Rights Reserved.
# Copyright 2017--2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You may not
# use this file except in compliance with the License. A copy of the License
Expand Down Expand Up @@ -26,12 +26,12 @@
logger = logging.getLogger(__name__)


def get_activation(act_type: str, inplace: bool = False) -> pt.nn.Module:
def get_activation(act_type: str) -> pt.nn.Module:
if act_type == C.SWISH1:
return pt.nn.SiLU(inplace=inplace)
return pt.nn.SiLU()
if act_type == C.GELU:
return pt.nn.GELU()
return pt.nn.ReLU(inplace=inplace)
return pt.nn.ReLU()


class LHUC(pt.nn.Module):
Expand Down Expand Up @@ -287,7 +287,7 @@ class DotAttentionCell(pt.nn.Module):

def __init__(self, dropout: float = 0.0, heads: int = 1) -> None:
super().__init__()
self.dropout = pt.nn.Dropout(p=dropout) if dropout > 0.0 else None
self.dropout = pt.nn.Dropout(p=dropout)
self.heads = heads

def forward(self,
Expand Down Expand Up @@ -420,6 +420,13 @@ def get_state_shape(self, batch_size) -> Tuple:
"""
raise NotImplementedError

@abstractmethod
def set_inference_only(self, inference_only: bool):
"""
Set inference_only.
"""
raise NotImplementedError

@abstractmethod
def forward(self, inputs: pt.Tensor, previous_states: pt.Tensor, *args) -> Tuple:
"""
Expand Down Expand Up @@ -461,6 +468,12 @@ def __init__(self,
# Interleaved format is used for inference, non-interleaved format is used for fused MHA in training.
self.kv_interleaved = False

def set_inference_only(self, inference_only: bool):
"""
Set inference_only. Not needed for MultiHeadSelfAttention.
"""
raise NotImplementedError

def separate_kv(self):
""" write kv input projection parameters in non-interleaved format (compatible with F.multi_head_attention) """
assert self.kv_interleaved
Expand Down Expand Up @@ -799,11 +812,9 @@ def __init__(self,
clamp_to_dtype: bool = False,) -> None:
super().__init__()
self.model_size = model_size
self.inference_only = inference_only
self.clamp_to_dtype = clamp_to_dtype

self.cell_state_transform = self._inference_cell_state_transform \
if inference_only else self._training_cell_state_transform
self.set_inference_only(inference_only)

self.forget_gate = pt.nn.Linear(in_features=model_size, out_features=model_size, bias=True, dtype=dtype)
self.forget_gate_act = pt.nn.Sigmoid()
Expand All @@ -812,6 +823,14 @@ def __init__(self,

self.relu = pt.nn.ReLU(inplace=False) # inplace=False because we need to non-activated data as well

def set_inference_only(self, inference_only: bool):
"""
Set inference_only.
"""
self.inference_only = inference_only
self.cell_state_transform = self._inference_cell_state_transform \
if inference_only else self._training_cell_state_transform

@property
def num_state_tensors(self) -> int:
""" Number of state tensors returned by the layer """
Expand Down
38 changes: 17 additions & 21 deletions sockeye/model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2017--2022 Amazon.com, Inc. or its affiliates. All Rights Reserved.
# Copyright 2017--2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You may not
# use this file except in compliance with the License. A copy of the License
Expand Down Expand Up @@ -111,7 +111,6 @@ def __init__(self,
super().__init__()
self.config = copy.deepcopy(config)
self.dtype = utils.get_torch_dtype(config.dtype)
self.inference_only = inference_only
self.clamp_to_dtype = clamp_to_dtype
logger.info("%s", self.config)
self.train_decoder_only = train_decoder_only
Expand Down Expand Up @@ -144,10 +143,10 @@ def __init__(self,
vocab_size=self.config.vocab_target_size,
weight=output_weight,
dtype=self.dtype)
if self.inference_only:
# Running this layer scripted with a newly initialized model can
# cause an overflow error.
self.output_layer = pt.jit.script(self.output_layer)
self.output_layer_module_cached = self.output_layer
# Running this layer scripted with a newly initialized model can cause an overflow error.
self.output_layer_script_cached = pt.jit.script(self.output_layer_module_cached)
self.set_inference_only(inference_only)

self.factor_output_layers = pt.nn.ModuleList()
# Optional target factor output layers
Expand Down Expand Up @@ -189,6 +188,14 @@ def __init__(self,

self.knn : Optional[layers.KNN] = None

def set_inference_only(self, inference_only: bool):
"""
Turn inference_only optimization on or off.
"""
self.inference_only = inference_only
self.output_layer = self.output_layer_script_cached if self.inference_only else \
self.output_layer_module_cached
self.decoder.set_inference_only(self.inference_only)

def cast(self, dtype: Union[pt.dtype, str]):
dtype = utils.get_torch_dtype(dtype)
Expand Down Expand Up @@ -417,7 +424,8 @@ def save_parameters(self, fname: str):
# filter their names from the state dictionary to avoid saving redundant
# copies of their parameters. Copies can also cause errors at loadtime
# if the traced modules do not yet exist.
filtered_state_dict = {name: param for (name, param) in self.state_dict().items() if 'traced' not in name}
filtered_state_dict = {name: param for (name, param) in self.state_dict().items()
if 'traced' not in name and 'cached' not in name}
pt.save(filtered_state_dict, fname)
self.apply(layers.separate_kv)
logging.info('Saved params/state_dict to "%s"', fname)
Expand Down Expand Up @@ -445,12 +453,12 @@ def load_parameters(self,
missing, unexpected = self.load_state_dict(state_dict, strict=False)
# Earlier versions of Sockeye may have saved parameters for traced
# modules. These parameters can be safely ignored.
unexpected = [key for key in unexpected if 'traced' not in key]
unexpected = [key for key in unexpected if 'traced' not in key and 'cached' not in key]
# We also ignore cases where traced modules exist and appear to be
# missing parameters. These modules actually use the same parameters as
# their original non-traced versions so there are no separate parameters
# to load.
missing = [key for key in missing if 'traced' not in key]
missing = [key for key in missing if 'traced' not in key and 'cached' not in key]
if not allow_missing:
utils.check_condition(not missing, f"missing keys: {missing}")
if not ignore_extra:
Expand Down Expand Up @@ -706,7 +714,6 @@ def load_model(model_folder: str,
inference_only: bool = False,
train_decoder_only: bool = False,
allow_missing: bool = False,
set_grad_req_null: bool = True,
forward_pass_cache_size: int = 0,
knn_index: Optional[str] = None) -> Tuple[SockeyeModel, List[vocab.Vocab], List[vocab.Vocab]]:
"""
Expand All @@ -721,7 +728,6 @@ def load_model(model_folder: str,
:param train_decoder_only: Training will only update the decoder. Disable
autograd for encoder and embeddings to save memory.
:param allow_missing: Allow missing parameters in the loaded model.
:param set_grad_req_null: Set grad_req to null for model parameters.
:param forward_pass_cache_size: If > 0, cache encoder and embedding calculations of forward pass.
:param knn_index: Optional path to a folder containing a KNN model index.
:return: List of models, source vocabularies, target vocabularies.
Expand All @@ -733,10 +739,6 @@ def load_model(model_folder: str,
utils.check_version(model_version)
model_config = SockeyeModel.load_config(os.path.join(model_folder, C.CONFIG_NAME))

if inference_only:
logger.info("Disabling dropout layers for performance reasons")
model_config.disable_dropout()

if checkpoint is None:
params_fname = os.path.join(model_folder, C.PARAMS_BEST_NAME)
else:
Expand All @@ -755,9 +757,6 @@ def load_model(model_folder: str,

model.to(device)

if set_grad_req_null:
model.eval()

if dtype is None:
logger.info("Model dtype: %s" % model.dtype)
else:
Expand All @@ -783,7 +782,6 @@ def load_models(device: pt.device,
inference_only: bool = False,
train_decoder_only: bool = False,
allow_missing: bool = False,
set_grad_req_null: bool = True,
forward_pass_cache_size: int = 0,
knn_index: Optional[str] = None) -> Tuple[List[SockeyeModel],
List[vocab.Vocab],
Expand All @@ -800,7 +798,6 @@ def load_models(device: pt.device,
:param train_decoder_only: Training will only update the decoder. Disable
autograd for encoder and embeddings to save memory.
:param allow_missing: Allow missing parameters in the loaded models.
:param set_grad_req_null: Set grad_req to null for model parameters.
:param forward_pass_cache_size: If > 0, cache encoder and embedding calculations of forward pass.
:param knn_index: Optional path to a folder containing a KNN model index.
:return: List of models, source vocabulary, target vocabulary, source factor vocabularies.
Expand All @@ -825,7 +822,6 @@ def load_models(device: pt.device,
inference_only=inference_only,
train_decoder_only=train_decoder_only,
allow_missing=allow_missing,
set_grad_req_null=set_grad_req_null,
forward_pass_cache_size=forward_pass_cache_size,
knn_index=knn_index)
models.append(model)
Expand Down

0 comments on commit a59e7a2

Please sign in to comment.