Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 21 additions & 11 deletions tests/pytorch/test_fusible_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -913,6 +913,8 @@ def test_basic_linear_quantized(
@pytest.mark.parametrize("bias", (False, True))
@pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8"))
@pytest.mark.parametrize("quantized_weight", (False, True))
@pytest.mark.parametrize("input_requires_grad", (False, True))
@pytest.mark.parametrize("weight_requires_grad", (False, True))
def test_linear(
self,
*,
Expand All @@ -923,6 +925,8 @@ def test_linear(
device: torch.device = "cuda",
quantization: Optional[str],
quantized_weight: bool,
input_requires_grad: bool,
weight_requires_grad: bool,
) -> None:
"""GEMM + bias"""

Expand All @@ -943,9 +947,10 @@ def test_linear(
test_device=device,
test_is_fp8=quantized_compute,
)
if isinstance(x_test, QuantizedTensor):
with torch.no_grad():
x_test = x_test.dequantize().requires_grad_()
with torch.no_grad():
if isinstance(x_test, QuantizedTensor):
x_test = x_test.dequantize()
x_test.requires_grad_(requires_grad=input_requires_grad)
w_ref, w_test = make_reference_and_test_tensors(
(out_features, in_features),
test_dtype=dtype,
Expand Down Expand Up @@ -986,9 +991,12 @@ def test_linear(
op.bias.copy_(b_test)
del w_test
del b_test
for param in op.parameters():
param.requires_grad_(requires_grad=weight_requires_grad)
with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
y_test = op(x_test)
y_test.backward(dy_test)
if input_requires_grad or weight_requires_grad:
y_test.backward(dy_test)

# Expected numerical error
tols = dtype_tols(dtype)
Expand All @@ -999,14 +1007,16 @@ def test_linear(

# Check results
y_test = y_test.to(dtype=torch.float64, device="cpu")
dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
dw_test = op.weight.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(y_test, y_ref, **tols)
torch.testing.assert_close(dx_test, x_ref.grad, **tols)
torch.testing.assert_close(dw_test, w_ref.grad, **tols)
if bias:
db_test = op.bias.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(db_test, b_ref.grad, **tols)
if input_requires_grad:
dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(dx_test, x_ref.grad, **tols)
if weight_requires_grad:
dw_test = op.weight.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(dw_test, w_ref.grad, **tols)
if bias:
db_test = op.bias.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(db_test, b_ref.grad, **tols)

@pytest.mark.parametrize("weight_shape", ((7, 2), (32,)))
@pytest.mark.parametrize("in_shape", ((-1,), (6, 16, -1)))
Expand Down
69 changes: 23 additions & 46 deletions transformer_engine/jax/cpp_extensions/gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,59 +142,30 @@ def _calculate_remaining_shape(shape, contracting_dims):
return tuple(shape[dim] for dim in range(len(shape)) if dim not in contracting_dims)


# Apply jit to guarantee correctness of FP8 GEMM.
@partial(
jax.jit,
static_argnums=(
2,
3,
4,
),
)
def __jitted_jax_gemm_tensor_scaling_fp8(lhs, rhs, lhs_dn, rhs_dn, precision):
# Reshape + Transpose
# [..., M, K] -> [B, M, K]
# [..., K, M] -> [B, M, K]
lhs_3d = _shape_normalization(lhs.data, lhs_dn, lhs.data_layout == "N")
rhs_3d = _shape_normalization(rhs.data, rhs_dn, rhs.data_layout == "T")

dim_nums = (((2,), (2,)), ((0,), (0,)))
out_fp8 = jax.lax.dot_general(
lhs_3d, rhs_3d, dim_nums, precision=precision, preferred_element_type=jnp.float32
)
scale_inv = (lhs.scale_inv * rhs.scale_inv).astype(jnp.float32)
def _transpose_contract_dims(ndim, contracting_dims):
return tuple(ndim - i - 1 for i in contracting_dims)[::-1]

return (out_fp8 * scale_inv).astype(lhs.dq_dtype)


def _jax_gemm_tensor_scaling_fp8(
lhs: ScaledTensor, rhs: ScaledTensor, dim_nums: Tuple[Tuple[Sequence[int], Sequence[int]]]
):
"""FP8 GEMM"""
assert rhs.scaling_mode.is_tensor_scaling(), "rhs does not have tensor scaling mode"

# Apply jit to guarantee correctness of FP8 GEMM.
@partial(jax.jit, static_argnums=(2, 3))
def _jax_gemm_tensor_scaling_fp8(lhs, rhs, dim_nums, precision):
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dim_nums
if lhs.data_layout == "T":
lhs_contract = tuple((lhs.data.ndim - 1 - i) % lhs.data.ndim for i in lhs_contract)
lhs_contract = _transpose_contract_dims(lhs.data.ndim, lhs_contract)
if rhs.data_layout == "T":
rhs_contract = tuple((rhs.data.ndim - 1 - i) % rhs.data.ndim for i in rhs_contract)

lhs_dn = (lhs_contract, lhs_batch)
rhs_dn = (rhs_contract, rhs_batch)
rhs_contract = _transpose_contract_dims(rhs.data.ndim, rhs_contract)

lhs_remain_shape = _calculate_remaining_shape(lhs.data.shape, lhs_contract)
rhs_remain_shape = _calculate_remaining_shape(rhs.data.shape, rhs_contract)
dim_nums = (lhs_contract, rhs_contract), (lhs_batch, rhs_batch)

precision = (
jax.lax.Precision.HIGHEST if QuantizeConfig.FP8_2X_ACC_FPROP else jax.lax.Precision.DEFAULT
out_fp8 = jax.lax.dot_general(
lhs.data, rhs.data, dim_nums, precision=precision, preferred_element_type=jnp.float32
)
out_3d = __jitted_jax_gemm_tensor_scaling_fp8(lhs, rhs, lhs_dn, rhs_dn, precision)
scale_inv = (lhs.scale_inv * rhs.scale_inv).astype(jnp.float32)

# Reshape [B, M, N] -> [..., M, N]
out = out_3d.reshape(*lhs_remain_shape, *rhs_remain_shape)
return out
return (out_fp8 * scale_inv).astype(lhs.dq_dtype)


@partial(jax.jit, static_argnums=(2,))
def _jax_gemm_mxfp8_1d(
lhs: ScaledTensor, rhs: ScaledTensor, dim_nums: Tuple[Tuple[Sequence[int], Sequence[int]]]
):
Expand All @@ -204,7 +175,6 @@ def _jax_gemm_mxfp8_1d(
assert (
rhs.scaling_mode == ScalingMode.MXFP8_1D_SCALING
), "rhs does not have MXFP8 1D scaling mode"
from jax._src.cudnn.scaled_matmul_stablehlo import scaled_matmul_wrapper

(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dim_nums

Expand Down Expand Up @@ -235,7 +205,7 @@ def _jax_gemm_mxfp8_1d(
# * Expected shape:
# * lhs_data (B, M, K) * rhs_data (B, N, K)
# * lhs_scale (B, M, K_block) * rhs_scale (B, N, K_block)
out_3d = scaled_matmul_wrapper(
out_3d = jax.nn.scaled_matmul(
lhs_3d, rhs_3d, lhs_scale_3d, rhs_scale_3d, preferred_element_type=lhs.dq_dtype
)
# Reshape [1, reduce(..., M), N] -> [..., M, N]
Expand All @@ -262,9 +232,16 @@ def _jax_gemm(
dim_nums = (contracting_dims, ((), ()))

def _jax_gemm_fp8_impl(lhs, rhs):

if lhs.scaling_mode.is_tensor_scaling():
return _jax_gemm_tensor_scaling_fp8(lhs, rhs, dim_nums)
assert (
rhs.scaling_mode == lhs.scaling_mode
), f"rhs.scaling_mode={rhs.scaling_mode} != lhs.scaling_mode={lhs.scaling_mode}"
precision = (
jax.lax.Precision.HIGHEST
if QuantizeConfig.FP8_2X_ACC_FPROP
else jax.lax.Precision.DEFAULT
)
return _jax_gemm_tensor_scaling_fp8(lhs, rhs, dim_nums, precision)

if lhs.scaling_mode == ScalingMode.MXFP8_1D_SCALING:
return _jax_gemm_mxfp8_1d(lhs, rhs, dim_nums)
Expand Down
66 changes: 43 additions & 23 deletions transformer_engine/pytorch/ops/basic/basic_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,9 @@ def _functional_forward(
input_quantizer: Optional[Quantizer] = None,
weight_quantizer: Optional[Quantizer] = None,
output_quantizer: Optional[Quantizer] = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
input_requires_grad: bool = True,
weight_requires_grad: bool = True,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
"""Functional API for forward pass

Parameters
Expand Down Expand Up @@ -385,17 +387,25 @@ def _functional_forward(
Builder class for quantized weight tensor.
output_quantizer: Quantizer, optional
Builder class for quantized output tensor.
input_requires_grad: bool, default = `True`
Whether the loss gradient w.r.t. the input tensor is
required in the backward pass.
weight_requires_grad: bool, default = `True`
Whether the loss gradient w.r.t. the weight tensor is
required in the backward pass.

Returns
-------
torch.Tensor
Output tensor
torch.Tensor
Input tensor used in GEMM, possibly cast and reshaped from
provided input tensor
torch.Tensor
Weight tensor used in GEMM, possibly cast and reshaped from
provided weight tensor
torch.Tensor, optional
Input tensor, ready for use in backward pass. `None` is
returned if loss gradient w.r.t. the weight tensor is not
required.
torch.Tensor, optional
Weight tensor, ready for use in backward pass. `None` is
returned if loss gradient w.r.t. the input tensor is not
required.

"""

Expand All @@ -416,7 +426,7 @@ def _functional_forward(
if with_quantized_compute:
if input_quantizer is None:
raise ValueError("Missing quantizer for input tensor")
input_quantizer.set_usage(rowwise=True)
input_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad)
if with_x_all_gather:
input_quantizer.set_usage(columnwise=False)
x, x_async = gather_along_first_dim(
Expand Down Expand Up @@ -449,7 +459,7 @@ def _functional_forward(
if with_quantized_compute and not w_is_quantized:
if weight_quantizer is None:
raise ValueError("Missing quantizer for weight tensor")
weight_quantizer.set_usage(rowwise=True)
weight_quantizer.set_usage(rowwise=True, columnwise=input_requires_grad)
w = weight_quantizer(w)
elif not with_quantized_compute and w_is_quantized:
w = w.dequantize()
Expand Down Expand Up @@ -526,17 +536,25 @@ def _functional_forward(
else:
torch.distributed.all_reduce(y, group=tensor_parallel_group)

# Detach input tensor if needed
# Note: PyTorch autograd produces esoteric errors if we save
# input tensor as context for backward pass.
if x_local is input:
x_local = x_local.detach()
# Prepare weight tensor for backward pass
if input_requires_grad:
if w is not weight and with_quantized_compute and isinstance(w, QuantizedTensor):
w.update_usage(rowwise_usage=False, columnwise_usage=True)
else:
w = None

# Configure input tensor for backward pass
if with_quantized_compute and isinstance(x_local, QuantizedTensor):
if not (isinstance(x_local, Float8TensorBase) and with_x_all_gather):
# FP8 does not support all-gather of transpose data
x_local.update_usage(rowwise_usage=False, columnwise_usage=True)
# Prepare input tensor for backward pass
if weight_requires_grad:
if x_local is input:
# PyTorch autograd produces esoteric errors if we
# cache input tensor directly.
x_local = x_local.detach()
if with_quantized_compute and isinstance(x_local, QuantizedTensor):
if not (isinstance(x_local, Float8TensorBase) and with_x_all_gather):
# FP8 does not support all-gather of transpose data
x_local.update_usage(rowwise_usage=False, columnwise_usage=True)
else:
x_local = None

return y, x_local, w

Expand Down Expand Up @@ -892,7 +910,7 @@ def op_forward(
dtype = torch.get_autocast_dtype("cuda")

# Linear forward
output, x_local, _ = BasicLinear._functional_forward(
output, x_local, w = BasicLinear._functional_forward(
input=input_,
weight=self.weight,
dtype=dtype,
Expand All @@ -903,10 +921,12 @@ def op_forward(
input_quantizer=input_quantizer,
weight_quantizer=weight_quantizer,
output_quantizer=output_quantizer,
input_requires_grad=input_requires_grad,
weight_requires_grad=weight_requires_grad,
)

# Save state for backward pass
ctx.save_for_backward(x_local)
ctx.save_for_backward(x_local, w)
ctx.with_quantized_compute = with_quantized_compute
ctx.input_quantizer = input_quantizer
ctx.weight_quantizer = weight_quantizer
Expand All @@ -926,7 +946,7 @@ def op_backward(
) -> tuple[torch.Tensor, Iterable[Optional[torch.Tensor]]]:

# Saved tensors from forward pass
(x_local,) = ctx.saved_tensors
(x_local, w) = ctx.saved_tensors

# wgrad fusion
accumulate_into_main_grad = self._accumulate_into_main_grad
Expand All @@ -946,7 +966,7 @@ def op_backward(
grad_input, grad_weight = BasicLinear._functional_backward(
grad_output=grad_output,
input=x_local,
weight=self.weight,
weight=w,
input_requires_grad=ctx.input_requires_grad,
weight_requires_grad=ctx.weight_requires_grad,
dtype=ctx.dtype,
Expand Down
4 changes: 2 additions & 2 deletions transformer_engine/pytorch/ops/fused/backward_linear_add.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def fuser_backward(
linear_op_ctx = basic_op_ctxs[0]

# Saved tensors from forward pass
(x_local,) = linear_op_ctx.saved_tensors
(x_local, w) = linear_op_ctx.saved_tensors

# wgrad fusion
accumulate_into_main_grad = linear_op._accumulate_into_main_grad
Expand All @@ -72,7 +72,7 @@ def fuser_backward(
grad_input, grad_weight = BasicLinear._functional_backward(
grad_output=grad_output,
input=x_local,
weight=linear_op.weight,
weight=w,
input_requires_grad=linear_op_ctx.input_requires_grad,
weight_requires_grad=linear_op_ctx.weight_requires_grad,
dtype=grad_input.dtype,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,10 @@ def fuser_forward(
else:
raise NotImplementedError("Activations are not yet supported")

# Check which grads are required
input_requires_grad = linear_op_ctx.requires_grad and input_.requires_grad
weight_requires_grad = linear_op_ctx.requires_grad and linear_op.weight.requires_grad

# FP8 metadata
with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled()
input_quantizer = None
Expand All @@ -106,7 +110,7 @@ def fuser_forward(
dtype = torch.get_autocast_dtype("cuda")

# Linear forward
output, x_local, _ = BasicLinear._functional_forward(
output, x_local, w = BasicLinear._functional_forward(
input=input_,
weight=linear_op.weight,
bias=bias,
Expand All @@ -118,18 +122,20 @@ def fuser_forward(
input_quantizer=input_quantizer,
weight_quantizer=weight_quantizer,
output_quantizer=output_quantizer,
input_requires_grad=input_requires_grad,
weight_requires_grad=weight_requires_grad,
)

# Save state for backward pass
linear_op_ctx.save_for_backward(x_local)
linear_op_ctx.save_for_backward(x_local, w)
linear_op_ctx.with_quantized_compute = with_quantized_compute
linear_op_ctx.input_quantizer = input_quantizer
linear_op_ctx.weight_quantizer = weight_quantizer
linear_op_ctx.grad_output_quantizer = grad_output_quantizer
linear_op_ctx.grad_input_quantizer = grad_input_quantizer
linear_op_ctx.dtype = dtype
linear_op_ctx.input_requires_grad = input_.requires_grad
linear_op_ctx.weight_requires_grad = linear_op.weight.requires_grad
linear_op_ctx.input_requires_grad = input_requires_grad
linear_op_ctx.weight_requires_grad = weight_requires_grad
linear_op_ctx.has_prev_op = basic_op_prev_ops[0] is not None

return output, [() for _ in range(len(self.basic_ops))]
Expand Down
Loading