Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Disable Upcasting and Downcasting to BF16/FP16 around matmul and linear operations for the nvFuser Executor #2054

Closed
kevinstephano opened this issue Apr 8, 2024 · 4 comments

Comments

@kevinstephano
Copy link
Collaborator

In Thunder, the behavior is to explicitly upcast to FP32 and downcast the FP/BF16 around a set of fusion operations. In the case of matmul and linear operations this would accidentally suggest not to use TensorCores for the operations and, therefore, this casting behavior needs to be changed.

Please consult with @jjsjann123 about the appropriate course of action!

@jjsjann123
Copy link
Collaborator

There's many pieces that could affect how a matmul will look like in the trace. i.e. grad transform rules / auto cast.

(note: I don't know how autocast transform is used in thunder, but looking at the code I think it's helping downcast inputs to reduced precision, so exactly what we wanted. cc'ing @IvanYashchuk / @tfogal who might know better).

Overall, thunder doesn't imply any type promotion logic for torch matmul/linear on its decomposition level: lightning-thunder/thunder/torch/__init__.py: matmul/linear is just mapped to prim directly;
Contrasting to how binary operations are handled in thunder: lightning-thunder/thunder/clang/__init__.py: add->clang.add->_elementwise_binary_wrapper does a type promotion

So generally I don't think we have any issue with thunder for now. i.e. thunder will be able to show input to matmul/linear in proper dtype as nvfuser would want to see. Example below showing how grad transform still leaves input to matmul with reduced precision. vvv

import thunder

def foo(a, b, w):
  return torch.matmul(a + b, w)
  # similarly we can have trace with linear as well.
  # return torch.nn.functional.linear(a + b, w, b)

import torch

dtype = torch.bfloat16

x = torch.randn(8, 16, device="cuda").to(dtype=dtype)
y = torch.randn(16, device="cuda").to(dtype=dtype)
z = torch.randn(16, 16, device="cuda").to(dtype=dtype)
x.requires_grad_()
y.requires_grad_()
z.requires_grad_()

x_ref = x.detach()
y_ref = y.detach()
z_ref = z.detach()
x_ref.requires_grad_()
y_ref.requires_grad_()
z_ref.requires_grad_()
out_ref = foo(x_ref, y_ref, z_ref)

jfoo = thunder.jit(foo)
out = jfoo(x, y, z)

assert out.allclose(out_ref)

out.sum().backward()
out_ref.sum().backward()

assert x.allclose(x_ref)
assert y.allclose(y_ref)
assert z.allclose(z_ref)

print("fwd\n", jfoo._lc_cs.last_traces[0])
print("bwd\n", jfoo._lc_cs.last_backward_traces[0])

Which gives trace of:

fwd
 import thunder
import thunder.torch as ltorch
import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast()
def computation(a, b, w):
  # a: "cuda:0 bf16[8, 16]"
  # b: "cuda:0 bf16[16]"
  # w: "cuda:0 bf16[16, 16]"
  result = ltorch.add(a, b, alpha=None)  # result: "cuda:0 bf16[8, 16]"
    # t0 = prims.broadcast_in_dim(b, (8, 16), (1,))  # t0: "cuda:0 bf16[8, 16]"
    # t1 = prims.convert_element_type(a, dtypes.float32)  # t1: "cuda:0 f32[8, 16]"
    # t2 = prims.convert_element_type(t0, dtypes.float32)  # t2: "cuda:0 f32[8, 16]"
    # t3 = prims.add(t1, t2)  # t3: "cuda:0 f32[8, 16]"
    # result = prims.convert_element_type(t3, dtypes.bfloat16)  # result: "cuda:0 bf16[8, 16]"
  t5 = ltorch.matmul(result, w)  # t5: "cuda:0 bf16[8, 16]"
    # t5 = prims.matmul(result, w)  # t5: "cuda:0 bf16[8, 16]"
  return t5
bwd
 # Constructed by Backward pass
import thunder
import thunder.core.dtypes as dtypes
import thunder.core.prims as prims
import thunder.torch as ltorch
import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast()
def backward_fn(saved_for_backward, cotangents):
  # saved_for_backward: "Collection"
  # cotangents: "Collection"
  C0, C1, = saved_for_backward
  t6, = cotangents
  t4, w, = C0
  # C1 (empty sequence)
  t13 = prims.transpose(w, (1, 0))  # t13: "cuda:0 bf16[16, 16]"
  t14 = ltorch.matmul(t6, t13)  # t14: "cuda:0 bf16[8, 16]"
    # t14 = prims.matmul(t6, t13)  # t14: "cuda:0 bf16[8, 16]"
  t15 = prims.transpose(t4, (1, 0))  # t15: "cuda:0 bf16[16, 8]"
  t16 = ltorch.matmul(t15, t6)  # t16: "cuda:0 bf16[16, 16]"
    # t16 = prims.matmul(t15, t6)  # t16: "cuda:0 bf16[16, 16]"
  t17 = prims.convert_element_type(t14, dtypes.float32)  # t17: "cuda:0 f32[8, 16]"
  t18 = prims.convert_element_type(t17, dtypes.bfloat16)  # t18: "cuda:0 bf16[8, 16]"
  t19 = prims.convert_element_type(t17, dtypes.bfloat16)  # t19: "cuda:0 bf16[8, 16]"
  t22 = ltorch.sum(t18, (0,), False, dtype=None)  # t22: "cuda:0 bf16[16]"
    # t20 = ltorch.to(t18, dtypes.float32, None, device=None, dtype=None, copy=False)  # t20: "cuda:0 f32[8, 16]"
      # t20 = prims.convert_element_type(t18, dtypes.float32)  # t20: "cuda:0 f32[8, 16]"
    # t21 = prims.sum(t20, (0,))  # t21: "cuda:0 f32[16]"
    # t22 = ltorch.to(t21, dtypes.bfloat16, None, device=None, dtype=None, copy=False)  # t22: "cuda:0 bf16[16]"
      # t22 = prims.convert_element_type(t21, dtypes.bfloat16)  # t22: "cuda:0 bf16[16]"
  return (t19, t22, t16)

@IvanYashchuk
Copy link
Collaborator

IvanYashchuk commented Apr 10, 2024

This is a great and accurate summary, Jie! Thunder's autocast could be renamed to "auto-downcast" as this is what it does today and only for linear, matmul, sdpa. It should be applied behind the scenes automatically if the jitted function is called under torch.autocast context manager.

In Thunder, the behavior is to explicitly upcast to FP32 and downcast the FP/BF16 around a set of fusion operations.

This is true that Thunder upcasts to FP32 in general except matmul, linear, conv, sdpa, and maybe some other special operation.

@kevinstephano
Copy link
Collaborator Author

Waiting on Priya to verify this works once matmul and linear are enabled in Thunder.

@Priya2698
Copy link
Collaborator

PRs: Lightning-AI/lightning-thunder#318 and Lightning-AI/lightning-thunder#207 enable matmul and linear for nvFuser executor in Thunder.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

4 participants