Skip to content

Commit

Permalink
[DoubleGrad PR #3] Supported higher-order GradNode generation (Paddle…
Browse files Browse the repository at this point in the history
…Paddle#41051)

* [Refactor] refactored eager_gen.py PR #2

* [DoubleGrad PR #1] Decoupled code generation logics for Dygraph ForwardFunctions and GradNodes

* Fixed minor issue

* Adjusted logics of GenerateNodeCreationCodes and GenerateForwardDefinition

* Fixed issues

* Supported higher-order grad node generation

* [DoubleGrad PR #4] Supported higher-order GradNode generation

* Fixed yaml typo
  • Loading branch information
jim19930609 committed Mar 30, 2022
1 parent 489a64e commit abd2df4
Show file tree
Hide file tree
Showing 3 changed files with 263 additions and 62 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,10 @@ def FindForwardName(string):
return string[:-5]


def IsGradName(string):
return string.endswith("_grad")


def IsPlainTensorType(string):
plain_tensor_types = ['Tensor&', 'Tensor', 'const Tensor&', 'const Tensor']
if string in plain_tensor_types:
Expand Down Expand Up @@ -166,6 +170,12 @@ def GetForwardFunctionName(string):
return f"{string}_final_state_dygraph_function"


def TransformGradVarNameForDoubleGradGeneration(string):
if IsGradName(string):
string = "grad_" + string[:-5]
return string


######################
### Yaml Parsers ###
######################
Expand Down
Loading

0 comments on commit abd2df4

Please sign in to comment.