-
Notifications
You must be signed in to change notification settings - Fork 78
Description
Tracking the issue for user experience. In the issue comment here.
The to translate a user program like this vvv
# t18 = prims.where(t17, -0.0, 0.0) # t18: "cuda:0 f32[4, 2, 3]"
We are calling nvfuser where on (boolean_tv, scalar_1, scalar_2).
The type promotion/inference logic is that, we determine the output dtype of where based on input[1] and input[2].
We are producing output TV from this operation, but since all scalar types are in double, nvfuser generates a double tensor as output, instead of a float32 as the thunder scripts. (since thunder sees input scalar as in float).
This isn't blocking, as we can always patch the executor in thunder for explicit type inference: Lightning-AI/lightning-thunder#1734
Per our thunder developer's request, we still want an issue to track this so we no longer needed these WAR on user of nvfuser.