From 8bcdc780aa3a50b55b8442acda6c3ceab170dcdd Mon Sep 17 00:00:00 2001 From: Alexander Mishunin Date: Wed, 9 Mar 2022 23:31:12 +0300 Subject: [PATCH] Add test case for tf.reduce_sum(..., axis = ...) --- .../GradientTest/GradientTest.cs | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/test/TensorFlowNET.Graph.UnitTest/GradientTest/GradientTest.cs b/test/TensorFlowNET.Graph.UnitTest/GradientTest/GradientTest.cs index 8dac1131d..12ad58e15 100644 --- a/test/TensorFlowNET.Graph.UnitTest/GradientTest/GradientTest.cs +++ b/test/TensorFlowNET.Graph.UnitTest/GradientTest/GradientTest.cs @@ -175,6 +175,24 @@ void test(string name, Func tfF, Func new[] { -1.0, 1.0 }); } + [TestMethod] + public void testReduceSumGradients() + { + var x = tf.placeholder(tf.float64, shape: new Shape(1, 1)); + var m = tf.broadcast_to(x, new Shape(2, 3)); + var g0 = tf.gradients(tf.reduce_sum(m), x)[0]; + var g1 = tf.gradients(tf.reduce_sum(m, axis: 0), x)[0]; + var g2 = tf.gradients(tf.reduce_sum(m, axis: 1), x)[0]; + + using (var session = tf.Session()) + { + var (r0, r1, r2) = session.run((g0, g1, g2), new FeedItem(x, 1.0)); + self.assertFloat64Equal(6.0, r0[0], $"tf.reduce_sum(...)"); + self.assertFloat64Equal(2.0, r1[0], $"tf.reduce_sum(..., axis = 0)"); + self.assertFloat64Equal(3.0, r2[0], $"tf.reduce_sum(..., axis = 1)"); + } + } + [TestMethod] public void testTanhGradient() {