Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/TensorFlowNET.Core/Gradients/math_grad.cs
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,9 @@ public static Tensor[] _ExpGrad(Operation op, Tensor[] grads)
[RegisterNoGradient("GreaterEqual")]
public static Tensor[] _GreaterEqualGrad(Operation op, Tensor[] grads) => null;

[RegisterNoGradient("OnesLike")]
public static Tensor[] _OnesLike(Operation op, Tensor[] grads) => null;

[RegisterNoGradient("ZerosLike")]
public static Tensor[] _ZerosLike(Operation op, Tensor[] grads) => null;

Expand Down
38 changes: 31 additions & 7 deletions src/TensorFlowNET.Core/Operations/array_ops.cs
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ public static Tensor _autopacking_helper(IEnumerable<object> list_or_tuple, TF_D
{
if (elem is EagerTensor eager_tensor)
{
if(switch_to_graph)
if (switch_to_graph)
elems_as_tensors.Add(constant_op.constant(eager_tensor.numpy(), dtype: dtype, name: i.ToString()));
else
elems_as_tensors.Add(eager_tensor);
Expand Down Expand Up @@ -366,8 +366,30 @@ public static Tensor rank_internal(Tensor input, string name = null, bool optimi
/// <param name="name"></param>
/// <param name="optimize"></param>
/// <returns></returns>
public static Tensor ones_like<T>(T tensor, TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool optimize = true)
=> ones_like_impl(tensor, dtype, name, optimize);
public static Tensor ones_like(Tensor tensor, TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool optimize = true)
{
return tf_with(ops.name_scope(name, "ones_like", new Tensor[] { tensor }), scope =>
{
name = scope;
tensor = ops.convert_to_tensor(tensor, name: "tensor");

// is_fully_defined return unexpected value.
if (optimize && tensor_util.to_shape(tensor.shape).is_fully_defined() && dtype != TF_DataType.TF_VARIANT)
{

}

if (dtype != TF_DataType.DtInvalid && dtype != tensor.dtype && dtype != TF_DataType.TF_VARIANT)
{
throw new NotImplementedException("ones_like");
// return ones(shape_internal(tensor, optimize: optimize), dtype: dtype, name: name);
}
else
{
return gen_array_ops.ones_like(tensor, name: name);
}
});
}

public static Tensor reshape(Tensor tensor, Tensor shape, string name = null)
=> gen_array_ops.reshape(tensor, shape, name: name);
Expand Down Expand Up @@ -888,7 +910,7 @@ public static Tensor transpose<T1>(T1 a, TensorShape perm, string name = "transp
return tf_with(ops.name_scope(name, "transpose", new { a }), scope =>
{
var a_tensor = ops.convert_to_tensor(a);
if(perm == null)
if (perm == null)
{
var rank = a_tensor.rank;
perm = range(0, rank).OrderByDescending(x => x).ToArray();
Expand Down Expand Up @@ -950,7 +972,9 @@ public static Tensor slice(Tensor input, Tensor begin, Tensor size, string name
=> tf.Context.RunInAutoMode2(
() => tf.OpDefLib._apply_op_helper("Slice", name, new
{
input, begin, size
input,
begin,
size
}).output,
() => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"Slice", name,
Expand All @@ -966,8 +990,8 @@ public static Tensor slice(Tensor input, Tensor begin, Tensor size, string name
tf.Runner.RecordGradient("Slice", op.inputs, attrs, op.outputs);
},
new Tensors(input, begin, size));
public static Tensor stack(object values, int axis = 0, string name = "stack")

public static Tensor stack(object values, int axis = 0, string name = "stack")
{
if (axis == 0)
// If the input is a constant list, it can be converted to a constant op
Expand Down
9 changes: 9 additions & 0 deletions src/TensorFlowNET.Core/Operations/gen_array_ops.cs
Original file line number Diff line number Diff line change
Expand Up @@ -591,6 +591,15 @@ public static Tensor transpose<T1>(Tensor x, T1 perm, string name = null)
return _op.outputs[0];
}

public static Tensor ones_like(Tensor x, string name = null)
=> tf.Context.RunInAutoMode(()
=> tf.OpDefLib._apply_op_helper("OnesLike", name, new { x }).output, ()
=> tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"OnesLike", name,
null,
x).FirstOrDefault(),
x);

public static Tensor zeros_like(Tensor x, string name = null)
=> tf.Context.RunInAutoMode(()
=> tf.OpDefLib._apply_op_helper("ZerosLike", name, new { x }).output, ()
Expand Down