Skip to content
This repository has been archived by the owner on Jan 3, 2023. It is now read-only.

Commit

Permalink
Add support for more types in gather op. (#2926)
Browse files Browse the repository at this point in the history
* Add test for i32 gather

* Add support for ints to Gather op

* Move helper function to anonymous namespace

* Add more types

* Use static_cast instead of the old one

* Style fix

* Skip tests on GPU

* Add more tests

* Skip tests on gpu

* Change bool to char
  • Loading branch information
tsocha authored and rkimballn1 committed May 17, 2019
1 parent 3d28d06 commit 9d50951
Show file tree
Hide file tree
Showing 4 changed files with 360 additions and 97 deletions.
204 changes: 107 additions & 97 deletions src/ngraph/runtime/cpu/builder/gather.cpp
Expand Up @@ -29,116 +29,126 @@ namespace ngraph
{
namespace cpu
{
namespace
{
template <typename T>
CPUKernelFunctor prepare_functor(const Node* node,
const vector<TensorViewWrapper>& args,
const vector<TensorViewWrapper>& out,
CPU_ExternalFunction* external_function)
{
const ngraph::op::Gather* gather = static_cast<const ngraph::op::Gather*>(node);
auto params_buffer_index =
external_function->get_buffer_index(args[0].get_name());
auto indices_buffer_index =
external_function->get_buffer_index(args[1].get_name());
auto out_buffer_index = external_function->get_buffer_index(out[0].get_name());

bool is_int64 = args[1].get_element_type() == element::i64;
auto axis = gather->get_axis();
auto params_shape = args[0].get_shape();
auto indices_shape = args[1].get_shape();
auto out_shape = out[0].get_shape();

if (is_int64)
{
return
[&,
params_shape,
indices_shape,
out_shape,
axis,
params_buffer_index,
indices_buffer_index,
out_buffer_index](CPURuntimeContext* ctx, CPUExecutionContext* ectx) {
ngraph::runtime::reference::gather<T, int64_t>(
static_cast<T*>(ctx->buffer_data[params_buffer_index]),
static_cast<int64_t*>(ctx->buffer_data[indices_buffer_index]),
static_cast<T*>(ctx->buffer_data[out_buffer_index]),
params_shape,
indices_shape,
out_shape,
axis);
};
}
else
{
return
[&,
params_shape,
indices_shape,
out_shape,
axis,
params_buffer_index,
indices_buffer_index,
out_buffer_index](CPURuntimeContext* ctx, CPUExecutionContext* ectx) {
ngraph::runtime::reference::gather<T, int32_t>(
static_cast<T*>(ctx->buffer_data[params_buffer_index]),
static_cast<int32_t*>(ctx->buffer_data[indices_buffer_index]),
static_cast<T*>(ctx->buffer_data[out_buffer_index]),
params_shape,
indices_shape,
out_shape,
axis);
};
}
}
} // namespace

template <>
void Builder::BUILDER_DECL(ngraph::op::Gather)
{
auto& functors = external_function->get_functors();
const ngraph::op::Gather* gather = static_cast<const ngraph::op::Gather*>(node);
CPUKernelFunctor functor;

auto params_buffer_index = external_function->get_buffer_index(args[0].get_name());
auto indices_buffer_index = external_function->get_buffer_index(args[1].get_name());
auto out_buffer_index = external_function->get_buffer_index(out[0].get_name());
if (args[1].get_element_type() != element::i64 &&
args[1].get_element_type() != element::i32)
{
throw ngraph_error("Unsupported index element type");
}
bool is_int64 = args[1].get_element_type() == element::i64;
auto axis = gather->get_axis();
auto params_shape = args[0].get_shape();
auto indices_shape = args[1].get_shape();
auto out_shape = out[0].get_shape();
auto element_type = args[0].get_element_type();
if (element_type == element::f32)
{
if (is_int64)
{
functor = [&,
params_shape,
indices_shape,
out_shape,
axis,
params_buffer_index,
indices_buffer_index,
out_buffer_index](CPURuntimeContext* ctx,
CPUExecutionContext* ectx) {
ngraph::runtime::reference::gather<float, int64_t>(
static_cast<float*>(ctx->buffer_data[params_buffer_index]),
static_cast<int64_t*>(ctx->buffer_data[indices_buffer_index]),
static_cast<float*>(ctx->buffer_data[out_buffer_index]),
params_shape,
indices_shape,
out_shape,
axis);
};
}
else
{
functor = [&,
params_shape,
indices_shape,
out_shape,
axis,
params_buffer_index,
indices_buffer_index,
out_buffer_index](CPURuntimeContext* ctx,
CPUExecutionContext* ectx) {
ngraph::runtime::reference::gather<float, int32_t>(
static_cast<float*>(ctx->buffer_data[params_buffer_index]),
static_cast<int32_t*>(ctx->buffer_data[indices_buffer_index]),
static_cast<float*>(ctx->buffer_data[out_buffer_index]),
params_shape,
indices_shape,
out_shape,
axis);
};
}
functor = prepare_functor<float>(node, args, out, external_function);
}
else if (element_type == element::f64)
{
if (is_int64)
{
functor = [&,
params_shape,
indices_shape,
out_shape,
axis,
params_buffer_index,
indices_buffer_index,
out_buffer_index](CPURuntimeContext* ctx,
CPUExecutionContext* ectx) {
ngraph::runtime::reference::gather<double, int64_t>(
static_cast<double*>(ctx->buffer_data[params_buffer_index]),
static_cast<int64_t*>(ctx->buffer_data[indices_buffer_index]),
static_cast<double*>(ctx->buffer_data[out_buffer_index]),
params_shape,
indices_shape,
out_shape,
axis);
};
}
else
{
functor = [&,
params_shape,
indices_shape,
out_shape,
axis,
params_buffer_index,
indices_buffer_index,
out_buffer_index](CPURuntimeContext* ctx,
CPUExecutionContext* ectx) {
ngraph::runtime::reference::gather<double, int32_t>(
static_cast<double*>(ctx->buffer_data[params_buffer_index]),
static_cast<int32_t*>(ctx->buffer_data[indices_buffer_index]),
static_cast<double*>(ctx->buffer_data[out_buffer_index]),
params_shape,
indices_shape,
out_shape,
axis);
};
}
functor = prepare_functor<double>(node, args, out, external_function);
}
else if (element_type == element::i8)
{
functor = prepare_functor<int8_t>(node, args, out, external_function);
}
else if (element_type == element::i16)
{
functor = prepare_functor<int16_t>(node, args, out, external_function);
}
else if (element_type == element::i32)
{
functor = prepare_functor<int32_t>(node, args, out, external_function);
}
else if (element_type == element::i64)
{
functor = prepare_functor<int64_t>(node, args, out, external_function);
}
else if (element_type == element::u8)
{
functor = prepare_functor<uint8_t>(node, args, out, external_function);
}
else if (element_type == element::u16)
{
functor = prepare_functor<uint16_t>(node, args, out, external_function);
}
else if (element_type == element::u32)
{
functor = prepare_functor<uint32_t>(node, args, out, external_function);
}
else if (element_type == element::u64)
{
functor = prepare_functor<uint64_t>(node, args, out, external_function);
}
else if (element_type == element::boolean)
{
functor = prepare_functor<char>(node, args, out, external_function);
}
else
{
Expand All @@ -149,6 +159,6 @@ namespace ngraph
}

REGISTER_OP_BUILDER(Gather);
}
}
}
} // namespace cpu
} // namespace runtime
} // namespace ngraph
9 changes: 9 additions & 0 deletions src/ngraph/runtime/gpu/unit_test.manifest
Expand Up @@ -165,3 +165,12 @@ scatter_add_1d_indices
scatter_add_scalar_indices
scatter_nd_add_batch_2d_to_3d
scatter_nd_add_2d_to_3d
gather_no_axis_int8
gather_no_axis_int16
gather_no_axis_int32
gather_no_axis_int64
gather_no_axis_uint8
gather_no_axis_uint16
gather_no_axis_uint32
gather_no_axis_uint64
gather_no_axis_bool
9 changes: 9 additions & 0 deletions src/ngraph/runtime/intelgpu/unit_test.manifest
Expand Up @@ -79,3 +79,12 @@ scatter_add_scalar_indices
scatter_nd_add_batch_2d_to_3d
scatter_nd_add_2d_to_3d
zero_sized_erf
gather_no_axis_int8
gather_no_axis_int16
gather_no_axis_int32
gather_no_axis_int64
gather_no_axis_uint8
gather_no_axis_uint16
gather_no_axis_uint32
gather_no_axis_uint64
gather_no_axis_bool

0 comments on commit 9d50951

Please sign in to comment.