Skip to content

Commit

Permalink
Sockeye 2 heafield quantize pr2 (#812)
Browse files Browse the repository at this point in the history
* Quantize CLI, Docker build update, version/changelog update.
  • Loading branch information
mjdenkowski committed May 22, 2020
1 parent e4553d3 commit 50393fc
Show file tree
Hide file tree
Showing 15 changed files with 311 additions and 181 deletions.
13 changes: 13 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,19 @@ Note that Sockeye has checks in place to not translate with an old model that wa

Each version section may have have subsections for: _Added_, _Changed_, _Removed_, _Deprecated_, and _Fixed_.

## [2.1.6]

### Changed

- Updated Dockerfiles optimized for CPU (intgemm int8 inference, full MKL support) and GPU (distributed training with Horovod). See [sockeye_contrib/docker](sockeye_contrib/docker).

### Added

- Official support for int8 quantization with [intgemm](https://github.com/kpu/intgemm):
- This requires the "intgemm" fork of MXNet ([kpuatamazon/incubator-mxnet/intgemm](https://github.com/kpuatamazon/incubator-mxnet/tree/intgemm)). This is the version of MXNet used in the Sockeye CPU docker image (see [sockeye_contrib/docker](sockeye_contrib/docker)).
- Use `sockeye.translate --dtype int8` to quantize a trained float32 model at runtime.
- Use the `sockeye.quantize` CLI to annotate a float32 model with int8 scaling factors for fast runtime quantization.

## [2.1.5]

### Changed
Expand Down
2 changes: 1 addition & 1 deletion requirements/requirements.horovod.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
horovod==0.18.1
horovod==0.19.1
mpi4py
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def get_requirements(filename):
'sockeye-lexicon = sockeye.lexicon:main',
'sockeye-init-embed = sockeye.init_embedding:main',
'sockeye-prepare-data = sockeye.prepare_data:main',
'sockeye-quantize = sockeye.quantize:main',
'sockeye-score = sockeye.score:main',
'sockeye-train = sockeye.train:main',
'sockeye-translate = sockeye.translate:main',
Expand Down
4 changes: 2 additions & 2 deletions sockeye/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2017--2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
# Copyright 2017--2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You may not
# use this file except in compliance with the License. A copy of the License
Expand All @@ -11,4 +11,4 @@
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

__version__ = '2.1.5'
__version__ = '2.1.6'
5 changes: 3 additions & 2 deletions sockeye/beam_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -593,8 +593,9 @@ def forward(self,
full_to_reduced = dict((val, i) for i, val in enumerate(vocab_slice_ids))
raw_constraint_list = [[[full_to_reduced[x] for x in phr] for phr in sent] for sent in
raw_constraint_list]
#Pad to a multiple of 8.
vocab_slice_ids = np.pad(vocab_slice_ids, (0,7-((len(vocab_slice_ids)-1) % 8)), mode='constant', constant_values = self.eos_id)
# Pad to a multiple of 8.
vocab_slice_ids = np.pad(vocab_slice_ids, (0, 7 - ((len(vocab_slice_ids) - 1) % 8)),
mode='constant', constant_values = self.eos_id)
vocab_slice_ids = mx.nd.array(vocab_slice_ids, ctx=self.context, dtype='int32')

if vocab_slice_ids.shape[0] < self.beam_size + 1:
Expand Down
2 changes: 2 additions & 0 deletions sockeye/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@

VERSION_NAME = "version"
CONFIG_NAME = "config"
CONFIG_NAME_FLOAT32 = CONFIG_NAME + ".float32"
LOG_NAME = "log"
JSON_SUFFIX = ".json"
VOCAB_SRC_PREFIX = "vocab.src"
Expand All @@ -195,6 +196,7 @@
PARAMS_PREFIX = "params."
PARAMS_NAME = PARAMS_PREFIX + "%05d"
PARAMS_BEST_NAME = "params.best"
PARAMS_BEST_NAME_FLOAT32 = PARAMS_BEST_NAME + ".float32"
DECODE_OUT_NAME = "decode.output.%05d"
DECODE_IN_NAME = "decode.source.%d"
DECODE_REF_NAME = "decode.target"
Expand Down
18 changes: 9 additions & 9 deletions sockeye/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,9 +129,9 @@ def __init__(self,
if weight is None or dtype == C.DTYPE_INT8:
if dtype == C.DTYPE_INT8:
self.scaling = self.params.get('scaling', shape=(1,), init=mx.initializer.Constant(-1.0), dtype=C.DTYPE_FP32, allow_deferred_init=False)
#This is only for inference but MXNet tries to create an
#initializer anyway, then fails because most random
#generators don't support int8 output.
# This is only for inference but MXNet tries to create an
# initializer anyway, then fails because most random
# generators don't support int8 output.
weight_initializer = 'zeros'
self.weight = self.params.get("weight",
shape=(vocab_size, hidden_size),
Expand Down Expand Up @@ -444,7 +444,7 @@ def __init__(self,

self.depth_att = depth_att
with self.name_scope():
self.ff_in = quantization.QuantizableDense(in_units=depth_att, units=depth_att * 3, flatten=False, use_bias=False, prefix='i2h_', dtype = dtype)
self.ff_in = quantization.QuantizableDense(in_units=depth_att, units=depth_att * 3, flatten=False, use_bias=False, prefix='i2h_', dtype=dtype)

def hybrid_forward(self, F,
inputs: mx.sym.Symbol,
Expand Down Expand Up @@ -526,9 +526,9 @@ def __init__(self,
super().__init__(prefix, depth_att, heads, depth_out, dropout, dtype)

with self.name_scope():
self.ff_q = quantization.QuantizableDense(in_units=depth_out, units=depth_att, flatten=False, use_bias=False, prefix='q2h_', dtype = dtype)
self.ff_k = quantization.QuantizableDense(in_units=depth_key_value, units=depth_att, flatten=False, use_bias=False, prefix='k2h_', dtype = dtype)
self.ff_v = quantization.QuantizableDense(in_units=depth_key_value, units=depth_att, flatten=False, use_bias=False, prefix='v2h_', dtype = dtype)
self.ff_q = quantization.QuantizableDense(in_units=depth_out, units=depth_att, flatten=False, use_bias=False, prefix='q2h_', dtype=dtype)
self.ff_k = quantization.QuantizableDense(in_units=depth_key_value, units=depth_att, flatten=False, use_bias=False, prefix='k2h_', dtype=dtype)
self.ff_v = quantization.QuantizableDense(in_units=depth_key_value, units=depth_att, flatten=False, use_bias=False, prefix='v2h_', dtype=dtype)

def project_and_isolate_heads(self, F, memory: mx.sym.Symbol) -> Tuple[mx.sym.Symbol, mx.sym.Symbol]:
"""
Expand Down Expand Up @@ -617,8 +617,8 @@ def __init__(self,
super().__init__(prefix=prefix)
self.num_hidden = num_hidden
with self.name_scope():
self.q2h = quantization.QuantizableDense(units=num_hidden, flatten=False, use_bias=True, dtype = dtype)
self.kv2h = quantization.QuantizableDense(units=num_hidden * 2, flatten=False, use_bias=True, dtype = dtype)
self.q2h = quantization.QuantizableDense(units=num_hidden, flatten=False, use_bias=True, dtype=dtype)
self.kv2h = quantization.QuantizableDense(units=num_hidden * 2, flatten=False, use_bias=True, dtype=dtype)
self.dot_att = DotAttentionCell()

def hybrid_forward(self, F,
Expand Down
26 changes: 16 additions & 10 deletions sockeye/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ class ModelConfig(Config):
:param weight_tying_type: Determines which weights get tied.
:param lhuc: LHUC (Vilar 2018) is applied at some part of the model.
:param dtype: Data type of model parameters. Default: float32.
:param intgemm_custom_lib: Path to intgemm custom operator library used for dtype is int8. Default: libintgemm.so in the same directory as this script.
:param intgemm_custom_lib: Path to intgemm custom operator library used for dtype is int8. Default: libintgemm.so
in the same directory as this script.
"""

def __init__(self,
Expand Down Expand Up @@ -120,7 +121,8 @@ def __init__(self, config: ModelConfig, inference_only: bool = False, prefix: st

# encoder & decoder first (to know the decoder depth)
self.encoder = encoder.get_encoder(self.config.config_encoder, prefix=self.prefix, dtype=config.dtype)
self.decoder = decoder.get_decoder(self.config.config_decoder, inference_only=inference_only, prefix=self.prefix, dtype=config.dtype)
self.decoder = decoder.get_decoder(self.config.config_decoder, inference_only=inference_only,
prefix=self.prefix, dtype=config.dtype)

self.output_layer = layers.OutputLayer(hidden_size=self.decoder.get_num_hidden(),
vocab_size=self.config.vocab_target_size,
Expand Down Expand Up @@ -452,7 +454,7 @@ def load_model(model_folder: str,
checkpoint: Optional[int] = None,
hybridize: bool = True,
inference_only: bool = False,
for_disk_saving: str = None,
for_disk_saving: Optional[str] = None,
allow_missing: bool = False,
set_grad_req_null: bool = True) -> Tuple[SockeyeModel, List[vocab.Vocab], vocab.Vocab]:
"""
Expand Down Expand Up @@ -490,15 +492,19 @@ def load_model(model_folder: str,
else:
params_fname = os.path.join(model_folder, C.PARAMS_NAME % checkpoint)

if (dtype == C.DTYPE_INT8 or model_config.dtype == C.DTYPE_INT8 or for_disk_saving is not None) and "intgemm_fully_connected" not in dir(mx.nd.contrib):
#We're going to use int8 but it's not compiled into mxnet.
if (dtype == C.DTYPE_INT8 or
model_config.dtype == C.DTYPE_INT8 or
for_disk_saving is not None) and "intgemm_fully_connected" not in dir(mx.nd.contrib):
# We're going to use int8 but it's not compiled into mxnet.
path = os.path.abspath(model_config.intgemm_custom_lib)
try:
mx.library.load(path)
except(mx.base.MXNetError):
raise NotImplementedError("8-bit int inference requested but intgemm was not compiled into MXNet and a custom operator library was not found in `" + path + "`. Compile the custom operator then set the path using intgemm_custom_lib in the config file.")
except mx.base.MXNetError:
raise NotImplementedError("8-bit int inference requested but intgemm was not compiled into MXNet and a "
"custom operator library was not found in `%s`. Compile the custom "
"operator then set the path using intgemm_custom_lib in the config file." % path)

#Are we converting the model to 8-bit?
# Are we converting the model to 8-bit?
quantizing = model_config.dtype != C.DTYPE_INT8 and (dtype == C.DTYPE_INT8 or for_disk_saving is not None)
if quantizing:
model_config.dtype = C.DTYPE_INT8 # Ensure the scaling factor parameters are created.
Expand Down Expand Up @@ -535,12 +541,12 @@ def load_model(model_folder: str,
ignore_extra=True, #Scaling factors may be present in float32 models.
cast_dtype=cast_dtype,
dtype_source=dtype_source)

params = model.collect_params()
if set_grad_req_null:
for param in params.values():
param.grad_req = 'null'

if for_disk_saving is not None:
#Saving scaling factors and possibly int8 values to disk.
if not quantizing:
Expand Down
70 changes: 36 additions & 34 deletions sockeye/quantization.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You may not
# use this file except in compliance with the License. A copy of the License
Expand All @@ -11,14 +11,17 @@
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

import mxnet as mx
import logging
import math
from . import constants as C

import mxnet as mx
from mxnet.gluon.nn.activations import Activation
import logging

from . import constants as C

logger = logging.getLogger(__name__)


# Modified from the source to mxnet.gluon.nn.basic_layers.Dense which is:
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
Expand Down Expand Up @@ -130,7 +133,7 @@ def cast(self, dtype):
#No casting an already quantized matrix.
logger.warning("Ignoring casting on int8 matrix")

def hybrid_forward(self, F, x, weight, scaling = None, bias=None):
def hybrid_forward(self, F, x, weight, scaling=None, bias=None):
if self._dtype == C.DTYPE_INT8:
if bias is not None:
act = F.contrib.intgemm_fully_connected(x, weight, scaling, bias, no_bias=False, num_hidden=self._units,
Expand All @@ -155,28 +158,29 @@ def __repr__(self):
layout='{0} -> {1}'.format(shape[1] if shape[1] else None, shape[0]))


#Minimize mean squared error of quantizing a tensor, returning the top value
#(i.e. the one that quantizes to 127). Scaling = 127.0 / return value.
def optimize_quantization_mse(tensor, rounds = 10):
#This is a convex optimization problem. EM works but makes slow steps.
#Instead of EM, use binary search in the direction minimization suggests.
def optimize_quantization_mse(tensor, rounds=10):
"""
Minimize mean squared error of quantizing a tensor, returning the top value
(i.e. the one that quantizes to 127). Scaling = 127.0 / return value.
This is a convex optimization problem. EM works but makes slow steps.
Instead of EM, use binary search in the direction minimization suggests.
"""
best_mse = math.inf
best_top = None
maxabs = mx.nd.contrib.intgemm_maxabsolute(tensor)
# For converting python numbers to MXNet NDArray
one = mx.nd.ones(shape=(1,))
low = 0.0
high = maxabs
for i in range(rounds):
for _ in range(rounds):
value = (low + high) / 2.0
quant = mx.nd.contrib.intgemm_prepare_data(tensor, value)
quant_float = mx.nd.cast(quant, dtype=C.DTYPE_FP32)
mse = (quant_float * (value / 127.0) - tensor).norm().asscalar() / math.sqrt(float(tensor.size))
if mse < best_mse:
best_mse = mse
best_top = value
#This optimizes scaling subject to cluster assignment.
#It can be used for EM but the step is really slow, so use it for direction.
# This optimizes scaling subject to cluster assignment.
# It can be used for EM but the step is really slow, so use it for direction.
scale = mx.nd.sum(quant_float * quant_float) / mx.nd.sum(quant_float * tensor)
top = 127.0 / scale.asscalar()
if top < value:
Expand All @@ -185,18 +189,19 @@ def optimize_quantization_mse(tensor, rounds = 10):
low = value
return best_top


def extract_quant_max(tensor_param: mx.gluon.parameter.Parameter, scaling_param: mx.gluon.parameter.Parameter) -> float:
"""
Extract or tune the scaling factor for a parameter.
"""
scaling = scaling_param.data()
if scaling.asscalar() < 0:
#Bogus auto initialized scaling factor.
b_max = optimize_quantization_mse(tensor_param.data())
scaling_param.set_data(b_max / 127.0)
else:
b_max = scaling * 127.0
return b_max
"""
Extract or tune the scaling factor for a parameter.
"""
scaling = scaling_param.data()
if scaling.asscalar() < 0:
# Bogus auto initialized scaling factor.
b_max = optimize_quantization_mse(tensor_param.data())
scaling_param.set_data(b_max / 127.0)
else:
b_max = scaling * 127.0
return b_max


def convert_weights_disk_format(params: mx.gluon.parameter.ParameterDict, dtype_store: str):
Expand All @@ -221,6 +226,7 @@ def convert_weights_disk_format(params: mx.gluon.parameter.ParameterDict, dtype_
param.set_data(quantized)
param.dtype = C.DTYPE_INT8


def convert_weights_cpu_dependent(params: mx.gluon.parameter.ParameterDict):
"""
Convert weights from disk format to intgemm's CPU-dependent format for
Expand All @@ -235,15 +241,11 @@ def convert_weights_cpu_dependent(params: mx.gluon.parameter.ParameterDict):
scaling_name = name[0:-6] + "scaling"
if scaling_name in params:
if param.dtype == C.DTYPE_INT8:
#Already fully quantized, just rearrange.
weight = mx.nd.contrib.intgemm_prepare_weight(
param.data(), already_quantized = True)
# Already fully quantized, just rearrange.
weight = mx.nd.contrib.intgemm_prepare_weight(param.data(), already_quantized = True)
else:
#Use offline scaling factor if available.
# Use offline scaling factor if available.
b_max = extract_quant_max(param, params[scaling_name])
weight = mx.nd.contrib.intgemm_prepare_weight(
param.data(),
b_max)
weight = mx.nd.contrib.intgemm_prepare_weight(param.data(), b_max)
param.set_data(weight)
param.dtype = C.DTYPE_INT8

59 changes: 59 additions & 0 deletions sockeye/quantize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You may not
# use this file except in compliance with the License. A copy of the License
# is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is distributed on
# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

import argparse
import logging
import os

import sockeye.constants as C
from sockeye.log import setup_main_logger, log_sockeye_version
import sockeye.model
from sockeye.utils import check_condition

logger = logging.getLogger(__name__)


def annotate_model_params(model_dir: str):
log_sockeye_version(logger)

params_best = os.path.join(model_dir, C.PARAMS_BEST_NAME)
params_best_float32 = os.path.join(model_dir, C.PARAMS_BEST_NAME_FLOAT32)
config = os.path.join(model_dir, C.CONFIG_NAME)
config_float32 = os.path.join(model_dir, C.CONFIG_NAME_FLOAT32)

for fname in params_best_float32, config_float32:
check_condition(not os.path.exists(fname),
'File "%s" exists, indicating this model has already been quantized.' % fname)

# Load model and compute scaling factors
model = sockeye.model.load_model(model_dir, for_disk_saving='float32', dtype='int8')
# Move original params and config files
os.rename(params_best, params_best_float32)
os.rename(config, config_float32)
# Write new params and config files with annotated scaling factors
model[0].save_parameters(params_best)
model[0].save_config(model_dir)


def main():
setup_main_logger(console=True, file_logging=False)
params = argparse.ArgumentParser(
description='Annotate trained model with scaling factors for fast loading/quantization for int8 inference.')
params.add_argument('--model', '-m', required=True, help='Trained Sockeye model directory.')
args = params.parse_args()

annotate_model_params(args.model)


if __name__ == '__main__':
main()
Loading

0 comments on commit 50393fc

Please sign in to comment.