Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 2 additions & 9 deletions transformer_engine/pytorch/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,12 +324,7 @@ def _make_graphed_callables(

if cache_quantized_params:
# Initialize flag that controls FP8 weight updates
qstate = FP8GlobalStateManager.quantization_state
if qstate.skip_fp8_weight_update_tensor is None:
qstate.skip_fp8_weight_update_tensor = torch.empty(
1, dtype=torch.float32, device="cuda"
)
qstate.skip_fp8_weight_update_tensor.fill_(False)
FP8GlobalStateManager.set_skip_fp8_weight_update_tensor(False)

# Check callables
for c in callables:
Expand Down Expand Up @@ -841,9 +836,7 @@ def forward(ctx, skip_fp8_weight_update, cuda_graph_stream, cuda_graph_event, *i
# Set flag for whether to update FP8 weight updates
ctx.is_first_module = FP8GlobalStateManager.is_first_fp8_module()
if ctx.is_first_module and skip_fp8_weight_update is not None:
FP8GlobalStateManager.quantization_state.skip_fp8_weight_update_tensor.fill_(
skip_fp8_weight_update
)
FP8GlobalStateManager.set_skip_fp8_weight_update_tensor(skip_fp8_weight_update)
ctx.cuda_graph_stream = cuda_graph_stream
ctx.cuda_graph_event = cuda_graph_event
# Copy values from new tensors into static tensors
Expand Down
14 changes: 14 additions & 0 deletions transformer_engine/pytorch/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,20 @@ class FP8GlobalStateManager:

quantization_state = FP8GlobalState()

@classmethod
def set_skip_fp8_weight_update_tensor(cls, skip: bool) -> None:
"""Set the skip fp8 weight update tensor"""
if cls.quantization_state.skip_fp8_weight_update_tensor is None:
cls.quantization_state.skip_fp8_weight_update_tensor = torch.empty(
1, dtype=torch.float32, device="cuda"
)
cls.quantization_state.skip_fp8_weight_update_tensor.fill_(skip)

@classmethod
def get_skip_fp8_weight_update_tensor(cls) -> Optional[torch.Tensor]:
"""Get the skip fp8 weight update tensor"""
return cls.quantization_state.skip_fp8_weight_update_tensor

@classmethod
def reset(cls) -> None:
"""Reset the global state"""
Expand Down
Loading