From 23687dc550fa45d90e359771e1ba59d39f7d2fc0 Mon Sep 17 00:00:00 2001 From: Fabian Boemer Date: Fri, 15 Mar 2019 13:45:06 -0700 Subject: [PATCH] Fboemer/ngtf v0.12.0 rc0 (#144) * Update to ng-tf v0.12.0-rc0 * ng to v0.15.1-rc.2 --- README.md | 4 +- cmake/ngraph-tf.cmake | 2 +- src/CMakeLists.txt | 1 + src/he_backend.cpp | 912 +--------------------------------- src/he_backend.hpp | 36 +- src/he_executable.cpp | 932 +++++++++++++++++++++++++++++++++++ src/he_executable.hpp | 64 +++ src/he_plain_tensor.cpp | 14 +- src/he_plain_tensor.hpp | 6 +- test/test_add.in.cpp | 15 +- test/test_avg_pool.in.cpp | 18 +- test/test_basics.in.cpp | 24 +- test/test_broadcast.in.cpp | 27 +- test/test_concat.in.cpp | 15 +- test/test_constant.in.cpp | 6 +- test/test_convolution.in.cpp | 21 +- test/test_cryptonets.cpp | 3 +- test/test_dot.in.cpp | 15 +- test/test_layers.in.cpp | 36 +- test/test_multiply.in.cpp | 12 +- test/test_negate.in.cpp | 3 +- test/test_pad.in.cpp | 54 +- test/test_reshape.in.cpp | 33 +- test/test_reverse.in.cpp | 45 +- test/test_slice.in.cpp | 21 +- test/test_subtract.in.cpp | 9 +- test/test_sum.in.cpp | 27 +- 27 files changed, 1277 insertions(+), 1078 deletions(-) create mode 100644 src/he_executable.cpp create mode 100644 src/he_executable.hpp diff --git a/README.md b/README.md index 078b66e4..5bd79e95 100644 --- a/README.md +++ b/README.md @@ -26,8 +26,8 @@ The [examples](https://github.com/NervanaSystems/he-transformer/tree/master/exam - virtualenv v16.1.0 - bazel v0.16.0 #### The following dependencies are built automatically -- [nGraph](https://github.com/NervanaSystems/ngraph) - v0.14.0 -- [nGraph-tf](https://github.com/NervanaSystems/ngraph-tf) - v0.11.0 +- [nGraph](https://github.com/NervanaSystems/ngraph) - v0.15.1-rc.2 +- [nGraph-tf](https://github.com/NervanaSystems/ngraph-tf) - v0.12-rc0 - [SEAL](https://github.com/Microsoft/SEAL) - version 3.2 ### To install bazel diff --git a/cmake/ngraph-tf.cmake b/cmake/ngraph-tf.cmake index 3c454cb1..c3c1f4da 100644 --- a/cmake/ngraph-tf.cmake +++ b/cmake/ngraph-tf.cmake @@ -20,7 +20,7 @@ set(EXTERNAL_NGRAPH_INSTALL_DIR ${EXTERNAL_INSTALL_DIR}) set(NGRAPH_TF_CMAKE_PREFIX ext_ngraph_tf) SET(NGRAPH_TF_REPO_URL https://github.com/NervanaSystems/ngraph-tf.git) -SET(NGRAPH_TF_GIT_LABEL v0.11.0) +SET(NGRAPH_TF_GIT_LABEL v0.12.0-rc0) SET(NGRAPH_TF_SRC_DIR ${CMAKE_BINARY_DIR}/${NGRAPH_TF_CMAKE_PREFIX}/src/${NGRAPH_TF_CMAKE_PREFIX}) SET(NGRAPH_TF_BUILD_DIR ${NGRAPH_TF_SRC_DIR}/build) diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index f28eaf7b..1e17cf1a 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -42,6 +42,7 @@ set (HE_SRC # main he_backend.cpp + he_executable.cpp he_cipher_tensor.cpp he_plain_tensor.cpp he_tensor.cpp diff --git a/src/he_backend.cpp b/src/he_backend.cpp index 7161876a..a14becf2 100644 --- a/src/he_backend.cpp +++ b/src/he_backend.cpp @@ -18,47 +18,15 @@ #include "he_backend.hpp" #include "he_cipher_tensor.hpp" +#include "he_executable.hpp" #include "he_plain_tensor.hpp" #include "he_tensor.hpp" -#include "kernel/add.hpp" -#include "kernel/avg_pool.hpp" -#include "kernel/broadcast.hpp" -#include "kernel/concat.hpp" -#include "kernel/constant.hpp" -#include "kernel/convolution.hpp" -#include "kernel/dot.hpp" -#include "kernel/multiply.hpp" -#include "kernel/negate.hpp" -#include "kernel/pad.hpp" -#include "kernel/reshape.hpp" -#include "kernel/result.hpp" -#include "kernel/reverse.hpp" -#include "kernel/slice.hpp" -#include "kernel/subtract.hpp" -#include "kernel/sum.hpp" #include "ngraph/descriptor/layout/dense_tensor_layout.hpp" #include "ngraph/function.hpp" -#include "ngraph/op/avg_pool.hpp" -#include "ngraph/op/broadcast.hpp" -#include "ngraph/op/concat.hpp" -#include "ngraph/op/constant.hpp" -#include "ngraph/op/convolution.hpp" -#include "ngraph/op/dot.hpp" -#include "ngraph/op/pad.hpp" -#include "ngraph/op/reshape.hpp" -#include "ngraph/op/result.hpp" -#include "ngraph/op/reverse.hpp" -#include "ngraph/op/slice.hpp" -#include "ngraph/op/sum.hpp" -#include "ngraph/pass/assign_layout.hpp" -#include "ngraph/pass/manager.hpp" -#include "ngraph/pass/visualize_tree.hpp" using namespace ngraph; using namespace std; -using descriptor::layout::DenseTensorLayout; - shared_ptr runtime::he::HEBackend::create_valued_plaintext( float value, const element::Type& element_type) const { @@ -142,877 +110,9 @@ shared_ptr runtime::he::HEBackend::create_valued_plain_tensor( return tensor; } -runtime::Handle runtime::he::HEBackend::compile(shared_ptr function) { - FunctionInstance& instance = m_function_map[function]; - if (!instance.m_is_compiled) { - instance.m_is_compiled = true; - pass::Manager pass_manager; - pass_manager.register_pass(); - pass_manager.register_pass>(); - pass_manager.register_pass(); - pass_manager.run_passes(function); - - for (const shared_ptr& node : function->get_ordered_ops()) { - instance.m_wrapped_nodes.emplace_back(node); - } - } - return function; -} - -void runtime::he::HEBackend::validate_he_call( - shared_ptr function, - const vector>& outputs, - const vector>& inputs) { - const ParameterVector& input_parameters = function->get_parameters(); - if (input_parameters.size() != inputs.size()) { - stringstream ss; - ss << "Call input count " << inputs.size() - << " does not match Function's Parameter count " - << input_parameters.size(); - throw runtime_error(ss.str()); - } - if (function->get_output_size() != outputs.size()) { - stringstream ss; - ss << "Call output count " << outputs.size() - << " does not match Function's Result count " - << function->get_output_size(); - throw runtime_error(ss.str()); - } - - for (size_t i = 0; i < input_parameters.size(); i++) { - if (inputs[i]->get_element_type() != - input_parameters[i]->get_element_type()) { - stringstream ss; - ss << "Input " << i << " type '" << inputs[i]->get_element_type() - << "' does not match Parameter type '" - << input_parameters[i]->get_element_type() << "'"; - throw runtime_error(ss.str()); - } - if (inputs[i]->get_expanded_shape() != input_parameters[i]->get_shape()) { - stringstream ss; - ss << "Input " << i << " shape {" << join(inputs[i]->get_expanded_shape()) - << "} does not match Parameter shape {" - << join(input_parameters[i]->get_shape()) << "}"; - throw runtime_error(ss.str()); - } - } - - for (size_t i = 0; i < function->get_output_size(); i++) { - if (outputs[i]->get_element_type() != - function->get_output_element_type(i)) { - stringstream ss; - ss << "Output " << i << " type '" << outputs[i]->get_element_type() - << "' does not match Result type '" - << function->get_output_element_type(i) << "'"; - throw runtime_error(ss.str()); - } - if (function->get_output_shape(i) != outputs[i]->get_expanded_shape()) { - stringstream ss; - ss << "Output " << i << " shape {" - << join(outputs[i]->get_expanded_shape()) - << "} does not match Result shape {" - << join(function->get_output_shape(i)) << "}"; - throw runtime_error(ss.str()); - } - } +std::shared_ptr runtime::he::HEBackend::compile( + shared_ptr function, bool enable_performance_collection) { + return make_shared(function, enable_performance_collection, + this, m_encrypt_data, m_encrypt_model, + m_batch_data); } - -void runtime::he::HEBackend::clear_function_instance() { - m_function_map.clear(); -} - -void runtime::he::HEBackend::remove_compiled_function( - shared_ptr function) { - m_function_map.erase(function); -} - -void runtime::he::HEBackend::enable_performance_data( - shared_ptr function, bool enable) { - // Enabled by default -} - -vector -runtime::he::HEBackend::get_performance_data( - shared_ptr function) const { - vector rc; - const FunctionInstance& instance = m_function_map.at(function); - for (const pair p : instance.m_timer_map) { - rc.emplace_back(p.first->get_name().c_str(), - p.second.get_total_microseconds(), - p.second.get_call_count()); - } - return rc; -} - -bool runtime::he::HEBackend::call( - shared_ptr function, - const vector>& outputs, - const vector>& inputs) { - if (encrypt_data()) { - NGRAPH_INFO << "Encrypting data"; - } - if (batch_data()) { - NGRAPH_INFO << "Batching data"; - } - if (encrypt_model()) { - NGRAPH_INFO << "Encrypting model"; - } - - auto fit = m_function_map.find(function); - if (fit == m_function_map.end()) { - throw runtime_error("compile() must be called before call()."); - } - FunctionInstance& instance = fit->second; - - // convert outputs to HETensor - vector> he_inputs; - for (auto& tv : inputs) { - he_inputs.push_back(static_pointer_cast(tv)); - } - - // convert inputs to HETensor - vector> he_outputs; - for (auto& tv : outputs) { - he_outputs.push_back(static_pointer_cast(tv)); - } - - validate_he_call(function, he_outputs, he_inputs); - - // map function params -> HETensor - unordered_map> - tensor_map; - size_t input_count = 0; - for (auto param : function->get_parameters()) { - for (size_t i = 0; i < param->get_output_size(); ++i) { - descriptor::Tensor* tv = param->get_output_tensor_ptr(i).get(); - - if (encrypt_data()) { - NGRAPH_INFO << "Encrypting parameter " << i; - auto plain_input = static_pointer_cast( - he_inputs[input_count]); - assert(plain_input != nullptr); - auto cipher_input = static_pointer_cast( - create_cipher_tensor(plain_input->get_element_type(), - plain_input->get_shape(), batch_data())); - - NGRAPH_INFO << "plain_input->get_batched_element_count() " - << plain_input->get_batched_element_count(); -#pragma omp parallel for - for (size_t i = 0; i < plain_input->get_batched_element_count(); ++i) { - encrypt(cipher_input->get_element(i), *plain_input->get_element(i)); - } - - NGRAPH_INFO << "Done encrypting parameter " << i; - - tensor_map.insert({tv, cipher_input}); - input_count++; - } else { - tensor_map.insert({tv, he_inputs[input_count++]}); - } - } - } - - // map function outputs -> HostTensor - for (size_t output_count = 0; output_count < function->get_output_size(); - ++output_count) { - auto output = function->get_output_op(output_count); - if (!dynamic_pointer_cast(output)) { - throw ngraph_error("One of function's outputs isn't op::Result"); - } - descriptor::Tensor* tv = output->get_output_tensor_ptr(0).get(); - tensor_map.insert({tv, he_outputs[output_count++]}); - } - - // for each ordered op in the graph - for (const NodeWrapper& wrapped : instance.m_wrapped_nodes) { - const Node* op = &wrapped.get_node(); - auto type_id = wrapped.get_typeid(); - - NGRAPH_INFO << "\033[1;32m" - << "[ " << op->get_name() << " ]" - << "\033[0m"; - - if (type_id == OP_TYPEID::Parameter) { - NGRAPH_INFO << "Parameter shape {" << join(op->get_shape()) << "}"; - continue; - } - - if (op->description() == "Constant") { - NGRAPH_INFO << "Constant shape {" << join(op->get_shape()) << "}"; - } - - // get op inputs from map - vector> op_inputs; - for (const descriptor::Input& input : op->get_inputs()) { - descriptor::Tensor* tv = input.get_output().get_tensor_ptr().get(); - op_inputs.push_back(tensor_map.at(tv)); - } - - // get op outputs from map or create - vector> op_outputs; - for (size_t i = 0; i < op->get_output_size(); ++i) { - descriptor::Tensor* tv = op->get_output_tensor_ptr(i).get(); - auto it = tensor_map.find(tv); - if (it == tensor_map.end()) { - // The output tensor is not in the tensor map so create a new tensor - const Shape& shape = op->get_output_shape(i); - const element::Type& element_type = op->get_output_element_type(i); - string name = op->get_output_tensor(i).get_name(); - - bool plain_out = all_of( - op_inputs.begin(), op_inputs.end(), - [](shared_ptr op_input) { - return dynamic_pointer_cast(op_input) != nullptr; - }); - if (op->is_constant()) { - plain_out = !encrypt_model(); - } - - bool batched_out = any_of(op_inputs.begin(), op_inputs.end(), - [](shared_ptr he_tv) { - return he_tv->is_batched(); - }); - if (plain_out) { - auto otv = make_shared( - element_type, shape, this, create_empty_plaintext(), batched_out, - name); - tensor_map.insert({tv, otv}); - } else { - auto otv = make_shared( - element_type, shape, this, create_empty_ciphertext(), batched_out, - name); - tensor_map.insert({tv, otv}); - } - } - op_outputs.push_back(tensor_map.at(tv)); - } - - // get op type - element::Type base_type; - if (op->get_inputs().empty()) { - base_type = op->get_element_type(); - } else { - base_type = op->get_inputs().at(0).get_tensor().get_element_type(); - } - - instance.m_timer_map[op].start(); - generate_calls(base_type, wrapped, op_outputs, op_inputs, instance); - instance.m_timer_map[op].stop(); - - const string op_name = op->description(); - - // delete any obsolete tensors - for (const descriptor::Tensor* t : op->liveness_free_list) { - for (auto it = tensor_map.begin(); it != tensor_map.end(); ++it) { - if (it->second->get_name() == t->get_name()) { - tensor_map.erase(it); - break; - } - } - } - NGRAPH_INFO << "\033[1;31m" << op->get_name() << " took " - << instance.m_timer_map[op].get_milliseconds() << "ms" - << "\033[0m"; - } - size_t total_time = 0; - for (const auto& elem : instance.m_timer_map) { - total_time += elem.second.get_milliseconds(); - } - NGRAPH_INFO << "\033[1;32m" - << "Total time " << total_time << " (ms) \033[0m"; - return true; -} - -void runtime::he::HEBackend::generate_calls( - const element::Type& element_type, const NodeWrapper& node_wrapper, - const vector>& out, - const vector>& args, FunctionInstance& instance) { - const Node& node = node_wrapper.get_node(); - string node_op = node.description(); - shared_ptr arg0_cipher = nullptr; - shared_ptr arg0_plain = nullptr; - shared_ptr arg1_cipher = nullptr; - shared_ptr arg1_plain = nullptr; - auto out0_cipher = dynamic_pointer_cast(out[0]); - auto out0_plain = dynamic_pointer_cast(out[0]); - - if (args.size() > 0) { - arg0_cipher = dynamic_pointer_cast(args[0]); - arg0_plain = dynamic_pointer_cast(args[0]); - } - if (args.size() > 1) { - arg1_cipher = dynamic_pointer_cast(args[1]); - arg1_plain = dynamic_pointer_cast(args[1]); - } - - size_t batch_size = 1; - if (out0_cipher != nullptr) { - batch_size = out0_cipher->get_batch_size(); - } else if (out0_plain != nullptr) { - batch_size = out0_plain->get_batch_size(); - } - - stringstream ss; - ss << "Inputs: "; - if (arg0_cipher != nullptr) { - ss << "Cipher"; - } else if (arg0_plain != nullptr) { - ss << "Plain"; - } - if (arg1_cipher != nullptr) { - ss << ", Cipher"; - } else if (arg1_plain != nullptr) { - ss << ", Plain"; - } - NGRAPH_INFO << ss.str(); - ss.str(""); - ss << "Outputs: "; - if (out0_cipher != nullptr) { - ss << "Cipher"; - } else if (out0_plain != nullptr) { - ss << "Plain"; - } - NGRAPH_INFO << ss.str(); - - if (batch_size != 1) { - NGRAPH_INFO << "Batch size " << batch_size; - } - -// We want to check that every OP_TYPEID enumeration is included in the list. -// These GCC flags enable compile-time checking so that if an enumeration -// is not in the list an error is generated. -#pragma GCC diagnostic push -#pragma GCC diagnostic error "-Wswitch" -#pragma GCC diagnostic error "-Wswitch-enum" - switch (node_wrapper.get_typeid()) { - case OP_TYPEID::Add: { - if (arg0_cipher != nullptr && arg1_cipher != nullptr && - out0_cipher != nullptr) { - runtime::he::kernel::add( - arg0_cipher->get_elements(), arg1_cipher->get_elements(), - out0_cipher->get_elements(), element_type, this, - out0_cipher->get_batched_element_count()); - } else if (arg0_cipher != nullptr && arg1_plain != nullptr && - out0_cipher != nullptr) { - runtime::he::kernel::add( - arg0_cipher->get_elements(), arg1_plain->get_elements(), - out0_cipher->get_elements(), element_type, this, - out0_cipher->get_batched_element_count()); - } else if (arg0_plain != nullptr && arg1_cipher != nullptr && - out0_cipher != nullptr) { - runtime::he::kernel::add( - arg0_plain->get_elements(), arg1_cipher->get_elements(), - out0_cipher->get_elements(), element_type, this, - out0_cipher->get_batched_element_count()); - } else if (arg0_plain != nullptr && arg1_plain != nullptr && - out0_plain != nullptr) { - runtime::he::kernel::add(arg0_plain->get_elements(), - arg1_plain->get_elements(), - out0_plain->get_elements(), element_type, this, - out0_plain->get_batched_element_count()); - } else { - throw ngraph_error("Add types not supported."); - } - break; - } - case OP_TYPEID::AvgPool: { - const op::AvgPool* avg_pool = static_cast(&node); - if (arg0_cipher != nullptr && out0_cipher != nullptr) { - Shape in_shape = arg0_cipher->get_shape(); - Shape out_shape = out0_cipher->get_shape(); - runtime::he::kernel::avg_pool( - arg0_cipher->get_elements(), out0_cipher->get_elements(), in_shape, - out_shape, avg_pool->get_window_shape(), - avg_pool->get_window_movement_strides(), - avg_pool->get_padding_below(), avg_pool->get_padding_above(), - avg_pool->get_include_padding_in_avg_computation(), this); - - } else if (arg0_plain != nullptr && out0_plain != nullptr) { - Shape in_shape = arg0_plain->get_shape(); - Shape out_shape = out0_plain->get_shape(); - runtime::he::kernel::avg_pool( - arg0_plain->get_elements(), out0_plain->get_elements(), in_shape, - out_shape, avg_pool->get_window_shape(), - avg_pool->get_window_movement_strides(), - avg_pool->get_padding_below(), avg_pool->get_padding_above(), - avg_pool->get_include_padding_in_avg_computation(), this); - } else { - throw ngraph_error("Broadcast types not supported."); - } - break; - } - case OP_TYPEID::Broadcast: { - const op::Broadcast* broadcast = static_cast(&node); - AxisSet broadcast_axes = broadcast->get_broadcast_axes(); - - if (arg0_cipher != nullptr && out0_cipher != nullptr) { - Shape in_shape = arg0_cipher->get_shape(); - Shape out_shape = out0_cipher->get_shape(); - runtime::he::kernel::broadcast(arg0_cipher->get_elements(), - out0_cipher->get_elements(), in_shape, - out_shape, broadcast_axes); - } else if (arg0_plain != nullptr && out0_plain != nullptr) { - Shape in_shape = arg0_plain->get_shape(); - Shape out_shape = out0_plain->get_shape(); - runtime::he::kernel::broadcast(arg0_plain->get_elements(), - out0_plain->get_elements(), in_shape, - out_shape, broadcast_axes); - } else { - throw ngraph_error("Broadcast types not supported."); - } - break; - } - case OP_TYPEID::BroadcastLike: - break; - case OP_TYPEID::Concat: { - const op::Concat* concat = static_cast(&node); - - if (arg0_cipher != nullptr && out0_cipher != nullptr) { - std::vector in_shapes; - std::vector>> in_args; - - for (shared_ptr arg : args) { - shared_ptr arg_cipher = - dynamic_pointer_cast(arg); - if (arg_cipher == nullptr) { - throw ngraph_error("Concat type not consistent"); - } - in_args.push_back(arg_cipher->get_elements()); - in_shapes.push_back(arg_cipher->get_shape()); - - runtime::he::kernel::concat(in_args, out0_cipher->get_elements(), - in_shapes, node.get_output_shape(0), - concat->get_concatenation_axis()); - } - } else if (arg0_plain != nullptr && out0_plain != nullptr) { - std::vector in_shapes; - std::vector>> in_args; - - for (shared_ptr arg : args) { - shared_ptr arg_plain = - dynamic_pointer_cast(arg); - if (arg_plain == nullptr) { - throw ngraph_error("Concat type not consistent"); - } - in_args.push_back(arg_plain->get_elements()); - in_shapes.push_back(arg_plain->get_shape()); - - runtime::he::kernel::concat(in_args, out0_plain->get_elements(), - in_shapes, node.get_output_shape(0), - concat->get_concatenation_axis()); - } - } else { - throw ngraph_error("Concat types not supported."); - } - break; - } - case OP_TYPEID::Constant: { - const op::Constant* constant = static_cast(&node); - - if (out0_plain != nullptr) { - runtime::he::kernel::constant(out0_plain->get_elements(), element_type, - constant->get_data_ptr(), this, - out0_plain->get_batched_element_count()); - } else if (out0_cipher != nullptr) { - runtime::he::kernel::constant(out0_cipher->get_elements(), element_type, - constant->get_data_ptr(), this, - out0_cipher->get_batched_element_count()); - } else { - throw ngraph_error("Constant type not supported."); - } - break; - } - case OP_TYPEID::Convolution: { - const op::Convolution* c = static_cast(&node); - auto window_movement_strides = c->get_window_movement_strides(); - auto window_dilation_strides = c->get_window_dilation_strides(); - auto padding_below = c->get_padding_below(); - auto padding_above = c->get_padding_above(); - auto data_dilation_strides = c->get_data_dilation_strides(); - - if (arg0_cipher != nullptr && arg1_cipher != nullptr && - out0_cipher != nullptr) { - runtime::he::kernel::convolution( - arg0_cipher->get_elements(), arg1_cipher->get_elements(), - out0_cipher->get_elements(), arg0_cipher->get_shape(), - arg1_cipher->get_shape(), out0_cipher->get_shape(), - window_movement_strides, window_dilation_strides, padding_below, - padding_above, data_dilation_strides, 0, 1, 1, 0, 0, 1, false, - element_type, batch_size, this); - } else if (arg0_cipher != nullptr && arg1_plain != nullptr && - out0_cipher != nullptr) { - runtime::he::kernel::convolution( - arg0_cipher->get_elements(), arg1_plain->get_elements(), - out0_cipher->get_elements(), arg0_cipher->get_shape(), - arg1_plain->get_shape(), out0_cipher->get_shape(), - window_movement_strides, window_dilation_strides, padding_below, - padding_above, data_dilation_strides, 0, 1, 1, 0, 0, 1, false, - element_type, batch_size, this); - } else if (arg0_plain != nullptr && arg1_cipher != nullptr && - out0_cipher != nullptr) { - runtime::he::kernel::convolution( - arg0_plain->get_elements(), arg1_cipher->get_elements(), - out0_cipher->get_elements(), arg0_plain->get_shape(), - arg1_cipher->get_shape(), out0_cipher->get_shape(), - window_movement_strides, window_dilation_strides, padding_below, - padding_above, data_dilation_strides, 0, 1, 1, 0, 0, 1, false, - element_type, batch_size, this); - } else if (arg0_plain != nullptr && arg1_plain != nullptr && - out0_plain != nullptr) { - runtime::he::kernel::convolution( - arg0_plain->get_elements(), arg1_plain->get_elements(), - out0_plain->get_elements(), arg0_plain->get_shape(), - arg1_plain->get_shape(), out0_plain->get_shape(), - window_movement_strides, window_dilation_strides, padding_below, - padding_above, data_dilation_strides, 0, 1, 1, 0, 0, 1, false, - element_type, batch_size, this); - } else { - throw ngraph_error("Convolution types not supported."); - } - break; - } - case OP_TYPEID::Dot: { - const op::Dot* dot = static_cast(&node); - - NGRAPH_INFO << join(args[0]->get_shape(), "x") << " dot " - << join(args[1]->get_shape(), "x"); - if (arg0_cipher != nullptr && arg1_cipher != nullptr && - out0_cipher != nullptr) { - runtime::he::kernel::dot( - arg0_cipher->get_elements(), arg1_cipher->get_elements(), - out0_cipher->get_elements(), arg0_cipher->get_shape(), - arg1_cipher->get_shape(), out0_cipher->get_shape(), - dot->get_reduction_axes_count(), element_type, this); - } else if (arg0_cipher != nullptr && arg1_plain != nullptr && - out0_cipher != nullptr) { - runtime::he::kernel::dot( - arg0_cipher->get_elements(), arg1_plain->get_elements(), - out0_cipher->get_elements(), arg0_cipher->get_shape(), - arg1_plain->get_shape(), out0_cipher->get_shape(), - dot->get_reduction_axes_count(), element_type, this); - } else if (arg0_plain != nullptr && arg1_cipher != nullptr && - out0_cipher != nullptr) { - runtime::he::kernel::dot( - arg0_plain->get_elements(), arg1_cipher->get_elements(), - out0_cipher->get_elements(), arg0_plain->get_shape(), - arg1_cipher->get_shape(), out0_cipher->get_shape(), - dot->get_reduction_axes_count(), element_type, this); - } else if (arg0_plain != nullptr && arg1_plain != nullptr && - out0_plain != nullptr) { - runtime::he::kernel::dot( - arg0_plain->get_elements(), arg1_plain->get_elements(), - out0_plain->get_elements(), arg0_plain->get_shape(), - arg1_plain->get_shape(), out0_plain->get_shape(), - dot->get_reduction_axes_count(), element_type, this); - } else { - throw ngraph_error("Dot types not supported."); - } - break; - } - case OP_TYPEID::Multiply: { - if (arg0_cipher != nullptr && arg1_cipher != nullptr && - out0_cipher != nullptr) { - runtime::he::kernel::multiply( - arg0_cipher->get_elements(), arg1_cipher->get_elements(), - out0_cipher->get_elements(), element_type, this, - out0_cipher->get_batched_element_count()); - } else if (arg0_cipher != nullptr && arg1_plain != nullptr && - out0_cipher != nullptr) { - runtime::he::kernel::multiply( - arg0_cipher->get_elements(), arg1_plain->get_elements(), - out0_cipher->get_elements(), element_type, this, - out0_cipher->get_batched_element_count()); - } else if (arg0_plain != nullptr && arg1_cipher != nullptr && - out0_cipher != nullptr) { - runtime::he::kernel::multiply( - arg0_plain->get_elements(), arg1_cipher->get_elements(), - out0_cipher->get_elements(), element_type, this, - out0_cipher->get_batched_element_count()); - } else if (arg0_plain != nullptr && arg1_plain != nullptr && - out0_plain != nullptr) { - runtime::he::kernel::multiply( - arg0_plain->get_elements(), arg1_plain->get_elements(), - out0_plain->get_elements(), element_type, this, - out0_plain->get_batched_element_count()); - } else { - throw ngraph_error("Multiply types not supported."); - } - break; - } - case OP_TYPEID::Negative: { - if (arg0_cipher != nullptr && out0_cipher != nullptr) { - runtime::he::kernel::negate( - arg0_cipher->get_elements(), out0_cipher->get_elements(), - element_type, this, out0_cipher->get_batched_element_count()); - } else if (arg0_plain != nullptr && out0_plain != nullptr) { - runtime::he::kernel::negate( - arg0_plain->get_elements(), out0_plain->get_elements(), - element_type, this, out0_plain->get_batched_element_count()); - } else { - throw ngraph_error("Negative types not supported."); - } - break; - } - case OP_TYPEID::Parameter: - NGRAPH_INFO << "Skipping parameter"; - break; - case OP_TYPEID::Pad: { - const op::Pad* pad = static_cast(&node); - - // TODO: clean up - Shape arg0_shape = node.get_inputs().at(0).get_shape(); - Shape out_shape = node.get_output_shape(0); - if (arg0_cipher != nullptr && out0_cipher != nullptr) { - NGRAPH_DEBUG << "arg0_cipher->is_batched(): " - << arg0_cipher->is_batched(); - NGRAPH_DEBUG << "arg0_cipher->get_batch_size(): " - << arg0_cipher->get_batch_size(); - if (arg0_cipher->is_batched()) { - arg0_shape[0] = arg0_shape[0] / arg0_cipher->get_batch_size(); - } - - NGRAPH_DEBUG << "out0_cipher->is_batched(): " - << out0_cipher->is_batched(); - NGRAPH_DEBUG << "arg0_cipher->get_batch_size(): " - << out0_cipher->get_batch_size(); - if (out0_cipher->is_batched()) { - out_shape[0] = out_shape[0] / out0_cipher->get_batch_size(); - } - } - - NGRAPH_DEBUG << "arg0_shape after batching: " << join(arg0_shape); - NGRAPH_DEBUG << "out_shape after batching: " << join(out_shape); - - if (arg0_cipher != nullptr && arg1_cipher != nullptr && - out0_cipher != nullptr) { - runtime::he::kernel::pad( - arg0_cipher->get_elements(), arg1_cipher->get_elements(), - out0_cipher->get_elements(), arg0_shape, out_shape, - pad->get_padding_below(), pad->get_padding_above(), - pad->get_padding_interior(), batch_size, this); - } else if (arg0_cipher != nullptr && arg1_plain != nullptr && - out0_cipher != nullptr) { - runtime::he::kernel::pad( - arg0_cipher->get_elements(), arg1_plain->get_elements(), - out0_cipher->get_elements(), arg0_shape, out_shape, - pad->get_padding_below(), pad->get_padding_above(), - pad->get_padding_interior(), batch_size, this); - } else { - throw ngraph_error("Pad cipher vs plain types not supported."); - } - break; - } - case OP_TYPEID::Reshape: { - NGRAPH_INFO << "Reshape op"; - const op::Reshape* reshape = static_cast(&node); - if (arg0_cipher != nullptr && out0_cipher != nullptr) { - runtime::he::kernel::reshape( - arg0_cipher->get_elements(), out0_cipher->get_elements(), - arg0_cipher->get_shape(), reshape->get_input_order(), - out0_cipher->get_shape()); - } else if (arg0_plain != nullptr && out0_plain != nullptr) { - runtime::he::kernel::reshape( - arg0_plain->get_elements(), out0_plain->get_elements(), - arg0_plain->get_shape(), reshape->get_input_order(), - out0_plain->get_shape()); - } else { - throw ngraph_error("Reshape types not supported."); - } - NGRAPH_INFO << "Done with reshape op"; - break; - } - case OP_TYPEID::Result: { - size_t output_size; - if (arg0_plain != nullptr) { - output_size = arg0_plain->get_batched_element_count(); - } else if (arg0_cipher != nullptr) { - output_size = arg0_cipher->get_batched_element_count(); - } else { - throw ngraph_error( - "Input argument is neither plaintext nor ciphertext"); - } - - if (arg0_cipher != nullptr && out0_cipher != nullptr) { - runtime::he::kernel::result(arg0_cipher->get_elements(), - out0_cipher->get_elements(), output_size); - } else if (arg0_plain != nullptr && out0_cipher != nullptr) { - runtime::he::kernel::result(arg0_plain->get_elements(), - out0_cipher->get_elements(), output_size, - this); - } else if (arg0_cipher != nullptr && out0_plain != nullptr) { - runtime::he::kernel::result(arg0_cipher->get_elements(), - out0_plain->get_elements(), output_size, - this); - } else if (arg0_plain != nullptr && out0_plain != nullptr) { - runtime::he::kernel::result(arg0_plain->get_elements(), - out0_plain->get_elements(), output_size); - } else { - throw ngraph_error("Result types not supported."); - } - break; - } - case OP_TYPEID::Reverse: { - const op::Reverse* reverse = static_cast(&node); - Shape in_shape = node.get_input_shape(0); - Shape out_shape = node.get_output_shape(0); - - if (arg0_cipher != nullptr && out0_cipher != nullptr) { - runtime::he::kernel::reverse(arg0_cipher->get_elements(), - out0_cipher->get_elements(), in_shape, - out_shape, reverse->get_reversed_axes()); - } else if (arg0_plain != nullptr && out0_plain != nullptr) { - runtime::he::kernel::reverse(arg0_plain->get_elements(), - out0_plain->get_elements(), in_shape, - out_shape, reverse->get_reversed_axes()); - } else { - throw ngraph_error("Reverse types not supported."); - } - break; - } - case OP_TYPEID::ScalarConstantLike: - break; - case OP_TYPEID::Slice: { - const op::Slice* slice = static_cast(&node); - Shape in_shape = node.get_input_shape(0); - Shape out_shape = node.get_output_shape(0); - - if (arg0_cipher != nullptr && out0_cipher != nullptr) { - runtime::he::kernel::slice( - arg0_cipher->get_elements(), out0_cipher->get_elements(), in_shape, - slice->get_lower_bounds(), slice->get_upper_bounds(), - slice->get_strides(), out_shape); - } else if (arg0_plain != nullptr && out0_plain != nullptr) { - runtime::he::kernel::slice( - arg0_plain->get_elements(), out0_plain->get_elements(), in_shape, - slice->get_lower_bounds(), slice->get_upper_bounds(), - slice->get_strides(), out_shape); - } else { - throw ngraph_error("Slice types not supported."); - } - break; - } - case OP_TYPEID::Subtract: { - if (arg0_cipher != nullptr && arg1_cipher != nullptr && - out0_cipher != nullptr) { - runtime::he::kernel::subtract( - arg0_cipher->get_elements(), arg1_cipher->get_elements(), - out0_cipher->get_elements(), element_type, this, - out0_cipher->get_batched_element_count()); - } else if (arg0_cipher != nullptr && arg1_plain != nullptr && - out0_cipher != nullptr) { - runtime::he::kernel::subtract( - arg0_cipher->get_elements(), arg1_plain->get_elements(), - out0_cipher->get_elements(), element_type, this, - out0_cipher->get_batched_element_count()); - } else if (arg0_plain != nullptr && arg1_cipher != nullptr && - out0_cipher != nullptr) { - runtime::he::kernel::subtract( - arg0_plain->get_elements(), arg1_cipher->get_elements(), - out0_cipher->get_elements(), element_type, this, - out0_cipher->get_batched_element_count()); - } else if (arg0_plain != nullptr && arg1_plain != nullptr && - out0_plain != nullptr) { - runtime::he::kernel::subtract( - arg0_plain->get_elements(), arg1_plain->get_elements(), - out0_plain->get_elements(), element_type, this, - out0_plain->get_batched_element_count()); - } else { - throw ngraph_error("Subtract types not supported."); - } - break; - } - case OP_TYPEID::Sum: { - const op::Sum* sum = static_cast(&node); - Shape in_shape = node.get_input_shape(0); - Shape out_shape = node.get_output_shape(0); - - if (arg0_cipher != nullptr && out0_cipher != nullptr) { - runtime::he::kernel::sum( - arg0_cipher->get_elements(), out0_cipher->get_elements(), in_shape, - out_shape, sum->get_reduction_axes(), element_type, this); - } else if (arg0_plain != nullptr && out0_plain != nullptr) { - runtime::he::kernel::sum( - arg0_plain->get_elements(), out0_plain->get_elements(), in_shape, - out_shape, sum->get_reduction_axes(), element_type, this); - } else { - throw ngraph_error("Sum types not supported."); - } - break; - } - // Unsupported ops - case OP_TYPEID::Abs: - case OP_TYPEID::Acos: - case OP_TYPEID::All: - case OP_TYPEID::AllReduce: - case OP_TYPEID::And: - case OP_TYPEID::Any: - case OP_TYPEID::ArgMax: - case OP_TYPEID::ArgMin: - case OP_TYPEID::Asin: - case OP_TYPEID::Atan: - case OP_TYPEID::AvgPoolBackprop: - case OP_TYPEID::BatchNormInference: - case OP_TYPEID::BatchNormTraining: - case OP_TYPEID::BatchNormTrainingBackprop: - case OP_TYPEID::Ceiling: - case OP_TYPEID::Convert: - case OP_TYPEID::ConvolutionBackpropData: - case OP_TYPEID::ConvolutionBackpropFilters: - case OP_TYPEID::Cos: - case OP_TYPEID::Cosh: - case OP_TYPEID::Dequantize: - case OP_TYPEID::Divide: - case OP_TYPEID::EmbeddingLookup: - case OP_TYPEID::Equal: - case OP_TYPEID::Exp: - case OP_TYPEID::Floor: - case OP_TYPEID::GenerateMask: - case OP_TYPEID::GetOutputElement: - case OP_TYPEID::Greater: - case OP_TYPEID::GreaterEq: - case OP_TYPEID::Less: - case OP_TYPEID::LessEq: - case OP_TYPEID::Log: - case OP_TYPEID::LRN: - case OP_TYPEID::Max: - case OP_TYPEID::Maximum: - case OP_TYPEID::MaxPool: - case OP_TYPEID::MaxPoolBackprop: - case OP_TYPEID::Min: - case OP_TYPEID::Minimum: - case OP_TYPEID::Not: - case OP_TYPEID::NotEqual: - case OP_TYPEID::OneHot: - case OP_TYPEID::Or: - case OP_TYPEID::Power: - case OP_TYPEID::Product: - case OP_TYPEID::Quantize: - case OP_TYPEID::QuantizedAvgPool: - case OP_TYPEID::QuantizedConvolutionBias: - case OP_TYPEID::QuantizedConvolutionBiasAdd: - case OP_TYPEID::QuantizedConvolutionBiasSignedAdd: - case OP_TYPEID::QuantizedConvolutionRelu: - case OP_TYPEID::QuantizedConvolution: - case OP_TYPEID::QuantizedMaxPool: - case OP_TYPEID::Relu: - case OP_TYPEID::ReluBackprop: - case OP_TYPEID::ReplaceSlice: - case OP_TYPEID::ReverseSequence: - case OP_TYPEID::Select: - case OP_TYPEID::ShapeOf: - case OP_TYPEID::Sigmoid: - case OP_TYPEID::SigmoidBackprop: - case OP_TYPEID::Sign: - case OP_TYPEID::Sin: - case OP_TYPEID::Sinh: - case OP_TYPEID::Softmax: - case OP_TYPEID::Sqrt: - case OP_TYPEID::StopGradient: - case OP_TYPEID::Tan: - case OP_TYPEID::Tanh: - case OP_TYPEID::TopK: - default: - throw unsupported_op("Unsupported op '" + node.description() + "'"); -#pragma GCC diagnostic pop - } -} \ No newline at end of file diff --git a/src/he_backend.hpp b/src/he_backend.hpp index 9fb342d6..ec2701df 100644 --- a/src/he_backend.hpp +++ b/src/he_backend.hpp @@ -120,7 +120,6 @@ class HEBackend : public runtime::Backend { virtual std::shared_ptr create_batched_plain_tensor( const element::Type& element_type, const Shape& shape) = 0; - /// @brief Return a handle for a tensor for given mem on backend device std::shared_ptr create_tensor( const element::Type& element_type, const Shape& shape, void* memory_pointer) override; @@ -147,22 +146,15 @@ class HEBackend : public runtime::Backend { std::shared_ptr create_valued_plain_tensor( float value, const element::Type& element_type, const Shape& shape) const; - runtime::Handle compile(std::shared_ptr function) override; - - bool call( - std::shared_ptr function, - const std::vector>& outputs, - const std::vector>& inputs) override; + std::shared_ptr compile( + std::shared_ptr func, + bool enable_performance_data = false) override; void validate_he_call( std::shared_ptr function, const std::vector>& outputs, const std::vector>& inputs); - void clear_function_instance(); - - void remove_compiled_function(std::shared_ptr function) override; - /// @brief Encodes bytes to a plaintext polynomial /// @param output Pointer to plaintext to write to /// @param input Pointer to memory to encode @@ -193,12 +185,6 @@ class HEBackend : public runtime::Backend { virtual void decrypt(std::shared_ptr& output, const runtime::he::HECiphertext& input) const = 0; - void enable_performance_data(std::shared_ptr function, - bool enable) override; - - std::vector get_performance_data( - std::shared_ptr function) const override; - /// @brief Return whether or not scalar optimizations are enabled bool optimized_add() const { return m_optimized_add; }; bool optimized_mult() const { return m_optimized_mult; }; @@ -212,27 +198,11 @@ class HEBackend : public runtime::Backend { bool encrypt_model() const { return m_encrypt_model; }; private: - class FunctionInstance { - public: - bool m_is_compiled = false; - bool m_nan_check_enabled = false; - bool m_performance_counters_enabled = false; - std::unordered_map m_timer_map; - std::vector m_wrapped_nodes; - }; - std::map, FunctionInstance> m_function_map; - bool m_optimized_add{std::getenv("NGRAPH_OPTIMIZED_ADD") != nullptr}; bool m_optimized_mult{std::getenv("NGRAPH_OPTIMIZED_MULT") != nullptr}; bool m_encrypt_data{std::getenv("NGRAPH_ENCRYPT_DATA") != nullptr}; bool m_batch_data{std::getenv("NGRAPH_BATCH_DATA") != nullptr}; bool m_encrypt_model{std::getenv("NGRAPH_ENCRYPT_MODEL") != nullptr}; - - void generate_calls( - const element::Type& element_type, const NodeWrapper& op, - const std::vector>& outputs, - const std::vector>& inputs, - FunctionInstance& instance); }; } // namespace he } // namespace runtime diff --git a/src/he_executable.cpp b/src/he_executable.cpp new file mode 100644 index 00000000..ba3587c6 --- /dev/null +++ b/src/he_executable.cpp @@ -0,0 +1,932 @@ +//***************************************************************************** +// Copyright 2018-2019 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//***************************************************************************** + +#include +#include + +#include "he_cipher_tensor.hpp" +#include "he_executable.hpp" +#include "he_plain_tensor.hpp" +#include "he_tensor.hpp" +#include "kernel/add.hpp" +#include "kernel/avg_pool.hpp" +#include "kernel/broadcast.hpp" +#include "kernel/concat.hpp" +#include "kernel/constant.hpp" +#include "kernel/convolution.hpp" +#include "kernel/dot.hpp" +#include "kernel/multiply.hpp" +#include "kernel/negate.hpp" +#include "kernel/pad.hpp" +#include "kernel/reshape.hpp" +#include "kernel/result.hpp" +#include "kernel/reverse.hpp" +#include "kernel/slice.hpp" +#include "kernel/subtract.hpp" +#include "kernel/sum.hpp" +#include "ngraph/assertion.hpp" +#include "ngraph/descriptor/layout/dense_tensor_layout.hpp" +#include "ngraph/op/avg_pool.hpp" +#include "ngraph/op/broadcast.hpp" +#include "ngraph/op/concat.hpp" +#include "ngraph/op/constant.hpp" +#include "ngraph/op/convolution.hpp" +#include "ngraph/op/dot.hpp" +#include "ngraph/op/pad.hpp" +#include "ngraph/op/passthrough.hpp" +#include "ngraph/op/reshape.hpp" +#include "ngraph/op/result.hpp" +#include "ngraph/op/reverse.hpp" +#include "ngraph/op/slice.hpp" +#include "ngraph/op/sum.hpp" +#include "ngraph/pass/assign_layout.hpp" +#include "ngraph/pass/like_replacement.hpp" +#include "ngraph/pass/liveness.hpp" +#include "ngraph/pass/manager.hpp" +#include "ngraph/pass/memory_layout.hpp" +#include "ngraph/pass/visualize_tree.hpp" +#include "ngraph/runtime/backend.hpp" +#include "ngraph/util.hpp" + +using namespace std; +using namespace ngraph; +using descriptor::layout::DenseTensorLayout; + +runtime::he::HEExecutable::HEExecutable(const shared_ptr& function, + bool enable_performance_collection, + const HEBackend* he_backend, + bool encrypt_data, bool encrypt_model, + bool batch_data) + : m_he_backend(he_backend), + m_encrypt_data(encrypt_data), + m_encrypt_model(encrypt_model), + m_batch_data(batch_data) { + NGRAPH_INFO << "Compiling function"; + NGRAPH_ASSERT(he_backend != nullptr) << "he_backend == nullptr"; + + m_is_compiled = true; + pass::Manager pass_manager; + pass_manager.register_pass(); + pass_manager.register_pass>(); + pass_manager.register_pass(); + pass_manager.run_passes(function); + + for (const shared_ptr& node : function->get_ordered_ops()) { + m_wrapped_nodes.emplace_back(node); + } + set_parameters_and_results(*function); +} + +vector +runtime::he::HEExecutable::get_performance_data() const { + vector rc; + for (const pair p : m_timer_map) { + rc.emplace_back(p.first->get_name().c_str(), + p.second.get_total_microseconds(), + p.second.get_call_count()); + } + return rc; +} + +void runtime::he::HEExecutable::he_validate( + const vector>& outputs, + const vector>& inputs) { + const ParameterVector& parameters = get_parameters(); + const ResultVector& results = get_results(); + if (parameters.size() != inputs.size()) { + stringstream ss; + ss << "Call input count " << inputs.size() + << " does not match Function's Parameter count " << parameters.size(); + throw runtime_error(ss.str()); + } + if (results.size() != outputs.size()) { + stringstream ss; + ss << "Call output count " << outputs.size() + << " does not match Function's Result count " << results.size(); + throw runtime_error(ss.str()); + } + + for (size_t i = 0; i < parameters.size(); i++) { + if (parameters[i]->get_element_type() != inputs[i]->get_element_type()) { + stringstream ss; + ss << "Input " << i << " type '" << inputs[i]->get_element_type() + << "' does not match Parameter type '" + << parameters[i]->get_element_type() << "'"; + throw runtime_error(ss.str()); + } + if (inputs[i]->get_expanded_shape() != parameters[i]->get_shape()) { + stringstream ss; + ss << "Input " << i << " shape {" << join(inputs[i]->get_expanded_shape()) + << "} does not match Parameter shape {" + << join(parameters[i]->get_shape()) << "}"; + throw runtime_error(ss.str()); + } + } + + for (size_t i = 0; i < results.size(); i++) { + if (results[i]->get_element_type() != outputs[i]->get_element_type()) { + stringstream ss; + ss << "Output " << i << " type '" << outputs[i]->get_element_type() + << "' does not match Result type '" << results[i]->get_element_type() + << "'"; + throw runtime_error(ss.str()); + } + if (results[i]->get_shape() != outputs[i]->get_expanded_shape()) { + stringstream ss; + ss << "Output " << i << " shape {" + << join(outputs[i]->get_expanded_shape()) + << "} does not match Result shape {" << join(results[i]->get_shape()) + << "}"; + throw runtime_error(ss.str()); + } + } +} + +bool runtime::he::HEExecutable::call( + const vector>& outputs, + const vector>& inputs) { + NGRAPH_INFO << "HEExecutable::call"; + + if (m_encrypt_data) { + NGRAPH_INFO << "Encrypting data"; + } + if (m_batch_data) { + NGRAPH_INFO << "Batching data"; + } + if (m_encrypt_model) { + NGRAPH_INFO << "Encrypting model"; + } + + // convert outputs to HETensor + vector> he_inputs; + for (auto& tv : inputs) { + he_inputs.push_back(static_pointer_cast(tv)); + } + + // convert inputs to HETensor + vector> he_outputs; + for (auto& tv : outputs) { + he_outputs.push_back(static_pointer_cast(tv)); + } + + he_validate(he_outputs, he_inputs); + + // map function params -> HETensor + unordered_map> + tensor_map; + size_t input_count = 0; + for (auto param : get_parameters()) { + for (size_t i = 0; i < param->get_output_size(); ++i) { + descriptor::Tensor* tv = param->get_output_tensor_ptr(i).get(); + + if (m_encrypt_data) { + NGRAPH_INFO << "Encrypting parameter " << i; + auto plain_input = static_pointer_cast( + he_inputs[input_count]); + assert(plain_input != nullptr); + auto cipher_input = static_pointer_cast( + m_he_backend->create_cipher_tensor(plain_input->get_element_type(), + plain_input->get_shape(), + m_batch_data)); + + NGRAPH_INFO << "plain_input->get_batched_element_count() " + << plain_input->get_batched_element_count(); +#pragma omp parallel for + for (size_t i = 0; i < plain_input->get_batched_element_count(); ++i) { + m_he_backend->encrypt(cipher_input->get_element(i), + *plain_input->get_element(i)); + } + + NGRAPH_INFO << "Done encrypting parameter " << i; + + tensor_map.insert({tv, cipher_input}); + input_count++; + } else { + tensor_map.insert({tv, he_inputs[input_count++]}); + } + } + } + + // map function outputs -> HostTensor + for (size_t output_count = 0; output_count < get_results().size(); + ++output_count) { + auto output = get_results()[output_count]; + if (!dynamic_pointer_cast(output)) { + throw ngraph_error("One of function's outputs isn't op::Result"); + } + descriptor::Tensor* tv = output->get_output_tensor_ptr(0).get(); + tensor_map.insert({tv, he_outputs[output_count++]}); + } + + // for each ordered op in the graph + for (const NodeWrapper& wrapped : m_wrapped_nodes) { + const Node* op = &wrapped.get_node(); + auto type_id = wrapped.get_typeid(); + + NGRAPH_INFO << "\033[1;32m" + << "[ " << op->get_name() << " ]" + << "\033[0m"; + + if (type_id == OP_TYPEID::Parameter) { + NGRAPH_INFO << "Parameter shape {" << join(op->get_shape()) << "}"; + continue; + } + + if (op->description() == "Constant") { + NGRAPH_INFO << "Constant shape {" << join(op->get_shape()) << "}"; + } + + // get op inputs from map + vector> op_inputs; + for (const descriptor::Input& input : op->get_inputs()) { + descriptor::Tensor* tv = input.get_output().get_tensor_ptr().get(); + op_inputs.push_back(tensor_map.at(tv)); + } + + // get op outputs from map or create + vector> op_outputs; + for (size_t i = 0; i < op->get_output_size(); ++i) { + descriptor::Tensor* tv = op->get_output_tensor_ptr(i).get(); + auto it = tensor_map.find(tv); + if (it == tensor_map.end()) { + // The output tensor is not in the tensor map so create a new tensor + const Shape& shape = op->get_output_shape(i); + const element::Type& element_type = op->get_output_element_type(i); + string name = op->get_output_tensor(i).get_name(); + + bool plain_out = all_of( + op_inputs.begin(), op_inputs.end(), + [](shared_ptr op_input) { + return dynamic_pointer_cast(op_input) != nullptr; + }); + if (op->is_constant()) { + plain_out = !m_encrypt_model; + } + + bool batched_out = any_of(op_inputs.begin(), op_inputs.end(), + [](shared_ptr he_tv) { + return he_tv->is_batched(); + }); + if (plain_out) { + auto otv = make_shared( + element_type, shape, m_he_backend, + m_he_backend->create_empty_plaintext(), batched_out, name); + tensor_map.insert({tv, otv}); + } else { + auto otv = make_shared( + element_type, shape, m_he_backend, + m_he_backend->create_empty_ciphertext(), batched_out, name); + tensor_map.insert({tv, otv}); + } + } + op_outputs.push_back(tensor_map.at(tv)); + } + + // get op type + element::Type base_type; + if (op->get_inputs().empty()) { + base_type = op->get_element_type(); + } else { + base_type = op->get_inputs().at(0).get_tensor().get_element_type(); + } + + m_timer_map[op].start(); + generate_calls(base_type, wrapped, op_outputs, op_inputs); + m_timer_map[op].stop(); + + const string op_name = op->description(); + + // delete any obsolete tensors + for (const descriptor::Tensor* t : op->liveness_free_list) { + for (auto it = tensor_map.begin(); it != tensor_map.end(); ++it) { + if (it->second->get_name() == t->get_name()) { + tensor_map.erase(it); + break; + } + } + } + NGRAPH_INFO << "\033[1;31m" << op->get_name() << " took " + << m_timer_map[op].get_milliseconds() << "ms" + << "\033[0m"; + } + size_t total_time = 0; + for (const auto& elem : m_timer_map) { + total_time += elem.second.get_milliseconds(); + } + NGRAPH_INFO << "\033[1;32m" + << "Total time " << total_time << " (ms) \033[0m"; + return true; +} + +void runtime::he::HEExecutable::generate_calls( + const element::Type& type, const NodeWrapper& node_wrapper, + const vector>& out, + const vector>& args) { + const Node& node = node_wrapper.get_node(); + string node_op = node.description(); + shared_ptr arg0_cipher = nullptr; + shared_ptr arg0_plain = nullptr; + shared_ptr arg1_cipher = nullptr; + shared_ptr arg1_plain = nullptr; + auto out0_cipher = dynamic_pointer_cast(out[0]); + auto out0_plain = dynamic_pointer_cast(out[0]); + + if (args.size() > 0) { + arg0_cipher = dynamic_pointer_cast(args[0]); + arg0_plain = dynamic_pointer_cast(args[0]); + } + if (args.size() > 1) { + arg1_cipher = dynamic_pointer_cast(args[1]); + arg1_plain = dynamic_pointer_cast(args[1]); + } + + size_t batch_size = 1; + if (out0_cipher != nullptr) { + batch_size = out0_cipher->get_batch_size(); + } else if (out0_plain != nullptr) { + batch_size = out0_plain->get_batch_size(); + } + + stringstream ss; + ss << "Inputs: "; + if (arg0_cipher != nullptr) { + ss << "Cipher"; + } else if (arg0_plain != nullptr) { + ss << "Plain"; + } + if (arg1_cipher != nullptr) { + ss << ", Cipher"; + } else if (arg1_plain != nullptr) { + ss << ", Plain"; + } + NGRAPH_INFO << ss.str(); + ss.str(""); + ss << "Outputs: "; + if (out0_cipher != nullptr) { + ss << "Cipher"; + } else if (out0_plain != nullptr) { + ss << "Plain"; + } + NGRAPH_INFO << ss.str(); + + if (batch_size != 1) { + NGRAPH_INFO << "Batch size " << batch_size; + } + +// We want to check that every OP_TYPEID enumeration is included in the list. +// These GCC flags enable compile-time checking so that if an enumeration +// is not in the list an error is generated. +#pragma GCC diagnostic push +#pragma GCC diagnostic error "-Wswitch" +#pragma GCC diagnostic error "-Wswitch-enum" + switch (node_wrapper.get_typeid()) { + case OP_TYPEID::Add: { + if (arg0_cipher != nullptr && arg1_cipher != nullptr && + out0_cipher != nullptr) { + runtime::he::kernel::add( + arg0_cipher->get_elements(), arg1_cipher->get_elements(), + out0_cipher->get_elements(), type, m_he_backend, + out0_cipher->get_batched_element_count()); + } else if (arg0_cipher != nullptr && arg1_plain != nullptr && + out0_cipher != nullptr) { + runtime::he::kernel::add( + arg0_cipher->get_elements(), arg1_plain->get_elements(), + out0_cipher->get_elements(), type, m_he_backend, + out0_cipher->get_batched_element_count()); + } else if (arg0_plain != nullptr && arg1_cipher != nullptr && + out0_cipher != nullptr) { + runtime::he::kernel::add( + arg0_plain->get_elements(), arg1_cipher->get_elements(), + out0_cipher->get_elements(), type, m_he_backend, + out0_cipher->get_batched_element_count()); + } else if (arg0_plain != nullptr && arg1_plain != nullptr && + out0_plain != nullptr) { + runtime::he::kernel::add(arg0_plain->get_elements(), + arg1_plain->get_elements(), + out0_plain->get_elements(), type, m_he_backend, + out0_plain->get_batched_element_count()); + } else { + throw ngraph_error("Add types not supported."); + } + break; + } + case OP_TYPEID::AvgPool: { + const op::AvgPool* avg_pool = static_cast(&node); + if (arg0_cipher != nullptr && out0_cipher != nullptr) { + Shape in_shape = arg0_cipher->get_shape(); + Shape out_shape = out0_cipher->get_shape(); + runtime::he::kernel::avg_pool( + arg0_cipher->get_elements(), out0_cipher->get_elements(), in_shape, + out_shape, avg_pool->get_window_shape(), + avg_pool->get_window_movement_strides(), + avg_pool->get_padding_below(), avg_pool->get_padding_above(), + avg_pool->get_include_padding_in_avg_computation(), m_he_backend); + + } else if (arg0_plain != nullptr && out0_plain != nullptr) { + Shape in_shape = arg0_plain->get_shape(); + Shape out_shape = out0_plain->get_shape(); + runtime::he::kernel::avg_pool( + arg0_plain->get_elements(), out0_plain->get_elements(), in_shape, + out_shape, avg_pool->get_window_shape(), + avg_pool->get_window_movement_strides(), + avg_pool->get_padding_below(), avg_pool->get_padding_above(), + avg_pool->get_include_padding_in_avg_computation(), m_he_backend); + } else { + throw ngraph_error("Broadcast types not supported."); + } + break; + } + case OP_TYPEID::Broadcast: { + const op::Broadcast* broadcast = static_cast(&node); + AxisSet broadcast_axes = broadcast->get_broadcast_axes(); + + if (arg0_cipher != nullptr && out0_cipher != nullptr) { + Shape in_shape = arg0_cipher->get_shape(); + Shape out_shape = out0_cipher->get_shape(); + runtime::he::kernel::broadcast(arg0_cipher->get_elements(), + out0_cipher->get_elements(), in_shape, + out_shape, broadcast_axes); + } else if (arg0_plain != nullptr && out0_plain != nullptr) { + Shape in_shape = arg0_plain->get_shape(); + Shape out_shape = out0_plain->get_shape(); + runtime::he::kernel::broadcast(arg0_plain->get_elements(), + out0_plain->get_elements(), in_shape, + out_shape, broadcast_axes); + } else { + throw ngraph_error("Broadcast types not supported."); + } + break; + } + case OP_TYPEID::BroadcastLike: + break; + case OP_TYPEID::Concat: { + const op::Concat* concat = static_cast(&node); + + if (arg0_cipher != nullptr && out0_cipher != nullptr) { + std::vector in_shapes; + std::vector>> in_args; + + for (shared_ptr arg : args) { + shared_ptr arg_cipher = + dynamic_pointer_cast(arg); + if (arg_cipher == nullptr) { + throw ngraph_error("Concat type not consistent"); + } + in_args.push_back(arg_cipher->get_elements()); + in_shapes.push_back(arg_cipher->get_shape()); + + runtime::he::kernel::concat(in_args, out0_cipher->get_elements(), + in_shapes, node.get_output_shape(0), + concat->get_concatenation_axis()); + } + } else if (arg0_plain != nullptr && out0_plain != nullptr) { + std::vector in_shapes; + std::vector>> in_args; + + for (shared_ptr arg : args) { + shared_ptr arg_plain = + dynamic_pointer_cast(arg); + if (arg_plain == nullptr) { + throw ngraph_error("Concat type not consistent"); + } + in_args.push_back(arg_plain->get_elements()); + in_shapes.push_back(arg_plain->get_shape()); + + runtime::he::kernel::concat(in_args, out0_plain->get_elements(), + in_shapes, node.get_output_shape(0), + concat->get_concatenation_axis()); + } + } else { + throw ngraph_error("Concat types not supported."); + } + break; + } + case OP_TYPEID::Constant: { + const op::Constant* constant = static_cast(&node); + + if (out0_plain != nullptr) { + runtime::he::kernel::constant(out0_plain->get_elements(), type, + constant->get_data_ptr(), m_he_backend, + out0_plain->get_batched_element_count()); + } else if (out0_cipher != nullptr) { + runtime::he::kernel::constant(out0_cipher->get_elements(), type, + constant->get_data_ptr(), m_he_backend, + out0_cipher->get_batched_element_count()); + } else { + throw ngraph_error("Constant type not supported."); + } + break; + } + case OP_TYPEID::Convolution: { + const op::Convolution* c = static_cast(&node); + auto window_movement_strides = c->get_window_movement_strides(); + auto window_dilation_strides = c->get_window_dilation_strides(); + auto padding_below = c->get_padding_below(); + auto padding_above = c->get_padding_above(); + auto data_dilation_strides = c->get_data_dilation_strides(); + + if (arg0_cipher != nullptr && arg1_cipher != nullptr && + out0_cipher != nullptr) { + runtime::he::kernel::convolution( + arg0_cipher->get_elements(), arg1_cipher->get_elements(), + out0_cipher->get_elements(), arg0_cipher->get_shape(), + arg1_cipher->get_shape(), out0_cipher->get_shape(), + window_movement_strides, window_dilation_strides, padding_below, + padding_above, data_dilation_strides, 0, 1, 1, 0, 0, 1, false, type, + batch_size, m_he_backend); + } else if (arg0_cipher != nullptr && arg1_plain != nullptr && + out0_cipher != nullptr) { + runtime::he::kernel::convolution( + arg0_cipher->get_elements(), arg1_plain->get_elements(), + out0_cipher->get_elements(), arg0_cipher->get_shape(), + arg1_plain->get_shape(), out0_cipher->get_shape(), + window_movement_strides, window_dilation_strides, padding_below, + padding_above, data_dilation_strides, 0, 1, 1, 0, 0, 1, false, type, + batch_size, m_he_backend); + } else if (arg0_plain != nullptr && arg1_cipher != nullptr && + out0_cipher != nullptr) { + runtime::he::kernel::convolution( + arg0_plain->get_elements(), arg1_cipher->get_elements(), + out0_cipher->get_elements(), arg0_plain->get_shape(), + arg1_cipher->get_shape(), out0_cipher->get_shape(), + window_movement_strides, window_dilation_strides, padding_below, + padding_above, data_dilation_strides, 0, 1, 1, 0, 0, 1, false, type, + batch_size, m_he_backend); + } else if (arg0_plain != nullptr && arg1_plain != nullptr && + out0_plain != nullptr) { + runtime::he::kernel::convolution( + arg0_plain->get_elements(), arg1_plain->get_elements(), + out0_plain->get_elements(), arg0_plain->get_shape(), + arg1_plain->get_shape(), out0_plain->get_shape(), + window_movement_strides, window_dilation_strides, padding_below, + padding_above, data_dilation_strides, 0, 1, 1, 0, 0, 1, false, type, + batch_size, m_he_backend); + } else { + throw ngraph_error("Convolution types not supported."); + } + break; + } + case OP_TYPEID::Dot: { + const op::Dot* dot = static_cast(&node); + + NGRAPH_INFO << join(args[0]->get_shape(), "x") << " dot " + << join(args[1]->get_shape(), "x"); + if (arg0_cipher != nullptr && arg1_cipher != nullptr && + out0_cipher != nullptr) { + runtime::he::kernel::dot( + arg0_cipher->get_elements(), arg1_cipher->get_elements(), + out0_cipher->get_elements(), arg0_cipher->get_shape(), + arg1_cipher->get_shape(), out0_cipher->get_shape(), + dot->get_reduction_axes_count(), type, m_he_backend); + } else if (arg0_cipher != nullptr && arg1_plain != nullptr && + out0_cipher != nullptr) { + runtime::he::kernel::dot( + arg0_cipher->get_elements(), arg1_plain->get_elements(), + out0_cipher->get_elements(), arg0_cipher->get_shape(), + arg1_plain->get_shape(), out0_cipher->get_shape(), + dot->get_reduction_axes_count(), type, m_he_backend); + } else if (arg0_plain != nullptr && arg1_cipher != nullptr && + out0_cipher != nullptr) { + runtime::he::kernel::dot( + arg0_plain->get_elements(), arg1_cipher->get_elements(), + out0_cipher->get_elements(), arg0_plain->get_shape(), + arg1_cipher->get_shape(), out0_cipher->get_shape(), + dot->get_reduction_axes_count(), type, m_he_backend); + } else if (arg0_plain != nullptr && arg1_plain != nullptr && + out0_plain != nullptr) { + runtime::he::kernel::dot( + arg0_plain->get_elements(), arg1_plain->get_elements(), + out0_plain->get_elements(), arg0_plain->get_shape(), + arg1_plain->get_shape(), out0_plain->get_shape(), + dot->get_reduction_axes_count(), type, m_he_backend); + } else { + throw ngraph_error("Dot types not supported."); + } + break; + } + case OP_TYPEID::Multiply: { + if (arg0_cipher != nullptr && arg1_cipher != nullptr && + out0_cipher != nullptr) { + runtime::he::kernel::multiply( + arg0_cipher->get_elements(), arg1_cipher->get_elements(), + out0_cipher->get_elements(), type, m_he_backend, + out0_cipher->get_batched_element_count()); + } else if (arg0_cipher != nullptr && arg1_plain != nullptr && + out0_cipher != nullptr) { + runtime::he::kernel::multiply( + arg0_cipher->get_elements(), arg1_plain->get_elements(), + out0_cipher->get_elements(), type, m_he_backend, + out0_cipher->get_batched_element_count()); + } else if (arg0_plain != nullptr && arg1_cipher != nullptr && + out0_cipher != nullptr) { + runtime::he::kernel::multiply( + arg0_plain->get_elements(), arg1_cipher->get_elements(), + out0_cipher->get_elements(), type, m_he_backend, + out0_cipher->get_batched_element_count()); + } else if (arg0_plain != nullptr && arg1_plain != nullptr && + out0_plain != nullptr) { + runtime::he::kernel::multiply( + arg0_plain->get_elements(), arg1_plain->get_elements(), + out0_plain->get_elements(), type, m_he_backend, + out0_plain->get_batched_element_count()); + } else { + throw ngraph_error("Multiply types not supported."); + } + break; + } + case OP_TYPEID::Negative: { + if (arg0_cipher != nullptr && out0_cipher != nullptr) { + runtime::he::kernel::negate( + arg0_cipher->get_elements(), out0_cipher->get_elements(), type, + m_he_backend, out0_cipher->get_batched_element_count()); + } else if (arg0_plain != nullptr && out0_plain != nullptr) { + runtime::he::kernel::negate( + arg0_plain->get_elements(), out0_plain->get_elements(), type, + m_he_backend, out0_plain->get_batched_element_count()); + } else { + throw ngraph_error("Negative types not supported."); + } + break; + } + case OP_TYPEID::Parameter: + NGRAPH_INFO << "Skipping parameter"; + break; + case OP_TYPEID::Pad: { + const op::Pad* pad = static_cast(&node); + + // TODO: clean up + Shape arg0_shape = node.get_inputs().at(0).get_shape(); + Shape out_shape = node.get_output_shape(0); + if (arg0_cipher != nullptr && out0_cipher != nullptr) { + NGRAPH_DEBUG << "arg0_cipher->is_batched(): " + << arg0_cipher->is_batched(); + NGRAPH_DEBUG << "arg0_cipher->get_batch_size(): " + << arg0_cipher->get_batch_size(); + if (arg0_cipher->is_batched()) { + arg0_shape[0] = arg0_shape[0] / arg0_cipher->get_batch_size(); + } + + NGRAPH_DEBUG << "out0_cipher->is_batched(): " + << out0_cipher->is_batched(); + NGRAPH_DEBUG << "arg0_cipher->get_batch_size(): " + << out0_cipher->get_batch_size(); + if (out0_cipher->is_batched()) { + out_shape[0] = out_shape[0] / out0_cipher->get_batch_size(); + } + } + + NGRAPH_DEBUG << "arg0_shape after batching: " << join(arg0_shape); + NGRAPH_DEBUG << "out_shape after batching: " << join(out_shape); + + if (arg0_cipher != nullptr && arg1_cipher != nullptr && + out0_cipher != nullptr) { + runtime::he::kernel::pad( + arg0_cipher->get_elements(), arg1_cipher->get_elements(), + out0_cipher->get_elements(), arg0_shape, out_shape, + pad->get_padding_below(), pad->get_padding_above(), + pad->get_padding_interior(), batch_size, m_he_backend); + } else if (arg0_cipher != nullptr && arg1_plain != nullptr && + out0_cipher != nullptr) { + runtime::he::kernel::pad( + arg0_cipher->get_elements(), arg1_plain->get_elements(), + out0_cipher->get_elements(), arg0_shape, out_shape, + pad->get_padding_below(), pad->get_padding_above(), + pad->get_padding_interior(), batch_size, m_he_backend); + } else { + throw ngraph_error("Pad cipher vs plain types not supported."); + } + break; + } + case OP_TYPEID::Passthrough: { + const op::Passthrough* passthrough = + static_cast(&node); + throw unsupported_op{"Unsupported operation language: " + + passthrough->language()}; + } + case OP_TYPEID::Reshape: { + NGRAPH_INFO << "Reshape op"; + const op::Reshape* reshape = static_cast(&node); + if (arg0_cipher != nullptr && out0_cipher != nullptr) { + runtime::he::kernel::reshape( + arg0_cipher->get_elements(), out0_cipher->get_elements(), + arg0_cipher->get_shape(), reshape->get_input_order(), + out0_cipher->get_shape()); + } else if (arg0_plain != nullptr && out0_plain != nullptr) { + runtime::he::kernel::reshape( + arg0_plain->get_elements(), out0_plain->get_elements(), + arg0_plain->get_shape(), reshape->get_input_order(), + out0_plain->get_shape()); + } else { + throw ngraph_error("Reshape types not supported."); + } + NGRAPH_INFO << "Done with reshape op"; + break; + } + case OP_TYPEID::Result: { + size_t output_size; + if (arg0_plain != nullptr) { + output_size = arg0_plain->get_batched_element_count(); + } else if (arg0_cipher != nullptr) { + output_size = arg0_cipher->get_batched_element_count(); + } else { + throw ngraph_error( + "Input argument is neither plaintext nor ciphertext"); + } + + if (arg0_cipher != nullptr && out0_cipher != nullptr) { + runtime::he::kernel::result(arg0_cipher->get_elements(), + out0_cipher->get_elements(), output_size); + } else if (arg0_plain != nullptr && out0_cipher != nullptr) { + runtime::he::kernel::result(arg0_plain->get_elements(), + out0_cipher->get_elements(), output_size, + m_he_backend); + } else if (arg0_cipher != nullptr && out0_plain != nullptr) { + runtime::he::kernel::result(arg0_cipher->get_elements(), + out0_plain->get_elements(), output_size, + m_he_backend); + } else if (arg0_plain != nullptr && out0_plain != nullptr) { + runtime::he::kernel::result(arg0_plain->get_elements(), + out0_plain->get_elements(), output_size); + } else { + throw ngraph_error("Result types not supported."); + } + break; + } + case OP_TYPEID::Reverse: { + const op::Reverse* reverse = static_cast(&node); + Shape in_shape = node.get_input_shape(0); + Shape out_shape = node.get_output_shape(0); + + if (arg0_cipher != nullptr && out0_cipher != nullptr) { + runtime::he::kernel::reverse(arg0_cipher->get_elements(), + out0_cipher->get_elements(), in_shape, + out_shape, reverse->get_reversed_axes()); + } else if (arg0_plain != nullptr && out0_plain != nullptr) { + runtime::he::kernel::reverse(arg0_plain->get_elements(), + out0_plain->get_elements(), in_shape, + out_shape, reverse->get_reversed_axes()); + } else { + throw ngraph_error("Reverse types not supported."); + } + break; + } + case OP_TYPEID::ScalarConstantLike: + break; + case OP_TYPEID::Slice: { + const op::Slice* slice = static_cast(&node); + Shape in_shape = node.get_input_shape(0); + Shape out_shape = node.get_output_shape(0); + + if (arg0_cipher != nullptr && out0_cipher != nullptr) { + runtime::he::kernel::slice( + arg0_cipher->get_elements(), out0_cipher->get_elements(), in_shape, + slice->get_lower_bounds(), slice->get_upper_bounds(), + slice->get_strides(), out_shape); + } else if (arg0_plain != nullptr && out0_plain != nullptr) { + runtime::he::kernel::slice( + arg0_plain->get_elements(), out0_plain->get_elements(), in_shape, + slice->get_lower_bounds(), slice->get_upper_bounds(), + slice->get_strides(), out_shape); + } else { + throw ngraph_error("Slice types not supported."); + } + break; + } + case OP_TYPEID::Subtract: { + if (arg0_cipher != nullptr && arg1_cipher != nullptr && + out0_cipher != nullptr) { + runtime::he::kernel::subtract( + arg0_cipher->get_elements(), arg1_cipher->get_elements(), + out0_cipher->get_elements(), type, m_he_backend, + out0_cipher->get_batched_element_count()); + } else if (arg0_cipher != nullptr && arg1_plain != nullptr && + out0_cipher != nullptr) { + runtime::he::kernel::subtract( + arg0_cipher->get_elements(), arg1_plain->get_elements(), + out0_cipher->get_elements(), type, m_he_backend, + out0_cipher->get_batched_element_count()); + } else if (arg0_plain != nullptr && arg1_cipher != nullptr && + out0_cipher != nullptr) { + runtime::he::kernel::subtract( + arg0_plain->get_elements(), arg1_cipher->get_elements(), + out0_cipher->get_elements(), type, m_he_backend, + out0_cipher->get_batched_element_count()); + } else if (arg0_plain != nullptr && arg1_plain != nullptr && + out0_plain != nullptr) { + runtime::he::kernel::subtract( + arg0_plain->get_elements(), arg1_plain->get_elements(), + out0_plain->get_elements(), type, m_he_backend, + out0_plain->get_batched_element_count()); + } else { + throw ngraph_error("Subtract types not supported."); + } + break; + } + case OP_TYPEID::Sum: { + const op::Sum* sum = static_cast(&node); + Shape in_shape = node.get_input_shape(0); + Shape out_shape = node.get_output_shape(0); + + if (arg0_cipher != nullptr && out0_cipher != nullptr) { + runtime::he::kernel::sum( + arg0_cipher->get_elements(), out0_cipher->get_elements(), in_shape, + out_shape, sum->get_reduction_axes(), type, m_he_backend); + } else if (arg0_plain != nullptr && out0_plain != nullptr) { + runtime::he::kernel::sum( + arg0_plain->get_elements(), out0_plain->get_elements(), in_shape, + out_shape, sum->get_reduction_axes(), type, m_he_backend); + } else { + throw ngraph_error("Sum types not supported."); + } + break; + } + // Unsupported ops + case OP_TYPEID::Abs: + case OP_TYPEID::Acos: + case OP_TYPEID::All: + case OP_TYPEID::AllReduce: + case OP_TYPEID::And: + case OP_TYPEID::Any: + case OP_TYPEID::ArgMax: + case OP_TYPEID::ArgMin: + case OP_TYPEID::Asin: + case OP_TYPEID::Atan: + case OP_TYPEID::AvgPoolBackprop: + case OP_TYPEID::BatchNormInference: + case OP_TYPEID::BatchNormTraining: + case OP_TYPEID::BatchNormTrainingBackprop: + case OP_TYPEID::Ceiling: + case OP_TYPEID::Convert: + case OP_TYPEID::ConvolutionBackpropData: + case OP_TYPEID::ConvolutionBackpropFilters: + case OP_TYPEID::Cos: + case OP_TYPEID::Cosh: + case OP_TYPEID::Dequantize: + case OP_TYPEID::Divide: + case OP_TYPEID::EmbeddingLookup: + case OP_TYPEID::Equal: + case OP_TYPEID::Exp: + case OP_TYPEID::Floor: + case OP_TYPEID::GenerateMask: + case OP_TYPEID::GetOutputElement: + case OP_TYPEID::Greater: + case OP_TYPEID::GreaterEq: + case OP_TYPEID::Less: + case OP_TYPEID::LessEq: + case OP_TYPEID::Log: + case OP_TYPEID::LRN: + case OP_TYPEID::Max: + case OP_TYPEID::Maximum: + case OP_TYPEID::MaxPool: + case OP_TYPEID::MaxPoolBackprop: + case OP_TYPEID::Min: + case OP_TYPEID::Minimum: + case OP_TYPEID::Not: + case OP_TYPEID::NotEqual: + case OP_TYPEID::OneHot: + case OP_TYPEID::Or: + case OP_TYPEID::Power: + case OP_TYPEID::Product: + case OP_TYPEID::Quantize: + case OP_TYPEID::QuantizedAvgPool: + case OP_TYPEID::QuantizedConvolutionBias: + case OP_TYPEID::QuantizedConvolutionBiasAdd: + case OP_TYPEID::QuantizedConvolutionBiasSignedAdd: + case OP_TYPEID::QuantizedConvolutionRelu: + case OP_TYPEID::QuantizedConvolution: + case OP_TYPEID::QuantizedMaxPool: + case OP_TYPEID::Relu: + case OP_TYPEID::ReluBackprop: + case OP_TYPEID::ReplaceSlice: + case OP_TYPEID::ReverseSequence: + case OP_TYPEID::Select: + case OP_TYPEID::ShapeOf: + case OP_TYPEID::Sigmoid: + case OP_TYPEID::SigmoidBackprop: + case OP_TYPEID::Sign: + case OP_TYPEID::Sin: + case OP_TYPEID::Sinh: + case OP_TYPEID::Softmax: + case OP_TYPEID::Sqrt: + case OP_TYPEID::StopGradient: + case OP_TYPEID::Tan: + case OP_TYPEID::Tanh: + case OP_TYPEID::TopK: + default: + throw unsupported_op("Unsupported op '" + node.description() + "'"); +#pragma GCC diagnostic pop + } +} \ No newline at end of file diff --git a/src/he_executable.hpp b/src/he_executable.hpp new file mode 100644 index 00000000..55dc621b --- /dev/null +++ b/src/he_executable.hpp @@ -0,0 +1,64 @@ +//***************************************************************************** +// Copyright 2018-2019 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//***************************************************************************** + +#pragma once + +#include +#include + +#include "he_backend.hpp" +#include "he_tensor.hpp" +#include "ngraph/runtime/backend.hpp" +#include "ngraph/util.hpp" +#include "node_wrapper.hpp" + +namespace ngraph { +namespace runtime { +namespace he { + +class HEExecutable : public Executable { + public: + HEExecutable(const std::shared_ptr& function, + bool enable_performance_collection, + const runtime::he::HEBackend* he_backend, bool encrypt_data, + bool encrypt_model, bool batch_data); + + bool call(const std::vector>& outputs, + const std::vector>& inputs) override; + + void he_validate( + const std::vector>& outputs, + const std::vector>& inputs); + + std::vector get_performance_data() const override; + + private: + bool m_encrypt_data; + bool m_batch_data; + bool m_encrypt_model; + bool m_is_compiled = false; + const HEBackend* m_he_backend = nullptr; // TODO: replace with context + std::unordered_map m_timer_map; + std::vector m_wrapped_nodes; + + void generate_calls(const element::Type& type, const NodeWrapper& op, + const std::vector>& outputs, + const std::vector>& inputs); +}; + +} // namespace he +} // namespace runtime +} // namespace ngraph diff --git a/src/he_plain_tensor.cpp b/src/he_plain_tensor.cpp index 23494f8a..0b7b05be 100644 --- a/src/he_plain_tensor.cpp +++ b/src/he_plain_tensor.cpp @@ -28,10 +28,10 @@ runtime::he::HEPlainTensor::HEPlainTensor( const bool batched, const string& name) : runtime::he::HETensor(element_type, shape, he_backend, batched, name) { m_num_elements = m_descriptor->get_tensor_layout()->get_size(); - m_plain_texts.resize(m_num_elements); + m_plaintexts.resize(m_num_elements); #pragma omp parallel for for (size_t i = 0; i < m_num_elements; ++i) { - m_plain_texts[i] = he_backend->create_empty_plaintext(); + m_plaintexts[i] = he_backend->create_empty_plaintext(); } } @@ -53,7 +53,7 @@ void runtime::he::HEPlainTensor::write(const void* source, size_t tensor_offset, if (num_elements_to_write == 1) { const void* src_with_offset = (void*)((char*)source); size_t dst_index = dst_start_index; - m_he_backend->encode(m_plain_texts[dst_index], src_with_offset, + m_he_backend->encode(m_plaintexts[dst_index], src_with_offset, element_type); } else { #pragma omp parallel for @@ -74,12 +74,12 @@ void runtime::he::HEPlainTensor::write(const void* source, size_t tensor_offset, type_byte_size * (i + j * num_elements_to_write)); memcpy(destination, src, type_byte_size); } - m_he_backend->encode(m_plain_texts[dst_index], batch_src, element_type, + m_he_backend->encode(m_plaintexts[dst_index], batch_src, element_type, m_batch_size); free((void*)batch_src); } else { - m_he_backend->encode(m_plain_texts[dst_index], src_with_offset, + m_he_backend->encode(m_plaintexts[dst_index], src_with_offset, element_type); } } @@ -103,7 +103,7 @@ void runtime::he::HEPlainTensor::read(void* target, size_t tensor_offset, if (num_elements_to_read == 1) { void* dst_with_offset = (void*)((char*)target); size_t src_index = src_start_index; - m_he_backend->decode(dst_with_offset, m_plain_texts[src_index].get(), + m_he_backend->decode(dst_with_offset, m_plaintexts[src_index].get(), element_type, m_batch_size); } else { #pragma omp parallel for @@ -113,7 +113,7 @@ void runtime::he::HEPlainTensor::read(void* target, size_t tensor_offset, throw ngraph_error("Error allocating HE Cipher Tensor memory"); } size_t src_index = src_start_index + i; - m_he_backend->decode(dst, m_plain_texts[src_index].get(), element_type, + m_he_backend->decode(dst, m_plaintexts[src_index].get(), element_type, m_batch_size); for (size_t j = 0; j < m_batch_size; ++j) { diff --git a/src/he_plain_tensor.hpp b/src/he_plain_tensor.hpp index 09be1afa..f89e5062 100644 --- a/src/he_plain_tensor.hpp +++ b/src/he_plain_tensor.hpp @@ -55,15 +55,15 @@ class HEPlainTensor : public HETensor { inline std::vector>& get_elements() noexcept { - return m_plain_texts; + return m_plaintexts; } inline std::shared_ptr& get_element(size_t i) { - return m_plain_texts[i]; + return m_plaintexts[i]; } private: - std::vector> m_plain_texts; + std::vector> m_plaintexts; size_t m_num_elements; }; } // namespace he diff --git a/test/test_add.in.cpp b/test/test_add.in.cpp index 38090c33..a658271a 100644 --- a/test/test_add.in.cpp +++ b/test/test_add.in.cpp @@ -51,7 +51,8 @@ NGRAPH_TEST(${BACKEND_NAME}, add_2_3) { test::NDArray({{1, 2, 3}, {4, 5, 6}}).get_vector()); copy_data(t_b, test::NDArray({{7, 8, 9}, {10, 11, 12}}).get_vector()); - backend->call(backend->compile(f), {t_result}, {t_a, t_b}); + auto handle = backend->compile(f); + handle->call({t_result}, {t_a, t_b}); EXPECT_TRUE(all_close( read_vector(t_result), (test::NDArray({{8, 10, 12}, {14, 16, 18}})).get_vector(), @@ -83,7 +84,8 @@ NGRAPH_TEST(${BACKEND_NAME}, add_zero_2_3) { test::NDArray({{1, 2, 3}, {4, 5, 6}}).get_vector()); copy_data(t_b, test::NDArray({{0, 0, 0}, {0, 0, 0}}).get_vector()); - backend->call(backend->compile(f), {t_result}, {t_a, t_b}); + auto handle = backend->compile(f); + handle->call({t_result}, {t_a, t_b}); EXPECT_TRUE(all_close( read_vector(t_result), (test::NDArray({{1, 2, 3}, {4, 5, 6}})).get_vector(), 1e-3f)); @@ -111,7 +113,8 @@ NGRAPH_TEST(${BACKEND_NAME}, add_4_3_batch_cipher) { copy_data(t_a, vector{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); copy_data(t_b, vector{13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}); - backend->call(backend->compile(f), {t_result}, {t_a, t_b}); + auto handle = backend->compile(f); + handle->call({t_result}, {t_a, t_b}); EXPECT_TRUE( all_close((vector{14, 16, 18, 20, 22, 24, 26, 28, 30, 32, 34, 36}), generalized_read_vector(t_result), 1e-3f)); @@ -138,7 +141,8 @@ NGRAPH_TEST(${BACKEND_NAME}, add_4_3_batch_plain) { copy_data(t_a, vector{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); copy_data(t_b, vector{13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}); - backend->call(backend->compile(f), {t_result}, {t_a, t_b}); + auto handle = backend->compile(f); + handle->call({t_result}, {t_a, t_b}); EXPECT_TRUE( all_close((vector{14, 16, 18, 20, 22, 24, 26, 28, 30, 32, 34, 36}), generalized_read_vector(t_result), 1e-3f)); @@ -170,7 +174,8 @@ NGRAPH_TEST(${BACKEND_NAME}, add_optimized_2_3) { test::NDArray({{1, 2, 3}, {4, 5, 6}}).get_vector()); copy_data(t_b, test::NDArray({{-1, 0, 1}, {-1, 0, 1}}).get_vector()); - backend->call(backend->compile(f), {t_result}, {t_a, t_b}); + auto handle = backend->compile(f); + handle->call({t_result}, {t_a, t_b}); EXPECT_TRUE(all_close( read_vector(t_result), (test::NDArray({{0, 2, 4}, {3, 5, 7}})).get_vector(), 1e-3f)); diff --git a/test/test_avg_pool.in.cpp b/test/test_avg_pool.in.cpp index 1f1ffffd..30a160a1 100644 --- a/test/test_avg_pool.in.cpp +++ b/test/test_avg_pool.in.cpp @@ -54,7 +54,8 @@ NGRAPH_TEST(${BACKEND_NAME}, avg_pool_1d_1channel_1image) { a, test::NDArray{{{0, 1, 0, 2, 1, 0, 3, 2, 0, 0, 2, 0, 0, 0}}} .get_vector()); - backend->call(backend->compile(f), {result}, {a}); + auto handle = backend->compile(f); + handle->call({result}, {a}); EXPECT_TRUE(all_close(test::NDArray( {{{1 / denom, 3 / denom, 3 / denom, 3 / denom, 4 / denom, 5 / denom, 5 / denom, 2 / denom, @@ -92,7 +93,8 @@ NGRAPH_TEST(${BACKEND_NAME}, avg_pool_1d_1channel_2image) { {{0, 2, 1, 1, 0, 0, 0, 2, 0, 1, 0, 0, 1, 2}}}) .get_vector()); - backend->call(backend->compile(f), {result}, {a}); + auto handle = backend->compile(f); + handle->call({result}, {a}); EXPECT_TRUE(all_close(test::NDArray( {{{1 / denom, 3 / denom, 3 / denom, 3 / denom, 4 / denom, 5 / denom, 5 / denom, 2 / denom, @@ -128,7 +130,8 @@ NGRAPH_TEST(${BACKEND_NAME}, avg_pool_1d_1channel_2image_batched) { {{0, 2, 1, 1, 0, 0, 0, 2, 0, 1, 0, 0, 1, 2}}}) .get_vector()); - backend->call(backend->compile(f), {t_result}, {t_a}); + auto handle = backend->compile(f); + handle->call({t_result}, {t_a}); EXPECT_TRUE(all_close( test::NDArray( {{{1 / denom, 3 / denom, 3 / denom, 3 / denom, 4 / denom, 5 / denom, @@ -171,7 +174,8 @@ NGRAPH_TEST(${BACKEND_NAME}, avg_pool_1d_2channel_2image) { {2, 1, 0, 0, 1, 0, 2, 0, 0, 0, 1, 1, 2, 0}}}) .get_vector()); - backend->call(backend->compile(f), {result}, {a}); + auto handle = backend->compile(f); + handle->call({result}, {a}); EXPECT_TRUE(all_close(test::NDArray( {{{1 / denom, 3 / denom, 3 / denom, 3 / denom, 4 / denom, 5 / denom, 5 / denom, 2 / denom, @@ -239,7 +243,8 @@ NGRAPH_TEST(${BACKEND_NAME}, avg_pool_2d_2channel_2image) { {1, 0, 0, 0, 2}}}}) .get_vector()); - backend->call(backend->compile(f), {result}, {a}); + auto handle = backend->compile(f); + handle->call({result}, {a}); EXPECT_TRUE( all_close(test::NDArray( {{{{6 / denom, 8 / denom, 5 / denom}, // img 0 chan 0 @@ -300,7 +305,8 @@ NGRAPH_TEST(${BACKEND_NAME}, avg_pool_2d_1channel_1image_strided) { {1, 0, 2, 0, 0, 0, 1, 0}}}}) .get_vector()); - backend->call(backend->compile(f), {result}, {a}); + auto handle = backend->compile(f); + handle->call({result}, {a}); EXPECT_TRUE(all_close( test::NDArray({{{{6 / denom, 5 / denom, 4 / denom}, {6 / denom, 5 / denom, 8 / denom}, diff --git a/test/test_basics.in.cpp b/test/test_basics.in.cpp index 65aa825a..6be4989b 100644 --- a/test/test_basics.in.cpp +++ b/test/test_basics.in.cpp @@ -64,7 +64,9 @@ NGRAPH_TEST(${BACKEND_NAME}, validate_call_input_count) { auto b = backend->create_tensor(element::f32, shape); auto c = backend->create_tensor(element::f32, shape); - EXPECT_ANY_THROW(backend->call_with_validate(f, {c}, {a})); + auto handle = backend->compile(f); + + EXPECT_ANY_THROW(handle->call_with_validate({c}, {a})); } NGRAPH_TEST(${BACKEND_NAME}, validate_call_input_type) { @@ -81,7 +83,9 @@ NGRAPH_TEST(${BACKEND_NAME}, validate_call_input_type) { auto b = backend->create_tensor(element::f32, shape); auto c = backend->create_tensor(element::f32, shape); - EXPECT_ANY_THROW(backend->call_with_validate(f, {c}, {a, b})); + auto handle = backend->compile(f); + + EXPECT_ANY_THROW(handle->call_with_validate({c}, {a, b})); } NGRAPH_TEST(${BACKEND_NAME}, validate_call_input_shape) { @@ -98,7 +102,9 @@ NGRAPH_TEST(${BACKEND_NAME}, validate_call_input_shape) { auto b = backend->create_tensor(element::f32, shape); auto c = backend->create_tensor(element::f32, shape); - EXPECT_ANY_THROW(backend->call_with_validate(f, {c}, {a, b})); + auto handle = backend->compile(f); + + EXPECT_ANY_THROW(handle->call_with_validate({c}, {a, b})); } NGRAPH_TEST(${BACKEND_NAME}, validate_call_output_count) { @@ -116,7 +122,9 @@ NGRAPH_TEST(${BACKEND_NAME}, validate_call_output_count) { auto c = backend->create_tensor(element::f32, shape); auto d = backend->create_tensor(element::f32, shape); - EXPECT_ANY_THROW(backend->call_with_validate(f, {c, d}, {a, b})); + auto handle = backend->compile(f); + + EXPECT_ANY_THROW(handle->call_with_validate({c, d}, {a, b})); } NGRAPH_TEST(${BACKEND_NAME}, validate_call_output_type) { @@ -133,7 +141,9 @@ NGRAPH_TEST(${BACKEND_NAME}, validate_call_output_type) { auto b = backend->create_tensor(element::f32, shape); auto c = backend->create_tensor(element::f32, shape); - EXPECT_ANY_THROW(backend->call_with_validate(f, {a}, {b, c})); + auto handle = backend->compile(f); + + EXPECT_ANY_THROW(handle->call_with_validate({a}, {b, c})); } NGRAPH_TEST(${BACKEND_NAME}, validate_call_output_shape) { @@ -150,5 +160,7 @@ NGRAPH_TEST(${BACKEND_NAME}, validate_call_output_shape) { auto b = backend->create_tensor(element::f32, shape); auto c = backend->create_tensor(element::f32, shape); - EXPECT_ANY_THROW(backend->call_with_validate(f, {a}, {c, b})); + auto handle = backend->compile(f); + + EXPECT_ANY_THROW(handle->call_with_validate({a}, {c, b})); } diff --git a/test/test_broadcast.in.cpp b/test/test_broadcast.in.cpp index b2df5088..812f37e2 100644 --- a/test/test_broadcast.in.cpp +++ b/test/test_broadcast.in.cpp @@ -48,7 +48,8 @@ NGRAPH_TEST(${BACKEND_NAME}, broadcast_vector) { copy_data(a, vector{6}); - backend->call(backend->compile(f), {result}, {a}); + auto handle = backend->compile(f); + handle->call({result}, {a}); EXPECT_TRUE( all_close((vector{6, 6, 6, 6}), read_vector(result))); } @@ -85,7 +86,8 @@ NGRAPH_TEST(${BACKEND_NAME}, broadcast_matrix) { copy_data(a, vector{6}); - backend->call(backend->compile(f), {result}, {a}); + auto handle = backend->compile(f); + handle->call({result}, {a}); EXPECT_TRUE( all_close((vector{6, 6, 6, 6}), read_vector(result))); } @@ -111,7 +113,8 @@ NGRAPH_TEST(${BACKEND_NAME}, broadcast_tensor) { copy_data(a, vector{6}); - backend->call(backend->compile(f), {result}, {a}); + auto handle = backend->compile(f); + handle->call({result}, {a}); EXPECT_TRUE(all_close((vector{6, 6, 6, 6, 6, 6, 6, 6}), read_vector(result))); } @@ -136,7 +139,8 @@ NGRAPH_TEST(${BACKEND_NAME}, broadcast_trivial) { copy_data(a, vector{2, 4, 6, 8, 16, 32, 64, 128}); - backend->call(backend->compile(f), {result}, {a}); + auto handle = backend->compile(f); + handle->call({result}, {a}); EXPECT_TRUE(all_close((vector{2, 4, 6, 8, 16, 32, 64, 128}), read_vector(result))); } @@ -163,7 +167,8 @@ NGRAPH_TEST(${BACKEND_NAME}, broadcast_vector_colwise) { copy_data(a, vector{1, 2, 3}); - backend->call(backend->compile(f), {result}, {a}); + auto handle = backend->compile(f); + handle->call({result}, {a}); EXPECT_TRUE(all_close((vector{1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3}), read_vector(result))); } @@ -190,7 +195,8 @@ NGRAPH_TEST(${BACKEND_NAME}, broadcast_vector_rowwise) { copy_data(a, vector{1, 2, 3, 4}); - backend->call(backend->compile(f), {result}, {a}); + auto handle = backend->compile(f); + handle->call({result}, {a}); EXPECT_TRUE(all_close((vector{1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4}), read_vector(result))); } @@ -216,7 +222,8 @@ NGRAPH_TEST(${BACKEND_NAME}, broadcast_matrix_0) { copy_data(a, vector{1, 2, 3, 4}); - backend->call(backend->compile(f), {result}, {a}); + auto handle = backend->compile(f); + handle->call({result}, {a}); EXPECT_TRUE(all_close((vector{1, 2, 3, 4, 1, 2, 3, 4}), read_vector(result))); } @@ -242,7 +249,8 @@ NGRAPH_TEST(${BACKEND_NAME}, broadcast_matrix_1) { copy_data(a, vector{1, 2, 3, 4}); - backend->call(backend->compile(f), {result}, {a}); + auto handle = backend->compile(f); + handle->call({result}, {a}); EXPECT_TRUE(all_close((vector{1, 2, 1, 2, 3, 4, 3, 4}), read_vector(result))); } @@ -268,7 +276,8 @@ NGRAPH_TEST(${BACKEND_NAME}, broadcast_matrix_2) { copy_data(a, vector{1, 2, 3, 4}); - backend->call(backend->compile(f), {result}, {a}); + auto handle = backend->compile(f); + handle->call({result}, {a}); EXPECT_TRUE(all_close((vector{1, 1, 2, 2, 3, 3, 4, 4}), read_vector(result))); } diff --git a/test/test_concat.in.cpp b/test/test_concat.in.cpp index 9ef00b00..fd8c127b 100644 --- a/test/test_concat.in.cpp +++ b/test/test_concat.in.cpp @@ -56,7 +56,8 @@ NGRAPH_TEST(${BACKEND_NAME}, concat_matrix_colwise) { copy_data(b, vector{1, 2, 4, 8, 16, 32}); copy_data(c, vector{2, 3, 5, 7, 11, 13}); - backend->call(backend->compile(f), {result}, {a, b, c}); + auto handle = backend->compile(f); + handle->call({result}, {a, b, c}); EXPECT_TRUE(all_close( vector{2, 4, 1, 2, 4, 2, 3, 5, 8, 16, 8, 16, 32, 7, 11, 13}, read_vector(result))); @@ -92,7 +93,8 @@ NGRAPH_TEST(${BACKEND_NAME}, concat_matrix_rowise) { copy_data(b, vector{1, 2, 4, 8, 16, 32}); copy_data(c, vector{2, 3, 5, 7, 11, 13}); - backend->call(backend->compile(f), {result}, {a, b, c}); + auto handle = backend->compile(f); + handle->call({result}, {a, b, c}); EXPECT_TRUE(all_close( vector{2, 4, 8, 16, 1, 2, 4, 8, 16, 32, 2, 3, 5, 7, 11, 13}, read_vector(result))); @@ -128,7 +130,8 @@ NGRAPH_TEST(${BACKEND_NAME}, concat_vector) { copy_data(b, vector{1, 2, 4, 8, 16, 32}); copy_data(c, vector{18, 19}); - backend->call(backend->compile(f), {result}, {a, b, c}); + auto handle = backend->compile(f); + handle->call({result}, {a, b, c}); EXPECT_TRUE( all_close(vector{2, 4, 8, 16, 1, 2, 4, 8, 16, 32, 18, 19}, read_vector(result))); @@ -161,7 +164,8 @@ NGRAPH_TEST(${BACKEND_NAME}, concat_4d_tensor) { copy_data(b, vector{2}); copy_data(c, vector{3}); - backend->call(backend->compile(f), {result}, {a, b, c}); + auto handle = backend->compile(f); + handle->call({result}, {a, b, c}); EXPECT_TRUE(all_close(vector{1, 2, 3}, read_vector(result))); } } @@ -192,7 +196,8 @@ NGRAPH_TEST(${BACKEND_NAME}, concat_2d_tensor) { copy_data(b, vector{2}); copy_data(c, vector{3}); - backend->call(backend->compile(f), {result}, {a, b, c}); + auto handle = backend->compile(f); + handle->call({result}, {a, b, c}); EXPECT_TRUE(all_close(vector{1, 2, 3}, read_vector(result))); } } diff --git a/test/test_constant.in.cpp b/test/test_constant.in.cpp index 9b723ae1..5a2a6adc 100644 --- a/test/test_constant.in.cpp +++ b/test/test_constant.in.cpp @@ -34,7 +34,8 @@ NGRAPH_TEST(${BACKEND_NAME}, constant) { auto f = make_shared(A, ParameterVector{}); auto result = backend->create_tensor(element::f32, shape); - backend->call(backend->compile(f), {result}, {}); + auto handle = backend->compile(f); + handle->call({result}, {}); EXPECT_TRUE(all_close((vector{0.1, 0.2, 0.3, 0.4}), read_vector(result))); } @@ -62,7 +63,8 @@ NGRAPH_TEST(${BACKEND_NAME}, constant_abc) { copy_data(b, test::NDArray({{5, 6}, {7, 8}}).get_vector()); copy_data(c, test::NDArray({{9, 10}, {11, 12}}).get_vector()); - backend->call(backend->compile(f), {result}, {b, c}); + auto handle = backend->compile(f); + handle->call({result}, {b, c}); EXPECT_TRUE(all_close( read_vector(result), diff --git a/test/test_convolution.in.cpp b/test/test_convolution.in.cpp index d26c06e1..9ea045b4 100644 --- a/test/test_convolution.in.cpp +++ b/test/test_convolution.in.cpp @@ -55,7 +55,8 @@ NGRAPH_TEST(${BACKEND_NAME}, convolution_2d_1image) { 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0}); copy_data(t_b, vector{0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5}); - backend->call(backend->compile(f), {t_result}, {t_a, t_b}); + auto handle = backend->compile(f); + handle->call({t_result}, {t_a, t_b}); EXPECT_TRUE(all_close(read_vector(t_result), vector{9, 9, 9, 9, 9, 9, 9, 9, 9}, 1e-1f)); } @@ -84,7 +85,8 @@ NGRAPH_TEST(${BACKEND_NAME}, convolution_2d_1image_2outputs) { vector{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}); copy_data(t_b, vector{1, 2, 3, 4, 5, 6, 7, 8}); - backend->call(backend->compile(f), {t_result}, {t_a, t_b}); + auto handle = backend->compile(f); + handle->call({t_result}, {t_a, t_b}); EXPECT_TRUE(all_close(read_vector(t_result), vector{51, 61, 71, 81, 101, 111, 121, 131, 115, 141, 167, 193, 245, 271, 297, 323}, @@ -123,7 +125,8 @@ NGRAPH_TEST(${BACKEND_NAME}, convolution_2d_1item) { copy_data(t_a, vector{-8.f, 2.f, -4.f, -2.f, 9.f, 9.f, -0.f, -3.f, -8.f, 5.f, -8.f, 1.f, 2.f, 8.f, -2.f}); copy_data(t_b, vector{-8.f, 2.f, -4.f, -2.f, 9.f, 9.f, -0.f, -3.f}); - backend->call(backend->compile(f), {t_result}, {t_a, t_b}); + auto handle = backend->compile(f); + handle->call({t_result}, {t_a, t_b}); EXPECT_TRUE( all_close(read_vector(t_result), vector{32.0f, -18.0f, 56.0f, 56.0f, -42.0f, -14.0f, @@ -164,7 +167,8 @@ NGRAPH_TEST(${BACKEND_NAME}, convolution_2d_1item_padded_1_1x1_1) { copy_data(t_a, vector{-8.f, 2.f, -4.f, -2.f, 9.f, 9.f, -0.f, -3.f, -8.f, 5.f, -8.f, 1.f, 2.f, 8.f, -2.f}); copy_data(t_b, vector{-8.f, 2.f, -4.f, -2.f, 9.f, 9.f, -0.f, -3.f}); - backend->call(backend->compile(f), {t_result}, {t_a, t_b}); + auto handle = backend->compile(f); + handle->call({t_result}, {t_a, t_b}); EXPECT_TRUE(all_close( read_vector(t_result), vector{16.0f, 28.0f, 0.0f, 20.0f, -10.0f, -36.0f, -34.0f, @@ -209,7 +213,8 @@ NGRAPH_TEST(${BACKEND_NAME}, convolution_2d_1item_padded_2_3x4_5) { copy_data(t_a, vector{-8.f, 2.f, -4.f, -2.f, 9.f, 9.f, -0.f, -3.f, -8.f, 5.f, -8.f, 1.f, 2.f, 8.f, -2.f}); copy_data(t_b, vector{-8.f, 2.f, -4.f, -2.f, 9.f, 9.f, -0.f, -3.f}); - backend->call(backend->compile(f), {t_result}, {t_a, t_b}); + auto handle = backend->compile(f); + handle->call({t_result}, {t_a, t_b}); EXPECT_TRUE(all_close( read_vector(t_result), vector{ @@ -274,7 +279,8 @@ NGRAPH_TEST(${BACKEND_NAME}, convolution_2d_2items) { 9.f, -7.f, 3.f, 0.f, 6.f, -1.f, -4.f, -2.f, 7.f, -0.f, -1.f, 7.f, -4.f, -9.f}); copy_data(t_b, vector{-8.f, 2.f, -4.f, -2.f, 9.f, 9.f, -0.f, -3.f}); - backend->call(backend->compile(f), {t_result}, {t_a, t_b}); + auto handle = backend->compile(f); + handle->call({t_result}, {t_a, t_b}); EXPECT_TRUE(all_close( read_vector(t_result), vector{32.0f, -18.0f, 56.0f, 56.0f, -42.0f, -14.0f, -16.0f, @@ -319,7 +325,8 @@ NGRAPH_TEST(${BACKEND_NAME}, convolution_2d_2items_strided_padded) { 9.f, -7.f, 3.f, 0.f, 6.f, -1.f, -4.f, -2.f, 7.f, -0.f, -1.f, 7.f, -4.f, -9.f}); copy_data(t_b, vector{-8.f, 2.f, -4.f, -2.f, 9.f, 9.f, -0.f, -3.f}); - backend->call(backend->compile(f), {t_result}, {t_a, t_b}); + auto handle = backend->compile(f); + handle->call({t_result}, {t_a, t_b}); EXPECT_TRUE(all_close( read_vector(t_result), vector{ diff --git a/test/test_cryptonets.cpp b/test/test_cryptonets.cpp index 1720e40d..7612d1c6 100644 --- a/test/test_cryptonets.cpp +++ b/test/test_cryptonets.cpp @@ -116,7 +116,8 @@ static void run_cryptonets_benchmark(string backend_name, NGRAPH_INFO << "calling function"; stopwatch sw_run_model; sw_run_model.start(); - backend->call(backend->compile(f), result_tvs, parameter_tvs); + auto handle = backend->compile(f); + handle->call(result_tvs, parameter_tvs); sw_run_model.stop(); NGRAPH_INFO << "sw_run_model: " << sw_run_model.get_milliseconds() << "ms"; diff --git a/test/test_dot.in.cpp b/test/test_dot.in.cpp index 37ccce66..c5419597 100644 --- a/test/test_dot.in.cpp +++ b/test/test_dot.in.cpp @@ -49,7 +49,8 @@ NGRAPH_TEST(${BACKEND_NAME}, dot1d) { copy_data(t_a, vector{1, 2, 3, 4}); copy_data(t_b, vector{5, 6, 7, 8}); - backend->call(backend->compile(f), {t_result}, {t_a, t_b}); + auto handle = backend->compile(f); + handle->call({t_result}, {t_a, t_b}); EXPECT_TRUE( all_close(read_vector(t_result), vector{70}, 1e-3f)); } @@ -81,7 +82,8 @@ NGRAPH_TEST(${BACKEND_NAME}, dot1d_optimized) { copy_data(t_a, vector{1, 2, 3, 4}); copy_data(t_b, vector{-1, 0, 1, 2}); - backend->call(backend->compile(f), {t_result}, {t_a, t_b}); + auto handle = backend->compile(f); + handle->call({t_result}, {t_a, t_b}); EXPECT_TRUE( all_close(read_vector(t_result), vector{10}, 1e-1f)); } @@ -112,7 +114,8 @@ NGRAPH_TEST(${BACKEND_NAME}, dot_matrix_vector) { copy_data(t_a, vector{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); copy_data(t_b, vector{17, 18, 19, 20}); - backend->call(backend->compile(f), {t_result}, {t_a, t_b}); + auto handle = backend->compile(f); + handle->call({t_result}, {t_a, t_b}); EXPECT_TRUE(all_close(read_vector(t_result), (vector{190, 486, 782, 1078}), 1e-3f)); } @@ -141,7 +144,8 @@ NGRAPH_TEST(${BACKEND_NAME}, dot_scalar) { copy_data(t_a, vector{8}); copy_data(t_b, vector{6}); - backend->call(backend->compile(f), {t_result}, {t_a, t_b}); + auto handle = backend->compile(f); + handle->call({t_result}, {t_a, t_b}); EXPECT_TRUE( all_close(read_vector(t_result), (vector{48}), 1e-3f)); } @@ -168,7 +172,8 @@ NGRAPH_TEST(${BACKEND_NAME}, dot_scalar_batch) { copy_data(t_a, vector{1, 2, 3}); copy_data(t_b, vector{4}); - backend->call(backend->compile(f), {t_result}, {t_a, t_b}); + auto handle = backend->compile(f); + handle->call({t_result}, {t_a, t_b}); EXPECT_TRUE(all_close((vector{4, 8, 12}), generalized_read_vector(t_result), 1e-3f)); } diff --git a/test/test_layers.in.cpp b/test/test_layers.in.cpp index e59bd5a3..c9e975ae 100644 --- a/test/test_layers.in.cpp +++ b/test/test_layers.in.cpp @@ -48,17 +48,20 @@ NGRAPH_TEST(${BACKEND_NAME}, mult_layer_cipher_cipher) { copy_data(b, test::NDArray({{5, 6}, {7, 8}}).get_vector()); copy_data(c, test::NDArray({{9, 10}, {11, 12}}).get_vector()); - backend->call(backend->compile(f), {result}, {a, b, c}); + auto handle1 = backend->compile(f); + handle1->call({result}, {a, b, c}); EXPECT_TRUE(all_close( read_vector(result), (test::NDArray({{45, 120}, {231, 384}})).get_vector(), 1e-1f)); - backend->call(backend->compile(f), {result}, {b, a, c}); + auto handle2 = backend->compile(f); + handle2->call({result}, {b, a, c}); EXPECT_TRUE(all_close( read_vector(result), (test::NDArray({{45, 120}, {231, 384}})).get_vector(), 1e-1f)); - backend->call(backend->compile(f), {result}, {c, a, b}); + auto handle3 = backend->compile(f); + handle3->call({result}, {c, a, b}); EXPECT_TRUE(all_close( read_vector(result), (test::NDArray({{45, 120}, {231, 384}})).get_vector(), 1e-1f)); @@ -85,17 +88,20 @@ NGRAPH_TEST(${BACKEND_NAME}, mult_layer_cipher_plain) { copy_data(b, test::NDArray({{5, 6}, {7, 8}}).get_vector()); copy_data(c, test::NDArray({{9, 10}, {11, 12}}).get_vector()); - backend->call(backend->compile(f), {result}, {a, b, c}); + auto handle1 = backend->compile(f); + handle1->call({result}, {a, b, c}); EXPECT_TRUE(all_close( read_vector(result), (test::NDArray({{45, 120}, {231, 384}})).get_vector(), 1e-1f)); - backend->call(backend->compile(f), {result}, {b, a, c}); + auto handle2 = backend->compile(f); + handle2->call({result}, {b, a, c}); EXPECT_TRUE(all_close( read_vector(result), (test::NDArray({{45, 120}, {231, 384}})).get_vector(), 1e-1f)); - backend->call(backend->compile(f), {result}, {c, a, b}); + auto handle3 = backend->compile(f); + handle3->call({result}, {c, a, b}); EXPECT_TRUE(all_close( read_vector(result), (test::NDArray({{45, 120}, {231, 384}})).get_vector(), 1e-1f)); @@ -122,17 +128,20 @@ NGRAPH_TEST(${BACKEND_NAME}, mult_layer_plain_plain) { copy_data(b, test::NDArray({{5, 6}, {7, 8}}).get_vector()); copy_data(c, test::NDArray({{9, 10}, {11, 12}}).get_vector()); - backend->call(backend->compile(f), {result}, {a, b, c}); + auto handle1 = backend->compile(f); + handle1->call({result}, {a, b, c}); EXPECT_TRUE(all_close( read_vector(result), (test::NDArray({{45, 120}, {231, 384}})).get_vector(), 1e-1f)); - backend->call(backend->compile(f), {result}, {b, a, c}); + auto handle2 = backend->compile(f); + handle2->call({result}, {b, a, c}); EXPECT_TRUE(all_close( read_vector(result), (test::NDArray({{45, 120}, {231, 384}})).get_vector(), 1e-1f)); - backend->call(backend->compile(f), {result}, {c, a, b}); + auto handle3 = backend->compile(f); + handle3->call({result}, {c, a, b}); EXPECT_TRUE(all_close( read_vector(result), (test::NDArray({{45, 120}, {231, 384}})).get_vector(), 1e-1f)); @@ -159,7 +168,8 @@ NGRAPH_TEST(${BACKEND_NAME}, add_layer_cipher_cipher) { copy_data(b, test::NDArray({{5, 6}, {7, 8}}).get_vector()); copy_data(c, test::NDArray({{9, 10}, {11, 12}}).get_vector()); - backend->call(backend->compile(f), {result}, {a, b, c}); + auto handle = backend->compile(f); + handle->call({result}, {a, b, c}); EXPECT_TRUE(all_close( read_vector(result), (test::NDArray({{14, 22}, {32, 44}})).get_vector(), 1e-1f)); @@ -186,7 +196,8 @@ NGRAPH_TEST(${BACKEND_NAME}, add_layer_cipher_plain) { copy_data(b, test::NDArray({{5, 6}, {7, 8}}).get_vector()); copy_data(c, test::NDArray({{9, 10}, {11, 12}}).get_vector()); - backend->call(backend->compile(f), {result}, {a, b, c}); + auto handle = backend->compile(f); + handle->call({result}, {a, b, c}); EXPECT_TRUE(all_close( read_vector(result), (test::NDArray({{14, 22}, {32, 44}})).get_vector(), 1e-1f)); @@ -213,7 +224,8 @@ NGRAPH_TEST(${BACKEND_NAME}, add_layer_plain_plain) { copy_data(b, test::NDArray({{5, 6}, {7, 8}}).get_vector()); copy_data(c, test::NDArray({{9, 10}, {11, 12}}).get_vector()); - backend->call(backend->compile(f), {result}, {a, b, c}); + auto handle = backend->compile(f); + handle->call({result}, {a, b, c}); EXPECT_TRUE(all_close( read_vector(result), (test::NDArray({{14, 22}, {32, 44}})).get_vector(), 1e-1f)); diff --git a/test/test_multiply.in.cpp b/test/test_multiply.in.cpp index 0c36d734..20f07d7b 100644 --- a/test/test_multiply.in.cpp +++ b/test/test_multiply.in.cpp @@ -51,7 +51,8 @@ NGRAPH_TEST(${BACKEND_NAME}, multiply_2_3) { test::NDArray({{1, 2, 3}, {4, 5, 6}}).get_vector()); copy_data(t_b, test::NDArray({{7, 8, 9}, {10, 11, 12}}).get_vector()); - backend->call(backend->compile(f), {t_result}, {t_a, t_b}); + auto handle = backend->compile(f); + handle->call({t_result}, {t_a, t_b}); EXPECT_TRUE(all_close( read_vector(t_result), (test::NDArray({{7, 16, 27}, {40, 55, 72}})).get_vector(), @@ -80,7 +81,8 @@ NGRAPH_TEST(${BACKEND_NAME}, square_2_3) { copy_data(t_a, test::NDArray({{1, 2, 3}, {4, 5, 6}}).get_vector()); - backend->call(backend->compile(f), {t_result}, {t_a, t_b}); + auto handle = backend->compile(f); + handle->call({t_result}, {t_a, t_b}); EXPECT_TRUE(all_close( read_vector(t_result), (test::NDArray({{1, 4, 9}, {16, 25, 36}})).get_vector(), @@ -115,7 +117,8 @@ NGRAPH_TEST(${BACKEND_NAME}, multiply_optimized_2_3) { test::NDArray({{1, 2, 3}, {4, 5, 6}}).get_vector()); copy_data(t_b, test::NDArray({{-1, 0, 1}, {-1, 0, 1}}).get_vector()); - backend->call(backend->compile(f), {t_result}, {t_a, t_b}); + auto handle = backend->compile(f); + handle->call({t_result}, {t_a, t_b}); EXPECT_TRUE(all_close( read_vector(t_result), (test::NDArray({{-1, 0, 3}, {-4, 0, 6}})).get_vector(), @@ -144,7 +147,8 @@ NGRAPH_TEST(${BACKEND_NAME}, multiply_4_3_batch) { copy_data(t_a, vector{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); copy_data(t_b, vector{13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}); - backend->call(backend->compile(f), {t_result}, {t_a, t_b}); + auto handle = backend->compile(f); + handle->call({t_result}, {t_a, t_b}); EXPECT_TRUE(all_close( (vector{13, 28, 45, 64, 85, 108, 133, 160, 189, 220, 253, 288}), generalized_read_vector(t_result), 1e-3f)); diff --git a/test/test_negate.in.cpp b/test/test_negate.in.cpp index 535c85ec..8e04e9f1 100644 --- a/test/test_negate.in.cpp +++ b/test/test_negate.in.cpp @@ -48,7 +48,8 @@ NGRAPH_TEST(${BACKEND_NAME}, negate_2_3) { copy_data(t_a, test::NDArray({{-3, -2, -1}, {0, 1, 2}}).get_vector()); - backend->call(backend->compile(f), {t_result}, {t_a}); + auto handle = backend->compile(f); + handle->call({t_result}, {t_a}); EXPECT_TRUE(all_close( read_vector(t_result), (test::NDArray({{3, 2, 1}, {0, -1, -2}})).get_vector())); diff --git a/test/test_pad.in.cpp b/test/test_pad.in.cpp index fa7b2e5e..dd45f19c 100644 --- a/test/test_pad.in.cpp +++ b/test/test_pad.in.cpp @@ -51,7 +51,8 @@ NGRAPH_TEST(${BACKEND_NAME}, pad_interior_1d) { copy_data(b, vector{2112}); auto result = he_backend->create_cipher_tensor(element::f32, shape_r); - backend->call(backend->compile(f), {result}, {a, b}); + auto handle = backend->compile(f); + handle->call({result}, {a, b}); EXPECT_TRUE(all_close( (test::NDArray({1, 2112, 2112, 2, 2112, 2112, 3, 2112, 2112, 4, 2112, 2112, 5, 2112, 2112, 6}) @@ -83,7 +84,8 @@ NGRAPH_TEST(${BACKEND_NAME}, pad_exterior_1d) { copy_data(b, vector{2112}); auto result = he_backend->create_cipher_tensor(element::f32, shape_r); - backend->call(backend->compile(f), {result}, {a, b}); + auto handle = backend->compile(f); + handle->call({result}, {a, b}); EXPECT_TRUE( all_close((test::NDArray({2112, 2112, 2112, 2112, 1, 2, 3, 4, 5, 6, 2112, 2112, 2112, 2112, 2112}) @@ -115,7 +117,8 @@ NGRAPH_TEST(${BACKEND_NAME}, pad_interior_exterior_1d) { copy_data(b, vector{2112}); auto result = he_backend->create_cipher_tensor(element::f32, shape_r); - he_backend->call(backend->compile(f), {result}, {a, b}); + auto handle = backend->compile(f); + handle->call({result}, {a, b}); EXPECT_TRUE(all_close((test::NDArray( {2112, 2112, 2112, 2112, 1, 2112, 2112, 2, 2112, 2112, 3, 2112, 2112, 4, 2112, 2112, 5, 2112, @@ -148,7 +151,8 @@ NGRAPH_TEST(${BACKEND_NAME}, pad_interior_exterior_2d) { copy_data(b, vector{9}); auto result = he_backend->create_cipher_tensor(element::f32, shape_r); - backend->call(backend->compile(f), {result}, {a, b}); + auto handle = backend->compile(f); + handle->call({result}, {a, b}); EXPECT_TRUE(all_close((test::NDArray({{9, 9, 9, 9, 9, 9}, {1, 9, 2, 9, 3, 9}, {9, 9, 9, 9, 9, 9}, @@ -184,7 +188,8 @@ NGRAPH_TEST(${BACKEND_NAME}, pad_exterior_2d_0x0) { copy_data(b, vector{2112}); auto result = he_backend->create_cipher_tensor(element::f32, shape_r); - backend->call(backend->compile(f), {result}, {a, b}); + auto handle = backend->compile(f); + handle->call({result}, {a, b}); EXPECT_TRUE( all_close((test::NDArray({{2112, 2112, 2112, 2112, 2112}, {2112, 2112, 2112, 2112, 2112}, @@ -219,7 +224,8 @@ NGRAPH_TEST(${BACKEND_NAME}, pad_exterior_2d_0x3) { copy_data(b, vector{2112}); auto result = he_backend->create_cipher_tensor(element::f32, shape_r); - backend->call(backend->compile(f), {result}, {a, b}); + auto handle = backend->compile(f); + handle->call({result}, {a, b}); EXPECT_TRUE( all_close((test::NDArray({{2112, 2112, 2112, 2112, 2112}, {2112, 2112, 2112, 2112, 2112}, @@ -254,7 +260,8 @@ NGRAPH_TEST(${BACKEND_NAME}, pad_exterior_2d_3x0) { copy_data(b, vector{2112}); auto result = he_backend->create_cipher_tensor(element::f32, shape_r); - backend->call(backend->compile(f), {result}, {a, b}); + auto handle = backend->compile(f); + handle->call({result}, {a, b}); EXPECT_TRUE( all_close((test::NDArray({{2112, 2112, 2112, 2112, 2112}, {2112, 2112, 2112, 2112, 2112}, @@ -305,7 +312,8 @@ NGRAPH_TEST(${BACKEND_NAME}, pad_exterior_4d_1x2x2x2) { auto result = he_backend->create_cipher_tensor(element::f32, shape_r); - backend->call(backend->compile(f), {result}, {a, b}); + auto handle = backend->compile(f); + handle->call({result}, {a, b}); // clang-format off EXPECT_TRUE(all_close((test::NDArray( { @@ -359,7 +367,8 @@ NGRAPH_TEST(${BACKEND_NAME}, pad_interior_exterior_4d_2x0x3x2) { vector expected(5 * 2 * 3 * 2, 2112); - backend->call(backend->compile(f), {result}, {a, b}); + auto handle = backend->compile(f); + handle->call({result}, {a, b}); EXPECT_TRUE(all_close(expected, read_vector(result), 1e-3f)); } @@ -385,7 +394,8 @@ NGRAPH_TEST(${BACKEND_NAME}, pad_const_interior_1d) { copy_data(a, test::NDArray({1, 2, 3, 4, 5, 6}).get_vector()); auto result = he_backend->create_cipher_tensor(element::f32, shape_r); - backend->call(backend->compile(f), {result}, {a}); + auto handle = backend->compile(f); + handle->call({result}, {a}); EXPECT_TRUE( all_close((test::NDArray({1, 123, 123, 2, 123, 123, 3, 123, 123, 4, 123, 123, 5, 123, 123, 6}) @@ -415,7 +425,8 @@ NGRAPH_TEST(${BACKEND_NAME}, pad_const_exterior_1d) { copy_data(a, test::NDArray({1, 2, 3, 4, 5, 6}).get_vector()); auto result = he_backend->create_cipher_tensor(element::f32, shape_r); - backend->call(backend->compile(f), {result}, {a}); + auto handle = backend->compile(f); + handle->call({result}, {a}); EXPECT_TRUE( all_close((test::NDArray({123, 123, 123, 123, 1, 2, 3, 4, 5, 6, 123, 123, 123, 123, 123}) @@ -445,7 +456,8 @@ NGRAPH_TEST(${BACKEND_NAME}, pad_const_interior_exterior_1d) { copy_data(a, test::NDArray({1, 2, 3, 4, 5, 6}).get_vector()); auto result = he_backend->create_cipher_tensor(element::f32, shape_r); - backend->call(backend->compile(f), {result}, {a}); + auto handle = backend->compile(f); + handle->call({result}, {a}); EXPECT_TRUE(all_close( (test::NDArray({123, 123, 123, 123, 1, 123, 123, 2, 123, 123, 3, 123, 123, 4, 123, 123, 5, 123, @@ -476,7 +488,8 @@ NGRAPH_TEST(${BACKEND_NAME}, pad_const_interior_exterior_2d) { copy_data(a, test::NDArray({{1, 2, 3}, {4, 5, 6}}).get_vector()); auto result = he_backend->create_cipher_tensor(element::f32, shape_r); - backend->call(backend->compile(f), {result}, {a}); + auto handle = backend->compile(f); + handle->call({result}, {a}); EXPECT_TRUE( all_close((test::NDArray({{123, 123, 123, 123, 123, 123}, {1, 123, 2, 123, 3, 123}, @@ -511,7 +524,8 @@ NGRAPH_TEST(${BACKEND_NAME}, pad_const_exterior_2d_0x0) { // copy_data(a, test::NDArray({{}}).get_vector(), backend); auto result = he_backend->create_cipher_tensor(element::f32, shape_r); - backend->call(backend->compile(f), {result}, {a}); + auto handle = backend->compile(f); + handle->call({result}, {a}); EXPECT_TRUE(all_close((test::NDArray({{123, 123, 123, 123, 123}, {123, 123, 123, 123, 123}, {123, 123, 123, 123, 123}, @@ -543,7 +557,8 @@ NGRAPH_TEST(${BACKEND_NAME}, pad_const_exterior_2d_0x3) { // copy_data(a, test::NDArray({}).get_vector(), backend); auto result = he_backend->create_cipher_tensor(element::f32, shape_r); - backend->call(backend->compile(f), {result}, {a}); + auto handle = backend->compile(f); + handle->call({result}, {a}); EXPECT_TRUE(all_close((test::NDArray({{123, 123, 123, 123, 123}, {123, 123, 123, 123, 123}, {123, 123, 123, 123, 123}, @@ -575,7 +590,8 @@ NGRAPH_TEST(${BACKEND_NAME}, pad_const_exterior_2d_3x0) { // copy_data(a, test::NDArray({}).get_vector(), backend); auto result = he_backend->create_cipher_tensor(element::f32, shape_r); - backend->call(backend->compile(f), {result}, {a}); + auto handle = backend->compile(f); + handle->call({result}, {a}); EXPECT_TRUE(all_close((test::NDArray({{123, 123, 123, 123, 123}, {123, 123, 123, 123, 123}, {123, 123, 123, 123, 123}, @@ -622,7 +638,8 @@ NGRAPH_TEST(${BACKEND_NAME}, pad_const_exterior_4d_1x2x2x2) { auto result = he_backend->create_cipher_tensor(element::f32, shape_r); - backend->call(backend->compile(f), {result}, {a}); + auto handle = backend->compile(f); + handle->call({result}, {a}); // clang-format off EXPECT_TRUE(all_close((test::NDArray( { @@ -674,6 +691,7 @@ NGRAPH_TEST(${BACKEND_NAME}, pad_const_interior_exterior_4d_2x0x3x2) { vector expected(5 * 2 * 3 * 2, 123); - backend->call(backend->compile(f), {result}, {a}); + auto handle = backend->compile(f); + handle->call({result}, {a}); EXPECT_TRUE(all_close(expected, read_vector(result), 1e-3f)); } diff --git a/test/test_reshape.in.cpp b/test/test_reshape.in.cpp index 7f90fe17..84d609b8 100644 --- a/test/test_reshape.in.cpp +++ b/test/test_reshape.in.cpp @@ -47,7 +47,8 @@ NGRAPH_TEST(${BACKEND_NAME}, reshape_t2v_012) { copy_data(a, vector{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); - backend->call(backend->compile(f), {result}, {a}); + auto handle = backend->compile(f); + handle->call({result}, {a}); EXPECT_TRUE( all_close((vector{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}), read_vector(result), 1e-3f)); @@ -74,7 +75,8 @@ NGRAPH_TEST(${BACKEND_NAME}, reshape_t2s_012) { copy_data(a, vector{6}); - backend->call(backend->compile(f), {result}, {a}); + auto handle = backend->compile(f); + handle->call({result}, {a}); EXPECT_TRUE( all_close((vector{6}), read_vector(result), 1e-3f)); } @@ -100,7 +102,8 @@ NGRAPH_TEST(${BACKEND_NAME}, reshape_t2s_120) { copy_data(a, vector{6}); - backend->call(backend->compile(f), {result}, {a}); + auto handle = backend->compile(f); + handle->call({result}, {a}); EXPECT_TRUE( all_close((vector{6}), read_vector(result), 1e-3f)); } @@ -126,7 +129,8 @@ NGRAPH_TEST(${BACKEND_NAME}, reshape_s2t) { copy_data(a, vector{42}); - backend->call(backend->compile(f), {result}, {a}); + auto handle = backend->compile(f); + handle->call({result}, {a}); EXPECT_TRUE( all_close((vector{42}), read_vector(result), 1e-3f)); } @@ -151,7 +155,8 @@ NGRAPH_TEST(${BACKEND_NAME}, reshape_v2m_col) { auto result = results[0]; copy_data(a, vector{1, 2, 3}); - backend->call(backend->compile(f), {result}, {a}); + auto handle = backend->compile(f); + handle->call({result}, {a}); EXPECT_TRUE( all_close((vector{1, 2, 3}), read_vector(result), 1e-3f)); } @@ -176,7 +181,8 @@ NGRAPH_TEST(${BACKEND_NAME}, reshape_v2m_row) { auto result = results[0]; copy_data(a, vector{1, 2, 3}); - backend->call(backend->compile(f), {result}, {a}); + auto handle = backend->compile(f); + handle->call({result}, {a}); EXPECT_TRUE( all_close((vector{1, 2, 3}), read_vector(result), 1e-3f)); } @@ -201,7 +207,8 @@ NGRAPH_TEST(${BACKEND_NAME}, reshape_v2t_middle) { auto result = results[0]; copy_data(a, vector{1, 2, 3}); - backend->call(backend->compile(f), {result}, {a}); + auto handle = backend->compile(f); + handle->call({result}, {a}); EXPECT_TRUE( all_close((vector{1, 2, 3}), read_vector(result), 1e-3f)); } @@ -226,7 +233,8 @@ NGRAPH_TEST(${BACKEND_NAME}, reshape_m2m_same) { auto result = results[0]; copy_data(a, vector{1, 2, 3, 4, 5, 6, 7, 8, 9}); - backend->call(backend->compile(f), {result}, {a}); + auto handle = backend->compile(f); + handle->call({result}, {a}); EXPECT_TRUE(all_close((vector{1, 2, 3, 4, 5, 6, 7, 8, 9}), read_vector(result), 1e-3f)); } @@ -251,7 +259,8 @@ NGRAPH_TEST(${BACKEND_NAME}, reshape_m2m_transpose) { auto result = results[0]; copy_data(a, vector{1, 2, 3, 4, 5, 6, 7, 8, 9}); - backend->call(backend->compile(f), {result}, {a}); + auto handle = backend->compile(f); + handle->call({result}, {a}); EXPECT_TRUE(all_close((vector{1, 4, 7, 2, 5, 8, 3, 6, 9}), read_vector(result), 1e-3f)); } @@ -276,7 +285,8 @@ NGRAPH_TEST(${BACKEND_NAME}, reshape_m2m_dim_change_transpose) { auto result = results[0]; copy_data(a, vector{1, 2, 3, 4, 5, 6}); - backend->call(backend->compile(f), {result}, {a}); + auto handle = backend->compile(f); + handle->call({result}, {a}); EXPECT_TRUE(all_close((vector{1, 3, 5, 2, 4, 6}), read_vector(result), 1e-3f)); } @@ -348,7 +358,8 @@ NGRAPH_TEST(${BACKEND_NAME}, reshape_6d) { auto result = results[0]; copy_data(a, a_data); - backend->call(backend->compile(f), {result}, {a}); + auto handle = backend->compile(f); + handle->call({result}, {a}); EXPECT_TRUE(all_close( (vector{ 1., 73., 9., 81., 17., 89., 2., 74., 10., 82., 18., diff --git a/test/test_reverse.in.cpp b/test/test_reverse.in.cpp index e83b1e1f..4469d1a0 100644 --- a/test/test_reverse.in.cpp +++ b/test/test_reverse.in.cpp @@ -46,7 +46,8 @@ NGRAPH_TEST(${BACKEND_NAME}, reverse_0d) { copy_data(a, vector{6}); - backend->call(backend->compile(f), {result}, {a}); + auto handle = backend->compile(f); + handle->call({result}, {a}); EXPECT_TRUE( all_close((vector{6}), read_vector(result), 1e-3f)); } @@ -71,7 +72,8 @@ NGRAPH_TEST(${BACKEND_NAME}, reverse_1d_nochange) { copy_data(a, vector{0, 1, 2, 3, 4, 5, 6, 7}); - backend->call(backend->compile(f), {result}, {a}); + auto handle = backend->compile(f); + handle->call({result}, {a}); EXPECT_TRUE(all_close((vector{0, 1, 2, 3, 4, 5, 6, 7}), read_vector(result), 1e-3f)); } @@ -96,7 +98,8 @@ NGRAPH_TEST(${BACKEND_NAME}, reverse_1d_0) { copy_data(a, vector{0, 1, 2, 3, 4, 5, 6, 7}); - backend->call(backend->compile(f), {result}, {a}); + auto handle = backend->compile(f); + handle->call({result}, {a}); EXPECT_TRUE(all_close((vector{7, 6, 5, 4, 3, 2, 1, 0}), read_vector(result), 1e-3f)); } @@ -123,7 +126,8 @@ NGRAPH_TEST(${BACKEND_NAME}, reverse_2d_nochange) { {{0, 1, 2}, {3, 4, 5}, {6, 7, 8}, {9, 10, 11}}) .get_vector()); - backend->call(backend->compile(f), {result}, {a}); + auto handle = backend->compile(f); + handle->call({result}, {a}); EXPECT_TRUE(all_close( (test::NDArray({{0, 1, 2}, {3, 4, 5}, {6, 7, 8}, {9, 10, 11}}) .get_vector()), @@ -152,7 +156,8 @@ NGRAPH_TEST(${BACKEND_NAME}, reverse_2d_0) { {{0, 1, 2}, {3, 4, 5}, {6, 7, 8}, {9, 10, 11}}) .get_vector()); - backend->call(backend->compile(f), {result}, {a}); + auto handle = backend->compile(f); + handle->call({result}, {a}); EXPECT_TRUE(all_close( (test::NDArray({{9, 10, 11}, {6, 7, 8}, {3, 4, 5}, {0, 1, 2}}) .get_vector()), @@ -181,7 +186,8 @@ NGRAPH_TEST(${BACKEND_NAME}, reverse_2d_1) { {{0, 1, 2}, {3, 4, 5}, {6, 7, 8}, {9, 10, 11}}) .get_vector()); - backend->call(backend->compile(f), {result}, {a}); + auto handle = backend->compile(f); + handle->call({result}, {a}); EXPECT_TRUE(all_close( (test::NDArray({{2, 1, 0}, {5, 4, 3}, {8, 7, 6}, {11, 10, 9}}) .get_vector()), @@ -210,7 +216,8 @@ NGRAPH_TEST(${BACKEND_NAME}, reverse_2d_01) { {{0, 1, 2}, {3, 4, 5}, {6, 7, 8}, {9, 10, 11}}) .get_vector()); - backend->call(backend->compile(f), {result}, {a}); + auto handle = backend->compile(f); + handle->call({result}, {a}); EXPECT_TRUE(all_close( (test::NDArray({{11, 10, 9}, {8, 7, 6}, {5, 4, 3}, {2, 1, 0}}) .get_vector()), @@ -240,7 +247,8 @@ NGRAPH_TEST(${BACKEND_NAME}, reverse_3d_nochange) { {{12, 13, 14}, {15, 16, 17}, {18, 19, 20}, {21, 22, 23}}}) .get_vector()); - backend->call(backend->compile(f), {result}, {a}); + auto handle = backend->compile(f); + handle->call({result}, {a}); EXPECT_TRUE(all_close( (test::NDArray( {{{0, 1, 2}, {3, 4, 5}, {6, 7, 8}, {9, 10, 11}}, @@ -272,7 +280,8 @@ NGRAPH_TEST(${BACKEND_NAME}, reverse_3d_0) { {{12, 13, 14}, {15, 16, 17}, {18, 19, 20}, {21, 22, 23}}}) .get_vector()); - backend->call(backend->compile(f), {result}, {a}); + auto handle = backend->compile(f); + handle->call({result}, {a}); EXPECT_TRUE(all_close( (test::NDArray( {{{12, 13, 14}, {15, 16, 17}, {18, 19, 20}, {21, 22, 23}}, @@ -304,7 +313,8 @@ NGRAPH_TEST(${BACKEND_NAME}, reverse_3d_1) { {{12, 13, 14}, {15, 16, 17}, {18, 19, 20}, {21, 22, 23}}}) .get_vector()); - backend->call(backend->compile(f), {result}, {a}); + auto handle = backend->compile(f); + handle->call({result}, {a}); EXPECT_TRUE(all_close( (test::NDArray( {{{9, 10, 11}, {6, 7, 8}, {3, 4, 5}, {0, 1, 2}}, @@ -336,7 +346,8 @@ NGRAPH_TEST(${BACKEND_NAME}, reverse_3d_2) { {{12, 13, 14}, {15, 16, 17}, {18, 19, 20}, {21, 22, 23}}}) .get_vector()); - backend->call(backend->compile(f), {result}, {a}); + auto handle = backend->compile(f); + handle->call({result}, {a}); EXPECT_TRUE(all_close( (test::NDArray( {{{2, 1, 0}, {5, 4, 3}, {8, 7, 6}, {11, 10, 9}}, @@ -368,7 +379,8 @@ NGRAPH_TEST(${BACKEND_NAME}, reverse_3d_01) { {{12, 13, 14}, {15, 16, 17}, {18, 19, 20}, {21, 22, 23}}}) .get_vector()); - backend->call(backend->compile(f), {result}, {a}); + auto handle = backend->compile(f); + handle->call({result}, {a}); EXPECT_TRUE(all_close( (test::NDArray( {{{21, 22, 23}, {18, 19, 20}, {15, 16, 17}, {12, 13, 14}}, @@ -400,7 +412,8 @@ NGRAPH_TEST(${BACKEND_NAME}, reverse_3d_02) { {{12, 13, 14}, {15, 16, 17}, {18, 19, 20}, {21, 22, 23}}}) .get_vector()); - backend->call(backend->compile(f), {result}, {a}); + auto handle = backend->compile(f); + handle->call({result}, {a}); EXPECT_TRUE(all_close( (test::NDArray( {{{14, 13, 12}, {17, 16, 15}, {20, 19, 18}, {23, 22, 21}}, @@ -432,7 +445,8 @@ NGRAPH_TEST(${BACKEND_NAME}, reverse_3d_12) { {{12, 13, 14}, {15, 16, 17}, {18, 19, 20}, {21, 22, 23}}}) .get_vector()); - backend->call(backend->compile(f), {result}, {a}); + auto handle = backend->compile(f); + handle->call({result}, {a}); EXPECT_TRUE(all_close( (test::NDArray( {{{11, 10, 9}, {8, 7, 6}, {5, 4, 3}, {2, 1, 0}}, @@ -464,7 +478,8 @@ NGRAPH_TEST(${BACKEND_NAME}, reverse_3d_012) { {{12, 13, 14}, {15, 16, 17}, {18, 19, 20}, {21, 22, 23}}}) .get_vector()); - backend->call(backend->compile(f), {result}, {a}); + auto handle = backend->compile(f); + handle->call({result}, {a}); EXPECT_TRUE(all_close( (test::NDArray( {{{23, 22, 21}, {20, 19, 18}, {17, 16, 15}, {14, 13, 12}}, diff --git a/test/test_slice.in.cpp b/test/test_slice.in.cpp index ffe5a36f..d382245f 100644 --- a/test/test_slice.in.cpp +++ b/test/test_slice.in.cpp @@ -46,7 +46,8 @@ NGRAPH_TEST(${BACKEND_NAME}, slice_scalar) { auto result = results[0]; copy_data(a, vector{312}); - backend->call(backend->compile(f), {result}, {a}); + auto handle = backend->compile(f); + handle->call({result}, {a}); EXPECT_EQ((vector{312}), read_vector(result)); } } @@ -71,7 +72,8 @@ NGRAPH_TEST(${BACKEND_NAME}, slice_matrix) { copy_data(a, vector{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); - backend->call(backend->compile(f), {result}, {a}); + auto handle = backend->compile(f); + handle->call({result}, {a}); EXPECT_TRUE(all_close((vector{2, 3, 6, 7, 10, 11}), read_vector(result))); } @@ -97,7 +99,8 @@ NGRAPH_TEST(${BACKEND_NAME}, slice_vector) { copy_data( a, vector{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}); - backend->call(backend->compile(f), {result}, {a}); + auto handle = backend->compile(f); + handle->call({result}, {a}); EXPECT_TRUE( all_close((vector{2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13}), read_vector(result))); @@ -125,7 +128,8 @@ NGRAPH_TEST(${BACKEND_NAME}, slice_matrix_strided) { copy_data( a, vector{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}); - backend->call(backend->compile(f), {result}, {a}); + auto handle = backend->compile(f); + handle->call({result}, {a}); EXPECT_TRUE( all_close((vector{4, 7, 12, 15}), read_vector(result))); } @@ -155,7 +159,8 @@ NGRAPH_TEST(${BACKEND_NAME}, slice_3d) { 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63}); - backend->call(backend->compile(f), {result}, {a}); + auto handle = backend->compile(f); + handle->call({result}, {a}); EXPECT_TRUE(all_close((vector{21, 22, 25, 26, 37, 38, 41, 42}), read_vector(result))); } @@ -186,7 +191,8 @@ NGRAPH_TEST(${BACKEND_NAME}, slice_3d_strided) { 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64}); - backend->call(backend->compile(f), {result}, {a}); + auto handle = backend->compile(f); + handle->call({result}, {a}); EXPECT_TRUE(all_close((vector{1, 3, 9, 11, 33, 35, 41, 43}), read_vector(result))); } @@ -218,7 +224,8 @@ NGRAPH_TEST(${BACKEND_NAME}, slice_3d_strided_different_strides) { 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64}); - backend->call(backend->compile(f), {result}, {a}); + auto handle = backend->compile(f); + handle->call({result}, {a}); EXPECT_TRUE(all_close((vector{1, 4, 9, 12, 33, 36, 41, 44}), read_vector(result))); } diff --git a/test/test_subtract.in.cpp b/test/test_subtract.in.cpp index c3dc3c63..8b44efee 100644 --- a/test/test_subtract.in.cpp +++ b/test/test_subtract.in.cpp @@ -51,7 +51,8 @@ NGRAPH_TEST(${BACKEND_NAME}, sub_2_3) { test::NDArray({{1, 2, 3}, {10, 11, 12}}).get_vector()); copy_data(t_b, test::NDArray({{7, 8, 9}, {4, 5, 6}}).get_vector()); - backend->call(backend->compile(f), {t_result}, {t_a, t_b}); + auto handle = backend->compile(f); + handle->call({t_result}, {t_a, t_b}); EXPECT_TRUE(all_close( read_vector(t_result), (test::NDArray({{-6, -6, -6}, {6, 6, 6}})).get_vector(), @@ -85,7 +86,8 @@ NGRAPH_TEST(${BACKEND_NAME}, sub_zero_2_3) { test::NDArray({{1, 2, 3}, {4, 5, 6}}).get_vector()); copy_data(t_b, test::NDArray({{0, 0, 0}, {0, 0, 0}}).get_vector()); - backend->call(backend->compile(f), {t_result}, {t_a, t_b}); + auto handle = backend->compile(f); + handle->call({t_result}, {t_a, t_b}); EXPECT_TRUE(all_close( read_vector(t_result), (test::NDArray({{1, 2, 3}, {4, 5, 6}})).get_vector(), 1e-3f)); @@ -118,7 +120,8 @@ NGRAPH_TEST(${BACKEND_NAME}, sub_from_zero_2_3) { test::NDArray({{0, 0, 0}, {0, 0, 0}}).get_vector()); copy_data(t_b, test::NDArray({{1, 2, 3}, {-1, -2, -3}}).get_vector()); - backend->call(backend->compile(f), {t_result}, {t_a, t_b}); + auto handle = backend->compile(f); + handle->call({t_result}, {t_a, t_b}); EXPECT_TRUE(all_close( read_vector(t_result), (test::NDArray({{-1, -2, -3}, {1, 2, 3}})).get_vector(), diff --git a/test/test_sum.in.cpp b/test/test_sum.in.cpp index 4a60cdc0..467c4abd 100644 --- a/test/test_sum.in.cpp +++ b/test/test_sum.in.cpp @@ -48,7 +48,8 @@ NGRAPH_TEST(${BACKEND_NAME}, sum_trivial) { auto t_result = results[0]; copy_data(t_a, vector{1, 2, 3, 4}); - backend->call(backend->compile(f), {t_result}, {t_a}); + auto handle = backend->compile(f); + handle->call({t_result}, {t_a}); EXPECT_TRUE( all_close((vector{1, 2, 3, 4}), read_vector(t_result))); } @@ -77,7 +78,8 @@ NGRAPH_TEST(${BACKEND_NAME}, sum_trivial_5d) { copy_data(t_a, vector{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}); - backend->call(backend->compile(f), {t_result}, {t_a}); + auto handle = backend->compile(f); + handle->call({t_result}, {t_a}); EXPECT_TRUE(all_close( (vector{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}), @@ -106,7 +108,8 @@ NGRAPH_TEST(${BACKEND_NAME}, sum_to_scalar) { auto t_result = results[0]; copy_data(t_a, vector{1, 2, 3, 4}); - backend->call(backend->compile(f), {t_result}, {t_a}); + auto handle = backend->compile(f); + handle->call({t_result}, {t_a}); EXPECT_TRUE( all_close((vector{10}), read_vector(t_result), 1e-3f)); @@ -138,7 +141,8 @@ NGRAPH_TEST(${BACKEND_NAME}, sum_matrix_columns) { auto t_result = results[0]; copy_data(t_a, vector{1, 2, 3, 4, 5, 6}); - backend->call(backend->compile(f), {t_result}, {t_a}); + auto handle = backend->compile(f); + handle->call({t_result}, {t_a}); EXPECT_TRUE( all_close((vector{9, 12}), read_vector(t_result), 1e-3f)); @@ -170,7 +174,8 @@ NGRAPH_TEST(${BACKEND_NAME}, sum_matrix_rows) { auto t_result = results[0]; copy_data(t_a, vector{1, 2, 3, 4, 5, 6}); - backend->call(backend->compile(f), {t_result}, {t_a}); + auto handle = backend->compile(f); + handle->call({t_result}, {t_a}); EXPECT_TRUE( all_close((vector{3, 7, 11}), read_vector(t_result))); @@ -203,7 +208,8 @@ NGRAPH_TEST(${BACKEND_NAME}, sum_matrix_rows_zero) { copy_data(t_a, vector{}); copy_data(t_result, vector({3, 3, 3})); - backend->call(backend->compile(f), {t_result}, {t_a}); + auto handle = backend->compile(f); + handle->call({t_result}, {t_a}); EXPECT_TRUE( all_close((vector{0, 0, 0}), read_vector(t_result))); @@ -235,7 +241,8 @@ NGRAPH_TEST(${BACKEND_NAME}, sum_matrix_cols_zero) { copy_data(t_a, vector{}); copy_data(t_result, vector({3, 3})); - backend->call(backend->compile(f), {t_result}, {t_a}); + auto handle = backend->compile(f); + handle->call({t_result}, {t_a}); EXPECT_TRUE(all_close((vector{0, 0}), read_vector(t_result))); // For some reason I'm feeling extra paranoid about making sure reduction @@ -266,7 +273,8 @@ NGRAPH_TEST(${BACKEND_NAME}, sum_matrix_vector_zero) { copy_data(t_a, vector{}); copy_data(t_result, vector({3})); - backend->call(backend->compile(f), {t_result}, {t_a}); + auto handle = backend->compile(f); + handle->call({t_result}, {t_a}); EXPECT_TRUE(all_close((vector{0}), read_vector(t_result))); // For some reason I'm feeling extra paranoid about making sure reduction @@ -297,7 +305,8 @@ NGRAPH_TEST(${BACKEND_NAME}, sum_matrix_to_scalar_zero_by_zero) { copy_data(t_a, vector{}); copy_data(t_result, vector({3})); - backend->call(backend->compile(f), {t_result}, {t_a}); + auto handle = backend->compile(f); + handle->call({t_result}, {t_a}); EXPECT_TRUE(all_close((vector{0}), read_vector(t_result))); // For some reason I'm feeling extra paranoid about making sure reduction