Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use torch.get_default_dtype and torch.get_default_device for factory method in thunder/torch/__init__.py #621

Closed
jjsjann123 opened this issue Jun 18, 2024 · 5 comments · Fixed by #820

Comments

@jjsjann123
Copy link
Collaborator

jjsjann123 commented Jun 18, 2024

🐛 Bug

thunder's produces output with different dtype in compiled function

To Reproduce

import thunder
import torch

def foo():
    return torch.ones((1,), device="cuda")

jfoo = thunder.jit(foo)  # works

print("thunder output: ", jfoo())  # integer type
print("ref output: ", foo())  # float type
print(thunder.last_traces(jfoo)[0])

Pitches

thunder/torch/__init__.py isn't properly pulling torch's default dtype/devices in ops like torch.full, torch.empty.
Resulting in wrong behavior.

The trace from the above function is:

thunder output:  tensor([1], device='cuda:0')
ref output:  tensor([1.], device='cuda:0')
import thunder
import thunder.torch as ltorch
import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def computation():
  # /volume/thunder_463_part2.py:5:         return torch.ones((1,), device="cuda")
  t0 = ltorch.ones((1,), device='cuda', dtype=None)  # t0: "cuda:0 i64[1]"
    # t0 = ltorch.full((1,), 1, device='cuda', dtype=None)  # t0: "cuda:0 i64[1]"
      # t0 = prims.full((1,), 1, device=devices.Device("cuda:0"), dtype=dtypes.int64_)  # t0: "cuda:0 i64[1]"
  return t0

the lowered prims.full is executed by nvfuser, which receives an explicit dtype which is different from the vanilla function.

cc @apaz-cli

@mruberry
Copy link
Collaborator

triage review:

  • let's capture these values at trace time and (ideally) generate a constraint that they have not changed

@lantiga
Copy link
Collaborator

lantiga commented Jun 24, 2024

added high-priority since it's needed to support transformers

@riccardofelluga riccardofelluga self-assigned this Jun 25, 2024
@riccardofelluga
Copy link
Collaborator

riccardofelluga commented Jul 12, 2024

I digged a bit to understand where this change might be. It looks to me that we could actually put the default device/dtype information in the cache info, however then we would still have the problem of what happens when the user changes the default device/dtype inside the compiled function like:

import thunder
import torch
# default dtype here is torch.float32
def foo():
    a = torch.ones((1, ), device="cuda")
    torch.set_default_dtype(torch.float64)
    b = torch.ones((1, ), device="cuda")
    return a, b

jfoo = thunder.jit(foo)
a, b = jfoo()

assert a.dtype != b.dtype

This work might also require some better standardization of this tensor creating operations where for each one of them we have slightly different behaviors. For example the snipped above in torch should pass or at least produce a in torch.float32 dtype and b in torch.float64 dtype. However as of writing, thunder takes the dtype information from the input tuple and creates two torch.int64 tensors. I unassign this for now, I think it's worth waiting Q3 planning before preceding.

@riccardofelluga riccardofelluga removed their assignment Jul 12, 2024
@jjsjann123
Copy link
Collaborator Author

however then we would still have the problem of what happens when the user changes the default device/dtype inside the compiled function like:

Is the problem here that we are worried about reordering the set_default_dtype node in the trace? I feel that's a general problem in graph transformation that we don't have a good strategy for yet.

In terms of I looks to me that we could actually put the default device/dtype information in the cache info, I think that part should work fine.

i.e. the cache information only guards for things at the call site of the compiled function. User should be able to update the default device/dtype in the program. That's the same behavior as with eager. If the end status is the same, follow up iteration would be a cache hit as well.

@t-vi
Copy link
Collaborator

t-vi commented Jul 16, 2024

So the dtype part of this works as @kshitij12345 fixed the #750, we should likely also do the device?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants