Skip to content

Commit

Permalink
[ONNX] First version of quantized model export: Support quantized.Lin…
Browse files Browse the repository at this point in the history
…ear (pytorch#69232)

Co-authored-by: David Fan <jiafa@microsoft.com>
  • Loading branch information
2 people authored and BowenBao committed Jan 5, 2022
1 parent 0434a42 commit b79d6ba
Show file tree
Hide file tree
Showing 13 changed files with 319 additions and 88 deletions.
26 changes: 23 additions & 3 deletions test/onnx/test_pytorch_onnx_onnxruntime.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from torch.nn.utils.rnn import PackedSequence
from torch.onnx import CheckerError, register_custom_op_symbolic, unregister_custom_op_symbolic
from torch.onnx.symbolic_helper import _unimplemented
from torch.onnx.utils import unpack_quantized_tensor


def flatten_tuples(elem):
Expand Down Expand Up @@ -108,17 +109,24 @@ def inline_flatten_list(inputs, res_list):
return res_list


def unpack_to_numpy(value):
value_unpacked = []
for value_ in value:
value_unpacked.extend(unpack_quantized_tensor(value_))
value_final = [to_numpy(v) for v in value_unpacked]
return value_final


def run_ort(ort_sess, input):
input = flatten_tuples(input)
input = to_numpy(input)
input = unpack_to_numpy(flatten_tuples(input))
ort_inputs = dict((ort_sess.get_inputs()[i].name, input) for i, input in enumerate(input))
ort_outs = ort_sess.run(None, ort_inputs)
return inline_flatten_list(ort_outs, [])


def ort_compare_with_pytorch(ort_outs, output, rtol, atol):
output, _ = torch.jit._flatten(output)
outputs = [to_numpy(outp) for outp in output]
outputs = unpack_to_numpy(output)

# compare onnxruntime and PyTorch results
assert len(outputs) == len(ort_outs), "number of outputs differ"
Expand Down Expand Up @@ -10218,6 +10226,18 @@ def forward(self, x):
loaded_model = onnx.load_from_string(f.getvalue())
self.assertEqual(loaded_model.graph.output[0].type.tensor_type.shape.dim[1].dim_value, 128)

@skipIfUnsupportedMinOpsetVersion(10)
def test_quantized_linear(self):
model = torch.nn.quantized.Linear(1, 2)
input = torch.rand(1, 1)
input_tensor = torch.quantize_per_tensor(input, 1, 0, torch.quint8)
# Currently, we need convert the model to ScriptModule before export.
# The reason is that PackedParams contains int (not tensor).
# Then it fails when the exporter calls _trace_and_get_graph_from_model().
# TODO: https://msdata.visualstudio.com/Vienna/_workitems/edit/1547858
self.run_test(torch.jit.trace(model, input_tensor), (input_tensor,))
self.run_test(torch.jit.script(model), (input_tensor,))

def make_test(name, base, layer, bidirectional, initial_state,
variable_length, dropout, script_test_min_opset_version,
**extra_kwargs):
Expand Down
3 changes: 2 additions & 1 deletion torch/_C/__init__.pyi.in
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,8 @@ def _jit_pass_onnx_remove_print(graph: Graph) -> None: ...
def _jit_pass_onnx_preprocess_caffe2(graph: Graph) -> None: ...
def _jit_pass_onnx_unpack_quantized_weights(
graph: Graph,
paramsDict: Dict[str, IValue]
paramsDict: Dict[str, IValue],
caffe2: _bool
) -> Dict[str, IValue]: ...
def _jit_pass_onnx_quantization_insert_permutes(
graph: Graph,
Expand Down
7 changes: 7 additions & 0 deletions torch/csrc/jit/passes/onnx/helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,5 +61,12 @@ Node* transformToONNXConcatNode(
bool need_new_input,
int opset_version);

class ScalarTypeHashFunction {
public:
size_t operator()(const c10::ScalarType& type) const {
return static_cast<size_t>(type);
}
};

} // namespace jit
} // namespace torch
19 changes: 19 additions & 0 deletions torch/csrc/jit/passes/onnx/peephole.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -761,6 +761,24 @@ static void fuseListConstructListUnpack(Block* b) {
}
}

// https://github.com/pytorch/pytorch/wiki/PyTorch-ONNX-exporter#quantized-model-export
static void eraseTupleConstruct(Block* block) {
size_t index = 0;
// TupleConstruct is generated from the symbolics in quantized domain, and consumed
// by other quantized operators. The remained TupleConstruct should be at the output of the blocks.
for (auto* output : block->outputs()) {
auto output_node = output->node();
if (output_node->kind() == prim::TupleConstruct) {
block->eraseOutput(index);
size_t input_index = 0;
for (auto* input: output_node->inputs()) {
block->insertOutput(index + (input_index++), input);
}
}
index++;
}
}

void removeMaxPoolUnusedOutput(Block* b) {
for (auto it = b->nodes().begin(), end = b->nodes().end(); it != end; ++it) {
auto n = *it;
Expand Down Expand Up @@ -1025,6 +1043,7 @@ void PeepholeOptimizeONNX(
fuseListConstructListUnpack(graph->block());
fuseLogSoftmaxNllLoss(graph->block());
eraseListConstruct(graph->block(), opset_version);
eraseTupleConstruct(graph->block());
EliminateDeadCode(
graph->block(),
true,
Expand Down
8 changes: 1 addition & 7 deletions torch/csrc/jit/passes/onnx/scalar_type_analysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include <torch/csrc/jit/jit_log.h>
#include <torch/csrc/jit/passes/dead_code_elimination.h>
#include <torch/csrc/jit/passes/onnx/scalar_type_analysis.h>
#include <torch/csrc/jit/passes/onnx/helper.h>

namespace torch {
namespace jit {
Expand All @@ -11,13 +12,6 @@ using namespace ::c10::onnx;
}

namespace {
class ScalarTypeHashFunction {
public:
size_t operator()(const c10::ScalarType& type) const {
return static_cast<size_t>(type);
}
};

const int ONNX_OPSET_14 = 14;

static const std::unordered_map<c10::ScalarType, int, ScalarTypeHashFunction>
Expand Down
Loading

0 comments on commit b79d6ba

Please sign in to comment.