Skip to content

Commit

Permalink
Merge pull request #1202 from Wanglongzhi2001/master
Browse files Browse the repository at this point in the history
fix: add the implementation of the tile's and GatherND's grad and add OptionalArgs
  • Loading branch information
Oceania2018 committed Oct 20, 2023
2 parents e79ecb7 + d0ec659 commit 079b9a3
Show file tree
Hide file tree
Showing 7 changed files with 108 additions and 6 deletions.
10 changes: 10 additions & 0 deletions src/TensorFlowNET.Core/APIs/tf.array.cs
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,16 @@ public Tensor identity(Tensor input, string name = null)
public Tensor gather(Tensor @params, Tensor indices, string name = null, int axis = 0)
=> array_ops.gather(@params, indices, name: name, axis: ops.convert_to_tensor(axis));

/// <summary>
/// Gather slices from `params` into a Tensor with shape specified by `indices`.
/// </summary>
/// <param name="params"></param>
/// <param name="indices"></param>
/// <param name="name"></param>
/// <returns></returns>
public Tensor gather_nd(Tensor @params, Tensor indices, string name = null)
=> gen_array_ops.gather_nd(@params, indices, name: name);

/// <summary>
/// Return the elements, either from `x` or `y`, depending on the `condition`.
/// </summary>
Expand Down
43 changes: 43 additions & 0 deletions src/TensorFlowNET.Core/Gradients/array_grad.cs
Original file line number Diff line number Diff line change
Expand Up @@ -381,5 +381,48 @@ public static Tensor[] _ReverseV2Grad(Operation op, Tensor[] grads)
var axis = op.inputs[1];
return new Tensor[] { array_ops.reverse(grad, axis), null };
}

[RegisterGradient("Tile")]
public static Tensor[] _TileGrad(Operation op, Tensor[] grads)
{
var grad = grads[0];
var input_shape = array_ops.shape(op.inputs[0], out_type: op.inputs[1].dtype);
var split_shape = array_ops.reshape(array_ops.transpose(array_ops.stack(new Tensor[] { op.inputs[1], input_shape })), new Shape(-1));
var axes = math_ops.range(0, array_ops.size(split_shape), 2);

//# Sum reduces grad along the first dimension for IndexedSlices
//if isinstance(grad, indexed_slices_lib.IndexedSlices):
//input_shape_0 = math_ops.cast(input_shape[0], grad.indices.dtype)
//grad = math_ops.unsorted_segment_sum(
// grad.values, math_ops.mod(grad.indices, input_shape_0), input_shape_0)
//split_shape = array_ops.concat([[1], split_shape[1:]], axis = 0)

var input_grad = math_ops.reduce_sum(array_ops.reshape(grad, split_shape), axes);
if (!tf.Context.executing_eagerly())
{
input_grad.set_shape(op.inputs[0].GetShape());
}
return new Tensor[] { input_grad, null };
}

[RegisterGradient("GatherNd")]
public static Tensor[] _GatherNdGrad(Operation op, Tensor[] grads)
{
var @ref = op.inputs[0];
var indices = op.inputs[1];
var grad = grads[0];
var ref_shape = array_ops.shape(@ref, out_type: indices.dtype);
Tensor ref_grad = null;
if (indices.shape.ndim == 2 && indices.shape.dims[indices.shape.Length - 1] == 1)
{
ref_grad = (Tensor)new IndexedSlices(grad, array_ops.squeeze(indices, axis: -1), ref_shape);
}
else
{
ref_grad = gen_array_ops.scatter_nd(indices, grad, ref_shape);
}
return new Tensor[] { ref_grad, null };
}

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,8 @@

namespace Tensorflow.Keras.ArgsDefinition
{
public class GRUOptionalArgs
public class GRUOptionalArgs : RnnOptionalArgs
{
public string Identifier => "GRU";

public Tensor Mask { get; set; } = null;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Keras.ArgsDefinition.Rnn
{
public class LSTMOptionalArgs : RnnOptionalArgs
{
public string Identifier => "LSTM";
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Keras.ArgsDefinition.Rnn
{
public class SimpleRNNOptionalArgs : RnnOptionalArgs
{
public string Identifier => "SimpleRNN";
}
}
4 changes: 2 additions & 2 deletions src/TensorFlowNET.Core/Operations/array_ops.cs
Original file line number Diff line number Diff line change
Expand Up @@ -829,7 +829,7 @@ public static Tensor stop_gradient(Tensor input, string name = null)
/// <returns>A `Tensor`. Has the same type as `input`.
/// Contains the same data as `input`, but has one or more dimensions of
/// size 1 removed.</returns>
public static Tensor squeeze(Tensor input, int[] axis = null, string name = null)
public static Tensor squeeze(Tensor input, Axis axis = null, string name = null)
=> gen_array_ops.squeeze(input, axis, name);

public static Tensor identity(Tensor input, string name = null)
Expand Down Expand Up @@ -990,7 +990,7 @@ public static Tensor gather(ResourceVariable @params, Tensor indices, string nam
return @params.sparse_read(indices, name);
}

public static Tensor transpose<T1>(T1 a, Axis perm, string name = "transpose", bool conjugate = false)
public static Tensor transpose<T1>(T1 a, Axis perm = null, string name = "transpose", bool conjugate = false)
{
return tf_with(ops.name_scope(name, "transpose", new { a }), scope =>
{
Expand Down
31 changes: 30 additions & 1 deletion test/TensorFlowNET.UnitTest/GradientTest/GradientEagerTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ public void SquaredDifference_1D()
// Calcute the gradient of (x1-x2)^2
// by Automatic Differentiation in Eager mode
// Expected is 2*(abs(x1-x2))
Tensor x1 = new NDArray( new float[] { 1, 3, 5, 21, 19, 17 });
Tensor x1 = new NDArray(new float[] { 1, 3, 5, 21, 19, 17 });
Tensor x2 = new NDArray(new float[] { 29, 27, 23, 7, 11, 13 });
float[] expected = new float[]
{
Expand Down Expand Up @@ -173,5 +173,34 @@ public void ConditionalMultiply()
var result = grad(x, 4);
Assert.AreEqual((float)result, 4.0f);
}

[TestMethod]
public void Tile()
{
var a = tf.constant(new int[] { 1 }, TF_DataType.TF_FLOAT);
var b = tf.constant(new int[] { 2 });
using (var tape = tf.GradientTape())
{
tape.watch(a);
var y = tf.tile(a, b);
var grad = tape.gradient(y, a);
Assert.AreEqual((float)grad.numpy(), 2.0f);
}
}

[TestMethod]
public void GatherNdTest()
{
var x = tf.constant(new float[,] { { 1.0f, 2.0f, 3.0f }, { 1.0f, 2.0f, 3.0f }, { 1.0f, 2.0f, 3.0f } }, dtype: TF_DataType.TF_FLOAT);
var indices = tf.constant(new int[,] { { 0, 1 }, { 1, 1 }, { 2, 1 } }, dtype: TF_DataType.TF_INT32);
using (var tape = tf.GradientTape())
{
tape.watch(x);
var res = tf.gather_nd(x, indices);
var grad = tape.gradient(res, x);
var expected = np.array(new float[,] { { 0f, 1f, 0f }, { 0f, 1f, 0f }, { 0f, 1f, 0f } });
Assert.IsTrue(Enumerable.SequenceEqual(grad.ToArray<float>(), expected.ToArray<float>()));
}
}
}
}

0 comments on commit 079b9a3

Please sign in to comment.