-
Notifications
You must be signed in to change notification settings - Fork 503
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
extend gradient function capability.
- Loading branch information
1 parent
4ff993b
commit f13e35d
Showing
8 changed files
with
107 additions
and
64 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,15 @@ | ||
# Chapter. Gradient | ||
|
||
### Register custom gradient function | ||
|
||
TF.NET is extensible which can be added custom gradient function. | ||
|
||
```csharp | ||
// define gradient function | ||
ops.RegisterGradientFunction("ConcatV2", (oper, out_grads) => | ||
{ | ||
var grad = grads[0]; | ||
return new Tensor[]{ }; | ||
}); | ||
``` | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
using System; | ||
using System.Collections.Generic; | ||
using System.Text; | ||
|
||
namespace Tensorflow.Gradients | ||
{ | ||
public class RegisterGradient : Attribute | ||
{ | ||
public string Name { get; set; } | ||
|
||
public RegisterGradient(string name) | ||
{ | ||
Name = name; | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
105 changes: 45 additions & 60 deletions
105
src/TensorFlowNET.Core/Gradients/ops.gradient_function_mapping.cs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,80 +1,65 @@ | ||
using System; | ||
using System.Collections.Generic; | ||
using System.Linq; | ||
using System.Reflection; | ||
using System.Text; | ||
using Tensorflow.Gradients; | ||
|
||
namespace Tensorflow | ||
{ | ||
public partial class ops | ||
{ | ||
static Dictionary<string, Func<Operation, Tensor[], Tensor[]>> gradientFunctions = null; | ||
|
||
/// <summary> | ||
/// Regiter new gradient function | ||
/// </summary> | ||
/// <param name="name">operation type</param> | ||
/// <param name="func">function delegate</param> | ||
public static void RegisterGradientFunction(string name, Func<Operation, Tensor[], Tensor[]> func) | ||
{ | ||
if(gradientFunctions == null) | ||
gradientFunctions = new Dictionary<string, Func<Operation, Tensor[], Tensor[]>>(); | ||
|
||
gradientFunctions[name] = func; | ||
} | ||
|
||
public static Func<Operation, Tensor[], Tensor[]> get_gradient_function(Operation op) | ||
{ | ||
if (op.inputs == null) return null; | ||
|
||
// map tensorflow\python\ops\math_grad.py | ||
return (oper, out_grads) => | ||
if (gradientFunctions == null) | ||
{ | ||
// Console.WriteLine($"get_gradient_function: {oper.type} '{oper.name}'"); | ||
gradientFunctions = new Dictionary<string, Func<Operation, Tensor[], Tensor[]>>(); | ||
|
||
switch (oper.type) | ||
var gradGroups = Assembly.GetExecutingAssembly() | ||
.GetTypes() | ||
.Where(x => x.GetCustomAttribute<RegisterGradient>() != null) | ||
.ToArray(); | ||
|
||
foreach (var g in gradGroups) | ||
{ | ||
case "Add": | ||
return math_grad._AddGrad(oper, out_grads); | ||
case "BiasAdd": | ||
return nn_grad._BiasAddGrad(oper, out_grads); | ||
case "ConcatV2": | ||
return array_grad._ConcatGradV2(oper, out_grads); | ||
case "DivNoNan": | ||
return math_grad._DivNoNanGrad(oper, out_grads); | ||
case "Exp": | ||
return math_grad._ExpGrad(oper, out_grads); | ||
case "Identity": | ||
return math_grad._IdGrad(oper, out_grads); | ||
case "Log": | ||
return math_grad._LogGrad(oper, out_grads); | ||
case "MatMul": | ||
return math_grad._MatMulGrad(oper, out_grads); | ||
case "Merge": | ||
return control_flow_grad._MergeGrad(oper, out_grads); | ||
case "Mul": | ||
return math_grad._MulGrad(oper, out_grads); | ||
case "Mean": | ||
return math_grad._MeanGrad(oper, out_grads); | ||
case "Neg": | ||
return math_grad._NegGrad(oper, out_grads); | ||
case "Sum": | ||
return math_grad._SumGrad(oper, out_grads); | ||
case "Sub": | ||
return math_grad._SubGrad(oper, out_grads); | ||
case "Pow": | ||
return math_grad._PowGrad(oper, out_grads); | ||
case "RealDiv": | ||
return math_grad._RealDivGrad(oper, out_grads); | ||
case "Reshape": | ||
return array_grad._ReshapeGrad(oper, out_grads); | ||
case "Relu": | ||
return nn_grad._ReluGrad(oper, out_grads); | ||
case "Sigmoid": | ||
return math_grad._SigmoidGrad(oper, out_grads); | ||
case "Square": | ||
return math_grad._SquareGrad(oper, out_grads); | ||
case "Squeeze": | ||
return array_grad._SqueezeGrad(oper, out_grads); | ||
case "Softmax": | ||
return nn_grad._SoftmaxGrad(oper, out_grads); | ||
case "SoftmaxCrossEntropyWithLogits": | ||
return nn_grad._SoftmaxCrossEntropyWithLogitsGrad(oper, out_grads); | ||
case "SparseSoftmaxCrossEntropyWithLogits": | ||
return nn_grad._SparseSoftmaxCrossEntropyWithLogitsGrad(oper, out_grads); | ||
case "Transpose": | ||
return array_grad._TransposeGrad(oper, out_grads); | ||
case "TopK": | ||
case "TopKV2": | ||
return nn_grad._TopKGrad(oper, out_grads); | ||
default: | ||
throw new NotImplementedException($"get_gradient_function {oper.type}"); | ||
var methods = g.GetMethods().Where(x => x.GetCustomAttribute<RegisterGradient>() != null) | ||
.ToArray(); | ||
|
||
foreach (var m in methods) | ||
{ | ||
RegisterGradientFunction(m.GetCustomAttribute<RegisterGradient>().Name, | ||
(oper, out_grads) => | ||
g.InvokeMember(m.Name, | ||
BindingFlags.InvokeMethod, | ||
null, | ||
null, | ||
args: new object[] { oper, out_grads }) as Tensor[] | ||
); | ||
} | ||
} | ||
}; | ||
} | ||
|
||
if (!gradientFunctions.ContainsKey(op.type)) | ||
throw new NotImplementedException($"can't get graident function through get_gradient_function {op.type}"); | ||
|
||
return gradientFunctions[op.type]; | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters