Skip to content

Commit

Permalink
extend gradient function capability.
Browse files Browse the repository at this point in the history
  • Loading branch information
Oceania2018 committed Jun 7, 2019
1 parent 4ff993b commit f13e35d
Show file tree
Hide file tree
Showing 8 changed files with 107 additions and 64 deletions.
13 changes: 13 additions & 0 deletions docs/source/Gradient.md
Original file line number Diff line number Diff line change
@@ -1,2 +1,15 @@
# Chapter. Gradient

### Register custom gradient function

TF.NET is extensible which can be added custom gradient function.

```csharp
// define gradient function
ops.RegisterGradientFunction("ConcatV2", (oper, out_grads) =>
{
var grad = grads[0];
return new Tensor[]{ };
});
```

16 changes: 16 additions & 0 deletions src/TensorFlowNET.Core/Gradients/RegisterGradient.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Gradients
{
public class RegisterGradient : Attribute
{
public string Name { get; set; }

public RegisterGradient(string name)
{
Name = name;
}
}
}
6 changes: 5 additions & 1 deletion src/TensorFlowNET.Core/Gradients/array_grad.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@ namespace Tensorflow.Gradients
/// <summary>
/// tensorflow\python\ops\array_grad.py
/// </summary>
[RegisterGradient("array_grad")]
public class array_grad
{
[RegisterGradient("ConcatV2")]
public static Tensor[] _ConcatGradV2(Operation op, Tensor[] grads)
{
var grad = grads[0];
Expand Down Expand Up @@ -123,12 +125,13 @@ private static Tensor[] _ExtractInputShapes(Tensor[] inputs)
return gen_ops.shape_n(inputs);
}


[RegisterGradient("Reshape")]
public static Tensor[] _ReshapeGrad(Operation op, Tensor[] grads)
{
return new Tensor[] { array_ops.reshape(grads[0], array_ops.shape(op.inputs[0])), null };
}

[RegisterGradient("Squeeze")]
public static Tensor[] _SqueezeGrad(Operation op, Tensor[] grads)
{
return new Tensor[] { _ReshapeToInput(op, grads[0]) };
Expand All @@ -139,6 +142,7 @@ private static Tensor _ReshapeToInput(Operation op, Tensor grad)
return array_ops.reshape(grad, array_ops.shape(op.inputs[0]));
}

[RegisterGradient("Transpose")]
public static Tensor[] _TransposeGrad(Operation op, Tensor[] grads)
{
var p = op.inputs[1];
Expand Down
5 changes: 3 additions & 2 deletions src/TensorFlowNET.Core/Gradients/control_flow_grad.py.cs
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,12 @@ public Tensor[] _SwitchGrad(Tensor op, Tensor[] grads)
// false_grad = switch(grad[0], op.inputs[1])[0]
// true_grad = switch(grad[1], op.inputs[1])[1]
// return merge([false_grad, true_grad])[0], None
}

}

/// <summary>
/// Gradients for a Merge op are calculated using a Switch op.
/// </summary>
[RegisterGradient("Merge")]
public static Tensor[] _MergeGrad(Operation op, Tensor[] grads)
{
var grad = grads[0];
Expand Down
16 changes: 16 additions & 0 deletions src/TensorFlowNET.Core/Gradients/math_grad.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@ namespace Tensorflow.Gradients
/// <summary>
/// Gradients for operators defined in math_ops.py.
/// </summary>
[RegisterGradient("math_grad")]
public class math_grad
{
[RegisterGradient("Add")]
public static Tensor[] _AddGrad(Operation op, Tensor[] grads)
{
var x = op.inputs[0];
Expand All @@ -32,6 +34,7 @@ public static Tensor[] _AddGrad(Operation op, Tensor[] grads)
return new Tensor[] { r1, r2 };
}

[RegisterGradient("DivNoNan")]
public static Tensor[] _DivNoNanGrad(Operation op, Tensor[] grads)
{
var grad = grads[0];
Expand Down Expand Up @@ -59,6 +62,7 @@ public static Tensor[] _DivNoNanGrad(Operation op, Tensor[] grads)
/// <param name="op"></param>
/// <param name="grads"></param>
/// <returns></returns>
[RegisterGradient("Exp")]
public static Tensor[] _ExpGrad(Operation op, Tensor[] grads)
{
var grad = grads[0];
Expand All @@ -69,11 +73,13 @@ public static Tensor[] _ExpGrad(Operation op, Tensor[] grads)
});
}

[RegisterGradient("Identity")]
public static Tensor[] _IdGrad(Operation op, Tensor[] grads)
{
return new Tensor[] { grads[0] };
}

[RegisterGradient("Log")]
public static Tensor[] _LogGrad(Operation op, Tensor[] grads)
{
var grad = grads[0];
Expand All @@ -84,6 +90,7 @@ public static Tensor[] _LogGrad(Operation op, Tensor[] grads)
});
}

[RegisterGradient("Mul")]
public static Tensor[] _MulGrad(Operation op, Tensor[] grads)
{
var x = op.inputs[0];
Expand Down Expand Up @@ -112,6 +119,7 @@ public static Tensor[] _MulGrad(Operation op, Tensor[] grads)
return new Tensor[] { reshape1, reshape2 };
}

[RegisterGradient("MatMul")]
public static Tensor[] _MatMulGrad(Operation op, Tensor[] grads)
{
var grad = grads[0];
Expand Down Expand Up @@ -145,6 +153,7 @@ public static Tensor[] _MatMulGrad(Operation op, Tensor[] grads)
return new Tensor[] { grad_a, grad_b };
}

[RegisterGradient("Mean")]
public static Tensor[] _MeanGrad(Operation op, Tensor[] grads)
{
var grad = grads[0];
Expand All @@ -159,6 +168,7 @@ public static Tensor[] _MeanGrad(Operation op, Tensor[] grads)
return new Tensor[] { math_ops.truediv(sum_grad, math_ops.cast(factor, sum_grad.dtype)), null };
}

[RegisterGradient("Neg")]
public static Tensor[] _NegGrad(Operation op, Tensor[] grads)
{
return new Tensor[] { -grads[0] };
Expand All @@ -169,6 +179,7 @@ private static Tensor _safe_shape_div(Tensor x, Tensor y)
return math_ops.floordiv(x, gen_math_ops.maximum(y, 1));
}

[RegisterGradient("Sub")]
public static Tensor[] _SubGrad(Operation op, Tensor[] grads)
{
var grad = grads[0];
Expand Down Expand Up @@ -198,6 +209,7 @@ public static bool _ShapesFullySpecifiedAndEqual(Tensor x, Tensor y, Tensor grad
!x_shape.Contains(-1);
}

[RegisterGradient("Sum")]
public static Tensor[] _SumGrad(Operation op, Tensor[] grads)
{
var grad = grads[0];
Expand Down Expand Up @@ -231,6 +243,7 @@ public static Tensor[] _SumGrad(Operation op, Tensor[] grads)
return new Tensor[] { gen_array_ops.tile(grad, tile_scaling), null };
}

[RegisterGradient("RealDiv")]
public static Tensor[] _RealDivGrad(Operation op, Tensor[] grads)
{
var grad = grads[0];
Expand All @@ -254,6 +267,7 @@ public static Tensor[] _RealDivGrad(Operation op, Tensor[] grads)
return new Tensor[] { reshape2, reshape1 };
}

[RegisterGradient("Sigmoid")]
public static Tensor[] _SigmoidGrad(Operation op, Tensor[] grads)
{
var grad = grads[0];
Expand All @@ -266,6 +280,7 @@ public static Tensor[] _SigmoidGrad(Operation op, Tensor[] grads)
});
}

[RegisterGradient("Square")]
public static Tensor[] _SquareGrad(Operation op, Tensor[] grads)
{
var grad = grads[0];
Expand All @@ -279,6 +294,7 @@ public static Tensor[] _SquareGrad(Operation op, Tensor[] grads)
});
}

[RegisterGradient("Pow")]
public static Tensor[] _PowGrad(Operation op, Tensor[] grads)
{
var grad = grads[0];
Expand Down
7 changes: 7 additions & 0 deletions src/TensorFlowNET.Core/Gradients/nn_grad.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ namespace Tensorflow.Gradients
/// <summary>
///
/// </summary>
[RegisterGradient("math_grad")]
public class nn_grad
{
/// <summary>
Expand All @@ -17,6 +18,7 @@ public class nn_grad
/// <param name="op"></param>
/// <param name="grad"></param>
/// <returns></returns>
[RegisterGradient("BiasAdd")]
public static Tensor[] _BiasAddGrad(Operation op, Tensor[] grads)
{
var grad = grads[0];
Expand All @@ -25,6 +27,7 @@ public static Tensor[] _BiasAddGrad(Operation op, Tensor[] grads)
return new Tensor[] { grad, bias_add_grad };
}

[RegisterGradient("Relu")]
public static Tensor[] _ReluGrad(Operation op, Tensor[] grads)
{
return new Tensor[] { gen_nn_ops.relu_grad(grads[0], op.outputs[0]) };
Expand All @@ -36,6 +39,7 @@ public static Tensor[] _ReluGrad(Operation op, Tensor[] grads)
/// <param name="op"></param>
/// <param name="grads"></param>
/// <returns></returns>
[RegisterGradient("Softmax")]
public static Tensor[] _SoftmaxGrad(Operation op, Tensor[] grads)
{
var grad_softmax = grads[0];
Expand All @@ -54,6 +58,7 @@ public static Tensor[] _SoftmaxGrad(Operation op, Tensor[] grads)
/// <param name="grad_loss"></param>
/// <param name="grad_grad"></param>
/// <returns></returns>
[RegisterGradient("SoftmaxCrossEntropyWithLogits")]
public static Tensor[] _SoftmaxCrossEntropyWithLogitsGrad(Operation op, Tensor[] grads)
{
var grad_loss = grads[0];
Expand All @@ -74,6 +79,7 @@ public static Tensor[] _SoftmaxCrossEntropyWithLogitsGrad(Operation op, Tensor[]
};
}

[RegisterGradient("SparseSoftmaxCrossEntropyWithLogits")]
public static Tensor[] _SparseSoftmaxCrossEntropyWithLogitsGrad(Operation op, Tensor[] grads)
{
var sparse_softmax_grad_without_gradient = array_ops.prevent_gradient(
Expand Down Expand Up @@ -111,6 +117,7 @@ private static Tensor _BroadcastMul(Tensor vec, Tensor mat)
/// <param name="op"></param>
/// <param name="grads"></param>
/// <returns></returns>
[RegisterGradient("TopK")]
public static Tensor[] _TopKGrad(Operation op, Tensor[] grads)
{
var grad = grads[0];
Expand Down
105 changes: 45 additions & 60 deletions src/TensorFlowNET.Core/Gradients/ops.gradient_function_mapping.cs
Original file line number Diff line number Diff line change
@@ -1,80 +1,65 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Reflection;
using System.Text;
using Tensorflow.Gradients;

namespace Tensorflow
{
public partial class ops
{
static Dictionary<string, Func<Operation, Tensor[], Tensor[]>> gradientFunctions = null;

/// <summary>
/// Regiter new gradient function
/// </summary>
/// <param name="name">operation type</param>
/// <param name="func">function delegate</param>
public static void RegisterGradientFunction(string name, Func<Operation, Tensor[], Tensor[]> func)
{
if(gradientFunctions == null)
gradientFunctions = new Dictionary<string, Func<Operation, Tensor[], Tensor[]>>();

gradientFunctions[name] = func;
}

public static Func<Operation, Tensor[], Tensor[]> get_gradient_function(Operation op)
{
if (op.inputs == null) return null;

// map tensorflow\python\ops\math_grad.py
return (oper, out_grads) =>
if (gradientFunctions == null)
{
// Console.WriteLine($"get_gradient_function: {oper.type} '{oper.name}'");
gradientFunctions = new Dictionary<string, Func<Operation, Tensor[], Tensor[]>>();

switch (oper.type)
var gradGroups = Assembly.GetExecutingAssembly()
.GetTypes()
.Where(x => x.GetCustomAttribute<RegisterGradient>() != null)
.ToArray();

foreach (var g in gradGroups)
{
case "Add":
return math_grad._AddGrad(oper, out_grads);
case "BiasAdd":
return nn_grad._BiasAddGrad(oper, out_grads);
case "ConcatV2":
return array_grad._ConcatGradV2(oper, out_grads);
case "DivNoNan":
return math_grad._DivNoNanGrad(oper, out_grads);
case "Exp":
return math_grad._ExpGrad(oper, out_grads);
case "Identity":
return math_grad._IdGrad(oper, out_grads);
case "Log":
return math_grad._LogGrad(oper, out_grads);
case "MatMul":
return math_grad._MatMulGrad(oper, out_grads);
case "Merge":
return control_flow_grad._MergeGrad(oper, out_grads);
case "Mul":
return math_grad._MulGrad(oper, out_grads);
case "Mean":
return math_grad._MeanGrad(oper, out_grads);
case "Neg":
return math_grad._NegGrad(oper, out_grads);
case "Sum":
return math_grad._SumGrad(oper, out_grads);
case "Sub":
return math_grad._SubGrad(oper, out_grads);
case "Pow":
return math_grad._PowGrad(oper, out_grads);
case "RealDiv":
return math_grad._RealDivGrad(oper, out_grads);
case "Reshape":
return array_grad._ReshapeGrad(oper, out_grads);
case "Relu":
return nn_grad._ReluGrad(oper, out_grads);
case "Sigmoid":
return math_grad._SigmoidGrad(oper, out_grads);
case "Square":
return math_grad._SquareGrad(oper, out_grads);
case "Squeeze":
return array_grad._SqueezeGrad(oper, out_grads);
case "Softmax":
return nn_grad._SoftmaxGrad(oper, out_grads);
case "SoftmaxCrossEntropyWithLogits":
return nn_grad._SoftmaxCrossEntropyWithLogitsGrad(oper, out_grads);
case "SparseSoftmaxCrossEntropyWithLogits":
return nn_grad._SparseSoftmaxCrossEntropyWithLogitsGrad(oper, out_grads);
case "Transpose":
return array_grad._TransposeGrad(oper, out_grads);
case "TopK":
case "TopKV2":
return nn_grad._TopKGrad(oper, out_grads);
default:
throw new NotImplementedException($"get_gradient_function {oper.type}");
var methods = g.GetMethods().Where(x => x.GetCustomAttribute<RegisterGradient>() != null)
.ToArray();

foreach (var m in methods)
{
RegisterGradientFunction(m.GetCustomAttribute<RegisterGradient>().Name,
(oper, out_grads) =>
g.InvokeMember(m.Name,
BindingFlags.InvokeMethod,
null,
null,
args: new object[] { oper, out_grads }) as Tensor[]
);
}
}
};
}

if (!gradientFunctions.ContainsKey(op.type))
throw new NotImplementedException($"can't get graident function through get_gradient_function {op.type}");

return gradientFunctions[op.type];
}
}
}
3 changes: 2 additions & 1 deletion src/TensorFlowNET.Core/TensorFlowNET.Core.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ Docs: https://tensorflownet.readthedocs.io</Description>
<AssemblyVersion>0.8.1.0</AssemblyVersion>
<PackageReleaseNotes>Changes since v0.8:

Removed global static graph instance.</PackageReleaseNotes>
1. Removed global static graph instance.
2. Provide custom gradient function.</PackageReleaseNotes>
<LangVersion>7.2</LangVersion>
<FileVersion>0.8.1.0</FileVersion>
</PropertyGroup>
Expand Down

0 comments on commit f13e35d

Please sign in to comment.