Skip to content
This repository has been archived by the owner on Jan 3, 2023. It is now read-only.

Commit

Permalink
Rebase updates
Browse files Browse the repository at this point in the history
  • Loading branch information
earhart committed Oct 28, 2018
1 parent c8eba67 commit 2b2bfe7
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 60 deletions.
152 changes: 93 additions & 59 deletions src/ngraph/runtime/plaidml/plaidml_ops_batch_norm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,43 +18,88 @@
#include "ngraph/op/batch_norm.hpp"
#include "ngraph/runtime/plaidml/plaidml_impl.hpp"

// BatchNorm implements batch normalization.
// BatchNormInference implements batch normalization for inference, in
// which the mean and variance to use are supplied.
template <>
void ngraph::runtime::plaidml::Impl<ngraph::op::BatchNorm>::operator()()
void ngraph::runtime::plaidml::Impl<ngraph::op::BatchNormInference>::operator()()
{
// There are two variations of BatchNorm: we produce mean and variance iff they're not supplied.
auto& input_shape = op().get_input_shape(2);
bool given_mean_and_variance = op().get_input_size() == 5;

if (given_mean_and_variance)
{
check_inputs(5);
check_outputs(1);
}
else
{
check_inputs(3);
check_outputs(3);
}
check_inputs(5);
check_outputs(1);

auto f = start_tile_function();
f.add(builder::Input{op_input(0), "Gamma"}.add_dims({"C"}))
.add(builder::Input{op_input(1), "Beta"}.add_dims({"C"}))
.add(builder::Input{op_input(2), "Input"}
.add_dims({"B", "C"})
.add_dims("DI", 3, input_shape.size() + 1))
.add(builder::Output{"Normalized"});
.add(builder::Output{"Normalized"})
.add(builder::Input{op_input(3), "Mean"}.add_dims({"C"}))
.add(builder::Input{op_input(4), "Variance"}.add_dims({"C"}));

std::string ones;
for (auto idx = 2; idx < input_shape.size(); ++idx)
{
ones += ", 1";
}

if (input_shape.size() <= 2)
{
f.add(builder::Elementwise{"GammaP", "Gamma"}).add(builder::Elementwise{"BetaP", "Beta"});
}
else
{
f.add(builder::Elementwise{"GammaP", std::string{"reshape(Gamma, C"} + ones + ")"})
.add(builder::Elementwise{"BetaP", std::string{"reshape(Beta, C"} + ones + ")"});
}

if (input_shape.size() <= 2)
{
f.add(builder::Elementwise{"MeanP", "Mean"});
}
else
{
f.add(builder::Elementwise{"MeanP", std::string{"reshape(Mean, C"} + ones + ")"});
}

if (given_mean_and_variance)
if (input_shape.size() <= 2)
{
f.add(builder::Input{op_input(3), "Mean"}.add_dims({"C"}))
.add(builder::Input{op_input(4), "Variance"}.add_dims({"C"}));
f.add(builder::Elementwise{"VarianceP", "Variance"});
}
else
{
f.add(builder::Output{"Mean"}).add(builder::Output{"Variance"});
f.add(builder::Elementwise{"VarianceP", std::string{"reshape(Variance, C"} + ones + ")"});
}

f.add(builder::Elementwise{"Normalized",
"(((Input-MeanP) / sqrt(VarianceP + " +
std::to_string(op().get_eps_value()) + ")) * GammaP) + BetaP"});

auto app = f.finalize();

set_output(app);
}

// BatchNormTraining implements batch normalization for training, in
// which the mean and variance are to be computed from the supplied
// input.
template <>
void ngraph::runtime::plaidml::Impl<ngraph::op::BatchNormTraining>::operator()()
{
auto& input_shape = op().get_input_shape(2);
check_inputs(3);
check_outputs(3);

auto f = start_tile_function();
f.add(builder::Input{op_input(0), "Gamma"}.add_dims({"C"}))
.add(builder::Input{op_input(1), "Beta"}.add_dims({"C"}))
.add(builder::Input{op_input(2), "Input"}
.add_dims({"B", "C"})
.add_dims("DI", 3, input_shape.size() + 1))
.add(builder::Output{"Normalized"})
.add(builder::Output{"Mean"})
.add(builder::Output{"Variance"});

std::string ones;
for (auto idx = 2; idx < input_shape.size(); ++idx)
{
Expand All @@ -68,32 +113,30 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::BatchNorm>::operator()()
else
{
f.add(builder::Elementwise{"GammaP", std::string{"reshape(Gamma, C"} + ones + ")"})
.add(builder::Elementwise{"BetaP", std::string{"reshape(Beta, C"} + ones + ")"});
.add(builder::Elementwise{"BetaP", std::string{"reshape(Beta, C"} + ones + ")"});
}

if (!given_mean_and_variance)
{
if (input_shape.size() <= 2)
{
f.add(builder::Elementwise{"EltCount", "B"});
}
else
{
std::string elts{"B"};
for (auto idx = 2; idx < input_shape.size(); ++idx)
{
elts += " * DI" + std::to_string(idx + 1);
}
f.add(builder::Elementwise{"EltCount", std::move(elts)});
}
if (input_shape.size() <= 2)
{
f.add(builder::Elementwise{"EltCount", "B"});
}
else
{
std::string elts{"B"};
for (auto idx = 2; idx < input_shape.size(); ++idx)
{
elts += " * DI" + std::to_string(idx + 1);
}
f.add(builder::Elementwise{"EltCount", std::move(elts)});
}

f.add(builder::UnaryContraction{"+"}
f.add(builder::UnaryContraction{"+"}
.set(builder::ContractionOutput{"SumInput"}.add_indices({"c"}).add_dims({"C"}))
.set(builder::ContractionInput{"Input"}
.add_indices({"b", "c"})
.add_indices("di", 3, input_shape.size() + 1)));
f.add(builder::Elementwise{"Mean", "SumInput / EltCount"});
}
f.add(builder::Elementwise{"Mean", "SumInput / EltCount"});


if (input_shape.size() <= 2)
{
Expand All @@ -104,18 +147,15 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::BatchNorm>::operator()()
f.add(builder::Elementwise{"MeanP", std::string{"reshape(Mean, C"} + ones + ")"});
}

if (!given_mean_and_variance)
{
f.add(builder::Elementwise{"DiffV", "(Input - MeanP)"})
.add(builder::Elementwise{"SqDiffV", "DiffV*DiffV"})
.add(builder::UnaryContraction{"+"}
f.add(builder::Elementwise{"DiffV", "(Input - MeanP)"})
.add(builder::Elementwise{"SqDiffV", "DiffV*DiffV"})
.add(builder::UnaryContraction{"+"}
.set(builder::ContractionOutput{"SumSqDiffV"}.add_indices({"c"}).add_dims(
{"C"}))
.set(builder::ContractionInput{"SqDiffV"}
.add_indices({"b", "c"})
.add_indices("di", 3, input_shape.size() + 1)))
.add(builder::Elementwise{"Variance", "SumSqDiffV / EltCount"});
}

if (input_shape.size() <= 2)
{
Expand All @@ -132,20 +172,13 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::BatchNorm>::operator()()

auto app = f.finalize();

if (given_mean_and_variance)
{
set_output(app);
}
else
{
set_output(0, app.get_output(0));
set_output(1, app.get_output(1));
set_output(2, app.get_output(2));
}
set_output(0, app.get_output(0));
set_output(1, app.get_output(1));
set_output(2, app.get_output(2));
}

template <>
void ngraph::runtime::plaidml::Impl<ngraph::op::BatchNormBackprop>::operator()()
void ngraph::runtime::plaidml::Impl<ngraph::op::BatchNormTrainingBackprop>::operator()()
{
// WARNING: I'm unconvinced that we have sufficient test converage for BatchNorm
// backprop and in particular I'm concerned that Gamma/Beta and Mean/Var could be
Expand Down Expand Up @@ -268,7 +301,8 @@ void ngraph::runtime::plaidml::Impl<ngraph::op::BatchNormBackprop>::operator()()

namespace
{
ngraph::runtime::plaidml::Impl<ngraph::op::BatchNorm>::Registration register_batch_norm;
ngraph::runtime::plaidml::Impl<ngraph::op::BatchNormBackprop>::Registration
register_batch_norm_backprop;
ngraph::runtime::plaidml::Impl<ngraph::op::BatchNormInference>::Registration register_batch_norm_inference;
ngraph::runtime::plaidml::Impl<ngraph::op::BatchNormTraining>::Registration register_batch_norm_training;
ngraph::runtime::plaidml::Impl<ngraph::op::BatchNormTrainingBackprop>::Registration
register_batch_norm_training_backprop;
}
13 changes: 12 additions & 1 deletion src/ngraph/runtime/plaidml/unit_test.manifest
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ topk_2d_min_one # No plans to implement TopK
# Tests that PlaidML might be able to run at some point.
backwards_maxpool_n2_c1_hw5_3x3_str2_max_pad1x2_2x3
backwards_slice
batchnorm_fprop_bprop # To debug
batchnorm_fprop_bprop_2step # To debug
reduce_matrix_rows_zero # To debug: possible broadcasting error?
reduce_matrix_cols_zero # To debug: possible broadcasting error?
reduce_3d_to_vector # To debug: possible broadcasting error?
Expand All @@ -58,6 +60,15 @@ select_and_scatter_without_overlap
select_and_scatter_3d_without_overlap
avg_pool_3d
avg_pool_3d_uneven_strided_padded_include_in_computation
dequantize_zero_offset # Quantization/Dequantization is unimplemented
quantize_ROUND_NEAREST_TOWARD_ZERO # Quantization/Dequantization is unimplemented
quantize_ROUND_NEAREST_UPWARD # Quantization/Dequantization is unimplemented
quantize_ROUND_NEAREST_DOWNWARD # Quantization/Dequantization is unimplemented
quantize_ROUND_NEAREST_TOWARD_EVEN # Quantization/Dequantization is unimplemented
quantize_ROUND_TOWARD_INFINITY # Quantization/Dequantization is unimplemented
quantize_ROUND_TOWARD_ZERO # Quantization/Dequantization is unimplemented
quantize_ROUND_UP # Quantization/Dequantization is unimplemented
quantize_ROUND_DOWN # Quantization/Dequantization is unimplemented
quantize # Quantization/Dequantization is unimplemented
quantize_axes # Quantization/Dequantization is unimplemented
quantize_int8 # Quantization/Dequantization is unimplemented
Expand All @@ -72,4 +83,4 @@ sum_matrix_to_scalar_zero_by_zero # Empty dims apparently should produce shape
sum_3d_eliminate_zero_dim # Empty dims apparently should produce shaped 0s
dot_0_0 # Empty dims apparently should produce shaped 0s
dot_matrix_2x0_0x2 # Empty dims apparently should produce shaped 0s
dot_2x0_0 # Empty dims apparently should produce shaped 0s
dot_2x0_0 # Empty dims apparently should produce shaped 0s

0 comments on commit 2b2bfe7

Please sign in to comment.