Skip to content

Commit

Permalink
make sure Fp8 weight buffers are sharded at the end of the backward p…
Browse files Browse the repository at this point in the history
…ass and gathered before forward

Signed-off-by: Alp Dener <adener@nvidia.com>
  • Loading branch information
denera committed Feb 28, 2024
1 parent 23c0cd5 commit d18b49f
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 5 deletions.
20 changes: 19 additions & 1 deletion transformer_engine/pytorch/module/layernorm_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,9 @@ def forward(
weight_fp8 = weight
weight_t_fp8 = None
elif update_fp8_weights:
# Gather Fp8 weight buffers if needed
if fsdp_group is not None and weight_fp8._data.shape != weight.data.shape:
_fsdp_gather_tensors(fsdp_group, [weight.data.shape], weight_fp8)
# Need to cast weights to FP8
weight_fp8 = Float8Tensor(
data=weight_fp8._data,
Expand All @@ -181,6 +184,12 @@ def forward(
if (is_grad_enabled
or (is_fp8_activation_recompute_enabled()
and not in_fp8_activation_recompute_phase())):
# Gather Fp8 transposed-weight buffers if needed
if (fsdp_group is not None
and weight_t_fp8._data.shape != reversed(weight.data.shape)):
_fsdp_gather_tensors(fsdp_group,
[tuple(reversed(weight.data.shape))],
weight_t_fp8)
tex.fp8_cast_transpose_fused(
weight,
fp8_meta["scaling_fwd"],
Expand Down Expand Up @@ -261,12 +270,15 @@ def forward(
rsigma.activation_offloading = True
ln_out.activation_offloading = True

# Scatter Fp8 weight buffers
_fsdp_scatter_tensors(fsdp_group, weight_fp8, weight_fp8)

# Scatter intermediate/activation tensors saved for the backward pass
ctx.fsdp_group = fsdp_group
ctx.fsdp_shapes = _fsdp_scatter_tensors(
fsdp_group,
mu,
rsigma,
weight.main_grad if cpu_offloading and fuse_wgrad_accumulation else None,
weight_t_fp8,
ln_out
)
Expand Down Expand Up @@ -338,6 +350,7 @@ def backward(
fwd_scale_inverses,
) = ctx.saved_tensors

# Gather intermediate/activation tensors if needed
_fsdp_gather_tensors(
ctx.fsdp_group,
ctx.fsdp_shapes,
Expand Down Expand Up @@ -575,6 +588,8 @@ def backward(
ctx.bwd_ln_sm_margin, ctx.zero_centered_gamma
)
dbeta = None
clear_tensor_data(mu)
clear_tensor_data(rsigma)

if not ctx.use_bias:
grad_bias = None
Expand All @@ -600,6 +615,9 @@ def backward(
else:
wgrad = None

# Scatter fp8 transposed-weight buffers
_fsdp_scatter_tensors(ctx.fsdp_group, weight_t_fp8)

return (
dgrad.view(ctx.inp_shape) if ctx.requires_dgrad else None,
dgamma,
Expand Down
32 changes: 32 additions & 0 deletions transformer_engine/pytorch/module/layernorm_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,17 @@ def forward(
fc1_weight_t_fp8 = None
fc2_weight_t_fp8 = None
elif update_fp8_weights:
# Gather Fp8 weight buffers if needed
if fsdp_group is not None:
weights_to_gather = []
gather_shapes = []
if fc1_weight_fp8._data.shape != fc1_weight.data.shape:
weights_to_gather.append(fc1_weight_fp8)
gather_shapes.append(fc1_weight.data.shape)
if fc2_weight_fp8._data.shape != fc2_weight.data.shape:
weights_to_gather.append(fc2_weight_fp8)
gather_shapes.append(fc2_weight.data.shape)
_fsdp_gather_tensors(fsdp_group, gather_shapes, weights_to_gather)
# Need to cast weights to FP8
fc1_weight_fp8 = Float8Tensor(
data=fc1_weight_fp8._data,
Expand All @@ -230,6 +241,17 @@ def forward(
if (is_grad_enabled
or (is_fp8_activation_recompute_enabled()
and not in_fp8_activation_recompute_phase())):
# Gather Fp8 transposed-weight buffers if needed
if fsdp_group is not None:
weights_to_gather = []
gather_shapes = []
if fc1_weight_t_fp8._data.shape != reversed(fc1_weight.data.shape):
weights_to_gather.append(fc1_weight_t_fp8)
gather_shapes.append(tuple(reversed(fc1_weight.data.shape)))
if fc2_weight_t_fp8._data.shape != reversed(fc2_weight.data.shape):
weights_to_gather.append(fc2_weight_t_fp8)
gather_shapes.append(tuple(reversed(fc2_weight.data.shape)))
_fsdp_gather_tensors(fsdp_group, gather_shapes, weights_to_gather)
# Fused cast-transpose kernels
tex.fp8_cast_transpose_fused(
fc1_weight,
Expand Down Expand Up @@ -473,6 +495,10 @@ def forward(
fc1_out.activation_offloading = True
gelu_out.activation_offloading = True

# Scatter Fp8 weight buffers
_fsdp_scatter_tensors(fsdp_group, fc1_weight_fp8, fc2_weight_fp8)

# Scatter intermediate/activation tensors saved for the backward pass
ctx.fsdp_group = fsdp_group
ctx.fsdp_shapes = _fsdp_scatter_tensors(
fsdp_group,
Expand Down Expand Up @@ -1000,6 +1026,8 @@ def backward(
ctx.bwd_ln_sm_margin, ctx.zero_centered_gamma
)
dbeta = None
clear_tensor_data(mu)
clear_tensor_data(rsigma)

if fc1_weight.requires_grad:
# Handle custom DDP from mcore.
Expand Down Expand Up @@ -1043,6 +1071,10 @@ def backward(
else:
fc2_wgrad = None

# Scatter Fp8 tranposed-weight buffers
_fsdp_scatter_tensors(ctx.fsdp_group, fc1_weight_t_fp8)
_fsdp_scatter_tensors(ctx.fsdp_group, fc2_weight_t_fp8)

return (
dgrad.view(ctx.inp_shape) if ctx.requires_dgrad else None,
dgamma,
Expand Down
22 changes: 18 additions & 4 deletions transformer_engine/pytorch/module/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,9 @@ def forward(
weight_fp8 = weight
weight_t_fp8 = None
elif update_fp8_weights:
# Gather Fp8 weight buffers if needed
if fsdp_group is not None and weight_fp8._data.shape != weight.data.shape:
_fsdp_gather_tensors(fsdp_group, [weight.data.shape], weight_fp8)
# Need to cast weights to FP8
weight_fp8 = Float8Tensor(
data=weight_fp8._data,
Expand All @@ -163,6 +166,12 @@ def forward(
if (is_grad_enabled
or (is_fp8_activation_recompute_enabled()
and not in_fp8_activation_recompute_phase())):
# Gather Fp8 transposed-weight buffers if needed
if (fsdp_group is not None
and weight_t_fp8._data.shape != reversed(weight.data.shape)):
_fsdp_gather_tensors(fsdp_group,
[tuple(reversed(weight.data.shape))],
weight_t_fp8)
fp8_cast_transpose_fused(
weight,
fp8_meta["scaling_fwd"],
Expand Down Expand Up @@ -290,13 +299,15 @@ def forward(
if saved_inputmat is not None:
saved_inputmat.activation_offloading = True

fwd_scale_inverses = fp8_meta["scaling_fwd"].scale_inv.clone() if fp8 else None
# Scatter Fp8 weight buffers
_fsdp_scatter_tensors(fsdp_group, weight_fp8, weight_fp8)

# Scatter intermediate/activation tensors saved for the backward pass
ctx.fsdp_group = fsdp_group
ctx.fsdp_shapes = _fsdp_scatter_tensors(
fsdp_group,
saved_inputmat, # None if fp8 == False
saved_inputmat_t, # None if fp8 == False AND not is_grad_enabled
weight.main_grad if cpu_offloading and fuse_wgrad_accumulation else None,
weight_t_fp8 if fp8 else None,
)

Expand All @@ -306,7 +317,7 @@ def forward(
weight,
weight.main_grad if cpu_offloading and fuse_wgrad_accumulation else None,
weight_t_fp8 if fp8 else None,
fwd_scale_inverses
fp8_meta["scaling_fwd"].scale_inv.clone() if fp8 else None
)

ctx.activation_dtype = activation_dtype
Expand Down Expand Up @@ -355,11 +366,11 @@ def backward(
fwd_scale_inverses,
) = ctx.saved_tensors

# Gather intermediate/activation tensors if needed
_fsdp_gather_tensors(ctx.fsdp_group,
ctx.fsdp_shapes,
inputmat,
inputmat_t,
main_grad,
weight_t_fp8)

if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation:
Expand Down Expand Up @@ -539,6 +550,9 @@ def backward(
else:
wgrad = None

# Scatter fp8 transposed-weight buffers
_fsdp_scatter_tensors(ctx.fsdp_group, weight_t_fp8)

return (
wgrad,
None,
Expand Down

0 comments on commit d18b49f

Please sign in to comment.