Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[bug] fix higher grad log #15120

Merged
merged 7 commits into from Jun 20, 2019
Merged

Conversation

kshitij12345
Copy link
Contributor

There is a bug in the implementation for higher order gradient of log.

https://github.com/apache/incubator-mxnet/blob/7b343d1fcde73b61322985580080333d9eee9e82/src/operator/tensor/elemwise_unary_op_basic.cc#L1077-L1079

We multiply gx * gx where gx = ograd * f'(x), getting ograd^2 * f'(x)^2, however we want only ograd * f'(x)^2 which can be achieved in a similar fashion to the implementation of _backward_log10/2.

I have validated the expected results for grad on x.

from mxnet import nd, autograd
import numpy
import math

op = lambda x: nd.log(x)
grad_op = lambda x: 1/x
grad_grad_op = lambda x: (-1/x**2)

x = nd.random.normal(0,1,(3,3))
head_grads = nd.random.normal(0,1,(3,3))
head_grad_grads = nd.random.normal(0,1,(3,3))

x.attach_grad()
head_grads.attach_grad()


with autograd.record():
  y = nd.cos(x)
  y_grad = autograd.grad(y, x, head_grads=head_grads, create_graph=True, retain_graph=True)[0]
y_grad.backward(head_grad_grads)


expected_grad_grad_x = (grad_grad_op(x) * head_grad_grads * head_grads).asnumpy()
expected_head_grad = (grad_op(x) * head_grad_grads).asnumpy()

numpy.testing.assert_allclose(x.grad.asnumpy(), expected_grad_grad_x, rtol=1e-7, atol=1e-7)
numpy.testing.assert_allclose(head_grads.grad.asnumpy(), expected_head_grad, rtol=1e-7, atol=1e-7)

Which fails with current code.

Have confirmed the behaviour with Pytorch as well.

import torch
import numpy
import math

op = lambda x: torch.log(x)
grad_op = lambda x: 1/x
grad_grad_op = lambda x: (-1/x**2)

x = torch.randn(2,3)
head_grads = torch.randn(2,3)
head_grad_grads = torch.randn(2,3)
x.requires_grad = True
head_grads.requires_grad = True

y = op(x)
y_grad = torch.autograd.grad(y, x, grad_outputs= head_grads, create_graph=True, retain_graph=True)[0]
y_grad.backward(head_grad_grads)

expected_grad_grad_x = (grad_grad_op(x) * head_grad_grads * head_grads).detach().numpy()
expected_head_grad = (grad_op(x) * head_grad_grads).detach().numpy()

numpy.testing.assert_allclose(x.grad.detach().numpy(), expected_grad_grad_x, rtol=1e-7, atol=1e-7)
numpy.testing.assert_allclose(head_grads.grad.detach().numpy(), expected_head_grad, rtol=1e-7, atol=1e-7)```

* bug: the head_grads were not preserved in higher order.
* add test to validate the fix of the same.
@kshitij12345
Copy link
Contributor Author

@apeforest @larroy please review

auto ggx = MakeNode("negative", n->attrs.name + "_backward_grad_grad",
{nnvm::NodeEntry{ggx_mid}}, nullptr, &n);

std::vector<nnvm::NodeEntry> ret;

ret.emplace_back(MakeNode("elemwise_mul", n->attrs.name + "_backward_grad_grad",
{ograds[0], gx}, nullptr, &n));
{ograds[0], nnvm::NodeEntry{g_lx}}, nullptr, &n));
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://github.com/apache/incubator-mxnet/blob/37ce3b87268a8154f5c0ad97ce2522478038ee06/tests/python/unittest/test_higher_order_grad.py#L102

I am having trouble with head_grads.grad which is being returned as 0's (I guess they are somehow not being updated) while I expect it to be the output of this line.
Please help.

Copy link
Contributor

@larroy larroy Jun 4, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi. What do you mean by head_grads.grad? NodeEntry doesn't have a grad field. Could you clarify? Are you referring to the python code below? The gradient is always 0 when attach_grad() is called. The value is updated after running backward on an output, or using autograd.grad.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still looking into this. The first output should be the gradient of y_grad. However, the head_grads.grad does not get the value. I suspect the returned value from this function is dropped in the gradient calculation in imperative.cc. I will look more into this. Stay tuned.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the confusion. I forgot to add the line from the test file.
Sure waiting to know what you find.

@piyushghai
Copy link
Contributor

Thanks for your contributions @kshitij12345
@mxnet-label-bot Add [pr-awaiting-review, Operator]

@marcoabreu marcoabreu added Operator pr-awaiting-review PR is waiting for code review labels Jun 4, 2019
@larroy
Copy link
Contributor

larroy commented Jun 4, 2019

Hi @kshitij12345 thanks for looking into this.

I think we need to clarify what exactly we have in the first parameter of FGradient "node". We were a bit puzzled with @apeforest looking at your PR. I validated the results with the tests but I think I tried only one log, don't remember which base. But the result seemed correct to me, I guess I missed this problem.

Why do you say that node is ograd*f'(x)? the node argument I understand is the node to calculate the gradient for, in this case we are calculating the gradient of the backward of the log. So are you saying that by chain rule, the node is ograd(of log) * d (log(x)) / dx = ograd * reciprocal?

Would be great if we could add this to the documentation, either to the FGradient typedef or to new_op. Otherwise I always have to dig through the code to refresh this. I think is poorly documented and tricky.


# Validate the gradients.
assert_almost_equal(expected_grad_grad, x.grad.asnumpy())
assert_almost_equal(expected_heads_grad, head_grads.grad.asnumpy())
Copy link
Contributor

@larroy larroy Jun 4, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now I understand your question, i don't think anything is updating head_grads.grad here (this is done when running backward). Why do you want to set the head gradients manually? To verify your fix?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you try

y_grad_grad = autograd.grad(y_grad, x, ..., create_graph = False...)[0]

and in validation
assert_almost_equal(expected_heads_grad, y_grad_grad.asnumpy())

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah to verify the fix.
I expected y_grad.backward(head_grad_grads) to update the head_grads.grad similar to the Pytorch Script from the description.

Thanks for the suggestion,
I will surely try that.

Copy link
Contributor

@larroy larroy Jun 5, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

y_grad.backward(head_grad_grads) indicate that head_grad_grads are the head gradients passed from "upstream". Calling (output variable).backward It will update all the independent input variables (from which those output are dependent), which have attached gradient. In this case head_grad_grads is not an input to the graph, so your problem that the grad doesn't get updated is expected:

https://github.com/apache/incubator-mxnet/blob/master/python/mxnet/ndarray/ndarray.py#L2188
https://github.com/apache/incubator-mxnet/blob/master/python/mxnet/autograd.py#L270

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are checking gradients for head_grads (not head_grad_grads), which is used to compute x_grad, so I believe we should accumulate some gradient in head_grads

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the default behaviour in python vs mxnet is different with respect accumulation of gradients (pytorch: add) mxnet: write. Having said that, I still don't understand why do you expect gradient accumulation in head_grads.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh. I just expected it to have gradients ( by accumulation or writing ), as it is / its value is used while computing the x_grad. But from your and @apeforest's explanation, I kinda understand the behaviour better.
Thank You for digging in and explaining.

auto gx = nnvm::NodeEntry{n};
auto gx_mul_head_grads = nnvm::NodeEntry{n}; // f'(x) * head_grads
auto head_grads = nnvm::NodeEntry{n->inputs[0]};
auto g_lx = MakeNode("reciprocal", n->attrs.name + "_backward_log_grad",
Copy link
Contributor

@larroy larroy Jun 4, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we add a comment about the inputs and what is g_lx? it would help reason about the code. Are the inputs of n (backward_log)

  • 0: input gradient
  • 1: x
    ?

So g_lx is a node having 1/x ? or the derivative of the log right? can we rename to g_logx ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure thing.

@kshitij12345
Copy link
Contributor Author

So are you saying that by chain rule, the node is ograd(of log) * d (log(x)) / dx = ograd * reciprocal?

Following the unary_bwd

https://github.com/apache/incubator-mxnet/blob/b0be6c50ae686a59c340c56b9c11493a5c84ff72/src/operator/tensor/elemwise_unary_op.h#L400-L402

which leads here.

https://github.com/apache/incubator-mxnet/blob/2fd4720342f87be80f280176ffbecb8f37e49b73/src/operator/mxnet_op.h#L520-L538

which fuses the output_grad * func_grad where func_grad our current case is d (log(x))/ dx operation.

Thus the node is multiplication of ograd * d(log(x)) / dx, which also implies that for every second order gradient function we have to do some operation to get only func_grad ( used for the gradient of ograd).

Also I had observed some weird behaviour for functions wrapped with ElemGradUseOut. Will post about it later in the day.

Would be great if we could add this to the documentation, either to the FGradient typedef or to new_op. Otherwise I always have to dig through the code to refresh this. I think is poorly documented and tricky.

Yes it is. True there should be some documentation regarding the same.

@apeforest
Copy link
Contributor

apeforest commented Jun 5, 2019

@kshitij12345 I have some question about the equation expected_head_grad = (grad_op(x) * head_grad_grads).asnumpy() in your test.

My understanding from the chain rule:

Given y =f(x)
dL/dx = dL/dy * dy/dx -->  this is the first backward pass. Let dL/dy be y_grad, we get dL/dx (noted as x_grad)

Now we rewrite the above the equation:

input0: y_grad
input1: x
output: x_grad = y_grad * f'(x)

Another backward pass for this would be:
dL/d y_grad = dL/d x_grad * f'(x)
dL/dx = dL/d x_grad * y_grad * f''(x)

What is the meaning of dL/d y_grad? Are we treating y_grad as another input variable here?

Many thanks for your clarification.

@apeforest
Copy link
Contributor

apeforest commented Jun 5, 2019

I did some more probing. I think the reason that head_grads.grad being all zeros is the variable head_grads was not specified during the second backward pass. I updated the test as follows, but got an assertion error:

grad_grad_out = autograd.grad(heads=xgrad, variables=[x, head_grads], head_grads=head_grads_grads, create_graph=False, retain_graph=True)
print(grad_grad_out[0].asnumpy()))

During the computation graph traversal, it complains that the variable head_grads is unreachable from the output which I think is reasonable. This again comes back to my question above, what is the mathematical meaning of head_grads.grad and why do we need this value?

@larroy
Copy link
Contributor

larroy commented Jun 5, 2019

@apeforest see my reply above, head_grads is not in the graph so it's not updated during backward

@apeforest
Copy link
Contributor

@larroy Yes, I agree with your reply. Also, I don't understand the meaning (or need) to return dL/dy_grad. @kshitij12345 Please comment. Thanks

y_grad = autograd.grad(y, x, create_graph=True, retain_graph=True)[0]
y_grad.backward()
assert_almost_equal(expect_grad_grad.asnumpy(), x.grad.asnumpy())
y_grad = autograd.grad(y, x, head_grads=head_grads,
Copy link
Contributor

@apeforest apeforest Jun 5, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This variable is actually dL/dx, maybe rename it to x_grad for better readability?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yes. Will do that.

x.attach_grad()

# Manual head_grads.
head_grads = nd.random.normal(shape=x.shape)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rename this to y_grad as it is dL/dy?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure.


# Manual head_grads.
head_grads = nd.random.normal(shape=x.shape)
head_grad_grads = nd.random.normal(shape=x.shape)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still don't understand what this variable is mathematically...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

head_grads is just the input node in the graph for x_grad.
head_grad_grads is just to check the validity of the chain rule/backprop.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the clarification.

@larroy larroy mentioned this pull request Jun 5, 2019
5 tasks
@larroy
Copy link
Contributor

larroy commented Jun 5, 2019

@kshitij12345 I have some question about the equation expected_head_grad = (grad_op(x) * head_grad_grads).asnumpy() in your test.

My understanding from the chain rule:

Given y =f(x)
dL/dx = dL/dy * dy/dx -->  this is the first forward pass. Let dL/dy be y_grad, we get dL/dx (noted as x_grad)

Now we rewrite the above the equation:

input0: y_grad
input1: x
output: x_grad = y_grad * f'(x)

Another backward pass for this would be:
dL/d y_grad = dL/d x_grad * f'(x)
dL/dx = dL/d x_grad * y_grad * f''(x)

What is the meaning of dL/d y_grad? Are we treating y_grad as another input variable here?

Many thanks for your clarification.

@kshitij12345 I have some question about the equation expected_head_grad = (grad_op(x) * head_grad_grads).asnumpy() in your test.

My understanding from the chain rule:

Given y =f(x)
dL/dx = dL/dy * dy/dx -->  this is the first forward pass. Let dL/dy be y_grad, we get dL/dx (noted as x_grad)

Now we rewrite the above the equation:

input0: y_grad
input1: x
output: x_grad = y_grad * f'(x)

Another backward pass for this would be:
dL/d y_grad = dL/d x_grad * f'(x)
dL/dx = dL/d x_grad * y_grad * f''(x)

What is the meaning of dL/d y_grad? Are we treating y_grad as another input variable here?

Many thanks for your clarification.

As per back accumulation of gradients and chain rule we always have the incoming gradient (also called head gradient or output gradient). So the second backward pass should calculate:

CodeCogsEqn (1)

I'm thinking that maybe the problem is that we should not reuse the head gradient from the first gradient in the second gradient. Shouldn't the two head gradients be independent variables? Let me know what you think.

graph

https://dreampuf.github.io/GraphvizOnline/#digraph%20G%20%7B%0A%20%20subgraph%20cluster_1%20%7B%20%0A%20%20%20%20style%3Dfilled%3B%0A%20%20%20%20color%3Dazure1%3B%0A%20%20%20%20x%20-%3E%20log%20-%3E%20y%3B%0A%20%20%20%20label%20%3D%20%22Fwd%22%3B%0A%20%20%7D%0A%20%20subgraph%20cluster_2%20%7B%0A%20%20%20%20style%3Dfilled%3B%0A%20%20%20%20color%3Dazure2%3B%0A%20%20%20%20ograd%20-%3E%20log_bwd%3B%0A%20%20%20%20x%20-%3E%20log_bwd%20-%3E%20x_grad%3B%0A%20%20%20%20%0A%20%20%20%20label%20%3D%20%22bwd%22%3B%0A%20%20%7D%0A%20%20subgraph%20cluster_3%20%7B%0A%20%20%20%20style%3Dfilled%3B%0A%20%20%20%20color%3Dazure3%3B%0A%20%20%20%20ograd2%20-%3E%20log_bwd_bwd%3B%0A%20%20%20%20x%20-%3E%20log_bwd_bwd%20-%3E%20x_grad_grad%3B%0A%20%20%20%20%0A%20%20%20%20label%20%3D%20%22bwd%22%3B%0A%20%20%7D%0A%20%20log%20%5Bshape%3Dsquare%5D%3B%0A%20%20log_bwd%20%5Bshape%3Dsquare%5D%3B%0A%7D

https://www.codecogs.com/latex/eqneditor.php

@kshitij12345
Copy link
Contributor Author

kshitij12345 commented Jun 5, 2019

@kshitij12345 I have some question about the equation expected_head_grad = (grad_op(x) * head_grad_grads).asnumpy() in your test.

My understanding from the chain rule:

Given y =f(x)
dL/dx = dL/dy * dy/dx -->  this is the first backward pass. Let dL/dy be y_grad, we get dL/dx (noted as x_grad)

Now we rewrite the above the equation:

input0: y_grad
input1: x
output: x_grad = y_grad * f'(x)

Another backward pass for this would be:
dL/d y_grad = dL/d x_grad * f'(x)
dL/dx = dL/d x_grad * y_grad * f''(x)

What is the meaning of dL/d y_grad? Are we treating y_grad as another input variable here?

Many thanks for your clarification.

I was following on the basis of this graph that I had in my mind.

second_backward

Even I am not sure about the mathematical meaning of it. Also in usual scenario it would essentially be an intermediate node. Like in the picture.

However since we are returning it, might as well test for it. And in case of the current test, head_grads as depicted by dL/dy is an input leaf node in the second graph. So I expect that we will be able to accumulate gradients in it (similar to Pytorch).

I believe the graph for second order should be like this.

f_second_order

I hope it makes sense. Thank You.

# Compute expected values.
expected_grad_grad = grad_grad_x.asnumpy() * head_grad_grads.asnumpy() * \
head_grads.asnumpy()
expected_heads_grad = grad_x.asnumpy()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be grad_x.asnumpy() * head_grad_grads.asnumpy()

@apeforest
Copy link
Contributor

@kshitij12345 The computation graph for the second backward pass makes sense to me. As you can see there is only one output from the graph, that is x_grad_grad. It is not clear to me where the output dL/dy_grad is

@kshitij12345
Copy link
Contributor Author

kshitij12345 commented Jun 5, 2019

@kshitij12345 The computation graph for the second backward pass makes sense to me. As you can see there is only one output from the graph, that is x_grad_grad. It is not clear to me where the output dL/dy_grad is

Actually I haven't shown the other graph for ograd_grad. If you take a look at the x_grad, it is computed using ograd and x. So x_grad will change not only with x but also with ograd.

If you focus only on * (multiplication node) while computing x_grad.

d2f29aff-35f8-491e-ab2a-81646f825224

The code below is what I have in my head (which passes the assertion for current code),

def check_second_order_unary(x, op, grad_op, grad_grad_op):
    x = nd.array(x)
    grad_x = grad_op(x)
    grad_grad_x = grad_grad_op(x)
    x.attach_grad()

    # Manual head_grads.
    y_grad = nd.random.normal(shape=x.shape)
    head_grad_grads = nd.random.normal(shape=x.shape)
    y_grad.attach_grad()

    # Perform compute.
    with autograd.record():
        y = op(x)
        x_grad_mid = autograd.grad(y, x, head_grads=nd.ones_like(x),
                               create_graph=True, retain_graph=True)[0]
        x_grad = x_grad_mid * y_grad # Note
    x_grad.backward(head_grad_grads)

    # Compute expected values.
    expected_grad = grad_x.asnumpy() * y_grad.asnumpy() # Note
    expected_grad_grad = grad_grad_x.asnumpy() * head_grad_grads.asnumpy() * \
        y_grad.asnumpy()
    expected_heads_grad = grad_x.asnumpy() * head_grad_grads.asnumpy()

    # Validate the gradients.
    assert_almost_equal(expected_grad, x_grad.asnumpy()) # Note
    assert_almost_equal(expected_grad_grad, x.grad.asnumpy())
    assert_almost_equal(expected_heads_grad, y_grad.grad.asnumpy())

  1. Manual multiplication with y_grad
 x_grad_mid = autograd.grad(y, x, head_grads=nd.ones_like(y), create_graph=True, retain_graph=True)[0]
 x_grad = x_grad_mid * y_grad # Note this part.
 assert_almost_equal(expected_grad, x_grad.asnumpy()) # Passes
  1. y_grad multiplied in c++ definition.
x_grad = autograd.grad(y, x, head_grads=y_grad,
                               create_graph=True, retain_graph=True)[0]
assert_almost_equal(expected_grad, x_grad.asnumpy()) # Passes

So I expect 1. and 2. to have similar behaviour as both of them use y_grad in computation of final
x_grad node.

As in the first case, x_grad = f'(x) * 1 * y_grad.
While in second, x_grad = f'(x) * y_grad.

My point being x_grad is dependent on f'(x) (which in turn depends on x) as well y_grad (as asserted in the code above).

@apeforest
Copy link
Contributor

apeforest commented Jun 6, 2019

@kshitij12345 I think it's because of the design of Python backward API in MXNet. When you specify variables=x, MXNet will only compute gradients for the input variables listed in variables. I did some experiment to make proof of my points:

As in your case 2:

x_grad = autograd.grad(heads=y, variables=x, head_grads=y_grad, create_graph=True, retain_graph=True)[0]

If you perform another backward on x_grad as x_grad.backward(out_grad=head_grads_grads), y_grad is not listed as input variable and therefore it's gradient is zero

As in your case 1:

x_grad = x_grad_mid * y_grad # Note
x_grad.backward(out_grad=head_grad_grads)

You implicitly made y_grad an input variable when calling backward on x_grad. And that is why you will get values in y_grad.grad.

I replaced the backward() method with an explicit autograd.grad() call, which should call the same C++ backend function and result is different.

case 1.1: if I do the following, I again don't get any values for y_grad because the output only contains one gradient variable

out_grad = autograd.grad(heads=x_grad, variables=x, head_grads=head_grad_grads, create_graph=False, retain_graph=False)
print(out_grad[0])   # values equals to expected_grad_grad

case 1.2: I explicitly set y_grad as input variable, I then get the expected result as in your case 1

out_grad = autograd.grad(heads=x_grad, variables=[x, y_grad], head_grads=head_grad_grads, create_graph=False, retain_graph=False)
print(out_grad[0])   # value equals to expected_grad_grad
print(out_grad[1])   # value equals to expected_heads_grad

At this point, I am not sure if this is a bug because the backward API is designed differently from PyTorch. If y_grad is not specified as part of the input variables that need to perform gradient on, it will not get values assigned even if you write y_grad.attach_grad() to it. This seems to be consistent from API spec. Also, given that the value y_grad does not have real useful meanings, I also don't feel the necessity to store it. Please let me know if this makes sense.

Thanks a lot for your careful drawing and insightful discussion.

@apeforest
Copy link
Contributor

As a follow up, I just dumped out the computation graph in case 2. Indeed, the node that used to calculate y_grad.grad is not even in the final symbolic graph because there is no input dependency.

@kshitij12345
Copy link
Contributor Author

Oh thank you very much for explaining what is happening. I get it now.
It makes sense, as passing argument as head_grads doesn't make any explicit computation on the graph.

y_grad may not have any useful meaning as of now, however I believe that we should test for the value using case 1.2, just to verify and catch any mistakes there. Rather than the same thing biting us in the future. Your thoughts? @larroy @apeforest

Also it would be really great if you can tell me how to get the dump of the graph that would be really helpful.

Thank you again for your time and efforts.

@larroy
Copy link
Contributor

larroy commented Jun 7, 2019

@apeforest can you paste your dump of the graph? and share where did you dump it? I was working on a utility to dump the graph.

@apeforest
Copy link
Contributor

Also it would be really great if you can tell me how to get the dump of the graph that would be really helpful.

@kshitij12345 There is a utility function inline void LogMemoryPlan(const nnvm::Graph& g) in exec_utils.h that can dump out the computation graph. However it only works in symbol mode currently and you need to manually copy paste it to imperative.cc to dump out graph in NDArray APIs. Besides, you also need to comment out the const auto& vstorage = g.GetAttr<nnvm::StorageVector>("storage_id") line otherwise you will get a core dump.

In one word, it's not an easy to use debugging utility and @larroy is working to improve it and create a PR soon. Stay tuned :)

@apeforest
Copy link
Contributor

y_grad may not have any useful meaning as of now, however I believe that we should test for the value using case 1.2, just to verify and catch any mistakes there. Rather than the same thing biting us in the future. Your thoughts? @larroy @apeforest

I don't have a clear answer to how to test this value now. Definitely not through a python unit test because the value is not getting output as NDArray. Besides, as I commented earlier, that computation node was even in nnvm graph maybe it was later deleted by graph optimization because it is a dangling node with no output.

* remove assertion for y_grad gradient.
* rename variables.
* fix and update computation.
@kshitij12345
Copy link
Contributor Author

kshitij12345 commented Jun 8, 2019

@apeforest , Thank You for explaining how to get the dump of the graph. Waiting the PR which simplifies that.

Also could you tell me if it is possible to print the Node and it's value (for the array) via C++ and if yes then how .

@apeforest
Copy link
Contributor

@apeforest , Thank You for explaining how to get the dump of the graph. Waiting the PR which simplifies that.

Also could you tell me if it is possible to print the Node and it's value (for the array) via C++ and if yes then how .

We don't yet have a utility function to do that. What do you mean the value of the node? A node only represent a computation in the graph. The values are passed in/out when the graph is traversed and each node invokes the FCompute/FGradient functions registered. I think we can only track values in the corresponding NDArray outputs. Please let me know if I misunderstand your question.

@kshitij12345
Copy link
Contributor Author

kshitij12345 commented Jun 10, 2019

Oops, sorry for the confusion. I meant the value of a NDArray corresponding to a computation.
auto gx = nnvm::NodeEntry{n};, can I get/print the value of the NDArray corresponding to that node.

Thank You.

@larroy
Copy link
Contributor

larroy commented Jun 10, 2019

Right now is very difficult to print the value of the NDArray. I will work on a utility to dump the graph this week, first the graph itself, and if possible the NDArray values.

@apeforest
Copy link
Contributor

There is no available utility to print out value from NDArray. I used to write a for loop to iterate the dptr pointer and dump out values. Besides, we need extra synchronization handling if the NDArray is stored in GPU. I agree with @larroy that this is something very important for debugging MXNet backend. I have created an issue #15198 to ask for contributions.

y_grad = autograd.grad(y, x, create_graph=True, retain_graph=True)[0]
y_grad.backward()
assert_almost_equal(expect_grad_grad.asnumpy(), x.grad.asnumpy())
x_grad = autograd.grad(y, x, head_grads=y_grad,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explicitly specify the argument as

Suggested change
x_grad = autograd.grad(y, x, head_grads=y_grad,
x_grad = autograd.grad(heads=y, variables=x, head_grads=y_grad, create_graph=True, retain_graph=True)

I think this will make it easier to understand.

@apeforest
Copy link
Contributor

@kshitij12345 Could you please rebase and retrigger CI again? Thanks!

* explicitly pass arguments with name.
Copy link
Contributor

@apeforest apeforest left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Let's wait for the CI to pass.

@apeforest
Copy link
Contributor

@kshitij12345 still one GPU test failed. I looked at the log and don't find it related to your change. Could you please rebase and trigger CI one more time. The master branch was broken last weekend. Thanks again.

@kshitij12345
Copy link
Contributor Author

@apeforest Note that the pending job has succeeded but for some reason it isn't updated here.

@apeforest
Copy link
Contributor

@kshitij12345 There was some issue with CI recently. Could you please re-trigger it one more time? Sorry for the inconvenience.

Co-Authored-By: Lin Yuan <apeforest@gmail.com>
@apeforest apeforest merged commit 2b7fbc5 into apache:master Jun 20, 2019
Copy link
Contributor

@larroy larroy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is much clearer, thanks. Nice variable naming.

auto gx = nnvm::NodeEntry{n, 0, 0};
auto g_lx = MakeNode("reciprocal", n->attrs.name + "_backward_log_grad",
auto dydx_mul_dldy = nnvm::NodeEntry{n}; // f'(x) * head_grads
auto dydx = MakeNode("elemwise_div", n->attrs.name + "_dydx",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't elemwise_div require two inputs?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, it does. However I guess, this thing is skipped in the test computation graph and hence we don't see the error.

Will fix it. However we should somehow find a way to test for the same.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
Operator pr-awaiting-review PR is waiting for code review
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants