Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

If saved_for_backward returns NumberProxy, the value is taken from compile time, not runtime #231

Closed
kiya00 opened this issue Apr 18, 2024 · 13 comments · Fixed by #244, #264 or #481
Closed
Assignees
Labels
bug Something isn't working dynamic constraints help wanted Extra attention is needed transforms

Comments

@kiya00
Copy link
Collaborator

kiya00 commented Apr 18, 2024

🐛 Bug

When I add an operator that returns numbers, the values in the saved_for_backward are the compile time value defined in Symbol.meta, not the real value computed at runtime.

An example trace is like:

@torch.no_grad()
@no_autocast()
def augmented_forward_fn(a):
  # a: "cuda:0 f32[2, 2]"
  t2 = get_rng_state_prim_impl(None, devices.Device("cuda:0"))  # t2: "cpu ui8[16]"
  (i3, i4) = unpack_rng_state_prim_impl(t2)
  del t2
  [t1] = nvFusion0(a, i3, i4)
    # t0 = prims.uniform_philox((2, 2), 0.0, 1.0, device=devices.Device("cuda:0"), dtype=dtypes.float32, seed=i3, offset=i4)  # t0: "cuda:0 f32[2, 2]"
    # t1 = prims.mul(t0, a)  # t1: "cuda:0 f32[2, 2]"
  i6 = operator.add(i4, 4)  # i6: "int 11"
    # i6 = prims.add(i4, 4)  # i6: "int 11"
  del i4
  t7 = pack_rng_state_prim_impl(i3, i6)  # t7: "cpu ui8[16]"
  del i3, i6
  set_rng_state_prim_impl(t7, devices.Device("cuda:0"))
  del t7
  ######### i3, i4 is not passed to backward, but the constant value 7 in the meta function
  return {'output': t1, 'flat_args': [a], 'flat_output': (t1,)}, ((), (7, 7))  

@torch.no_grad()
@no_autocast()
def backward_fn(saved_for_backward, cotangents):
  # saved_for_backward: "Collection"
  # cotangents: "Collection"
  _, C1, = saved_for_backward
  clear_collection(saved_for_backward)
  del saved_for_backward
  t2, = cotangents
  clear_collection(cotangents)
  del cotangents
  i3, i4, = C1
  clear_collection(C1)
  del C1
  [t12] = nvFusion0(i3, i4, t2)
    # t0 = prims.uniform_philox((2, 2), 0.0, 1.0, device=devices.Device("cuda:0"), dtype=dtypes.float32, seed=i3, offset=i4)  # t0: "cuda:0 f32[2, 2]"
    # t12 = prims.mul(t0, t2)  # t12: "cuda:0 f32[2, 2]"
  del i3, i4, t2
  return (t12,)

To Reproduce

Reproduction is based on branch uniform_rng

import torch
import thunder

def func(a):
    b = thunder.torch.uniform_like(a, device=a.device, dtype=a.dtype)
    return b*a

a = torch.randn(2, 2, device="cuda", requires_grad=True)

jfunc = thunder.jit(func)
out = jfunc(a)

print(thunder.last_traces(jfunc)[-1])
print(thunder.last_backward_traces(jfunc)[-1])
@kiya00 kiya00 added bug Something isn't working help wanted Extra attention is needed transforms labels Apr 18, 2024
@kshitij12345
Copy link
Collaborator

It happens due to this line. Not sure if there is any impact of not baking in the value from compile time. The tests in test_grad.py seem to be running fine after removing this line.

saved_for_backward = tree_map(lambda x: x.value if isinstance(x, NumberProxy) else x, saved_for_backward)

I think other option could be to instead return TupleProxy since this function seems to be returning tuple of numbers.

cc: @IvanYashchuk

@kiya00
Copy link
Collaborator Author

kiya00 commented Apr 19, 2024

Thank you @kshitij12345 ! removing this line works for me

@kiya00
Copy link
Collaborator Author

kiya00 commented Apr 19, 2024

sadly, when this line is removed, there are other cases where a NumberProxy name is returned but it doesn't exist in the trace

@IvanYashchuk
Copy link
Collaborator

sadly, when this line is removed, there are other cases where a NumberProxy name is returned but it doesn't exist in the trace

That's probably the reason why these numbers were made concrete. We can't solve this particular problem easily, there are a lot of parts of Thunder that rely on concrete numbers.

@IvanYashchuk
Copy link
Collaborator

Does this cause any problems in your work?

@kiya00
Copy link
Collaborator Author

kiya00 commented Apr 22, 2024

Does this cause any problems in your work?

If we have operator that produces NumberProxy results(unpack_rng_state in my case), and these results happen to be passed to backward pass(like the i3,i4 in the upper trace), thunder used the NumerProxy value in meta function(7 in the example) not the one computed in runtime(i3,i4). I hope I registered the operator the right way

@mruberry
Copy link
Collaborator

triage review:

  • we would like to pursue the general problem of being able to return a NumberProxy without a value here
  • @jjsjann123 , this will probably be necessary for symbolic value support, if you're interested in taking a look

@jjsjann123 jjsjann123 self-assigned this Apr 22, 2024
@jjsjann123
Copy link
Collaborator

jjsjann123 commented Apr 22, 2024

Does this look about right?

root@c574d9980ec8:/volume# python thunder_issue_231.py
# Constructed by Delete Last Used (took 0 milliseconds)
import operator
import thunder.core.devices as devices
import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast()
def augmented_forward_fn(a):
  # a: "cuda:0 f32[2, 2]"
  t2 = get_rng_state_prim_impl(None, devices.Device("cuda:0"))  # t2: "cpu ui8[16]"
  (i3, i4) = unpack_rng_state_prim_impl(t2)
  del t2
  [t1] = nvFusion0(a, i3, i4)
    # t0 = prims.uniform_philox((2, 2), 0.0, 1.0, device=devices.Device("cuda:0"), dtype=dtypes.float32, seed=i3, offset=i4)  # t0: "cuda:0 f32[2, 2]"
    # t1 = prims.mul(t0, a)  # t1: "cuda:0 f32[2, 2]"
  i6 = operator.add(i4, 4)  # i6: "int 4"
    # i6 = prims.add(i4, 4)  # i6: "int 4"
  t7 = pack_rng_state_prim_impl(i3, i6)  # t7: "cpu ui8[16]"
  del i6
  set_rng_state_prim_impl(t7, devices.Device("cuda:0"))
  del t7
  return {'output': t1, 'flat_args': [a], 'flat_output': (t1,)}, ((), (i3, i4))
# Constructed by Delete Last Used (took 0 milliseconds)
import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast()
def backward_fn(saved_for_backward, cotangents):
  # saved_for_backward: "Collection"
  # cotangents: "Collection"
  _, C1, = saved_for_backward
  clear_collection(saved_for_backward)
  del saved_for_backward
  t2, = cotangents
  clear_collection(cotangents)
  del cotangents
  i3, i4, = C1
  clear_collection(C1)
  del C1
  [t12] = nvFusion0(i3, i4, t2)
    # t0 = prims.uniform_philox((2, 2), 0.0, 1.0, device=devices.Device("cuda:0"), dtype=dtypes.float32, seed=i3, offset=i4)  # t0: "cuda:0 f32[2, 2]"
    # t12 = prims.mul(t0, t2)  # t12: "cuda:0 f32[2, 2]"
  del i3, i4, t2
  return (t12,)

I see (i3, i4) is being saved for backward.

This is with your branch + @kshitij12345 's suggestion on removing the treemap + #250

I happen to be playing with this recently and my hack seems to plumbed it through for you.

@jjsjann123
Copy link
Collaborator

sadly, when this line is removed, there are other cases where a NumberProxy name is returned but it doesn't exist in the trace

I don't think my PR helps at all.... since it apparently runs fine without it. 😆

Wondering what are the other cases that you are looking at?

@kiya00
Copy link
Collaborator Author

kiya00 commented Apr 23, 2024

Wondering what are the other cases that you are looking at?

when I remove the treemap line and run the dropout case, it has NameError: name 'f7' is not defined.

def func(a):
    b = torch.nn.functional.dropout(a, p=0.5)
    return b*a

In dropout case the NumberProxy is just constant number(2.0), and in the uniform case it is the i3,i4, which is number produced by an operator, so I tried to hack it with #244. But we probably need more general way to deal with it

@jjsjann123
Copy link
Collaborator

Linking issue #403

@IvanYashchuk
Copy link
Collaborator

@kiya00, I think this issue was closed automatically. Is it really fixed or should be reopened?

@jjsjann123
Copy link
Collaborator

ah... forgot that we had this one open already. I'm linking it to #541 and I'll verify this when I close the other.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment