Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 125 additions & 0 deletions benchmarks/linear/benchmark_linear_cpu_overhead.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

from typing import List
import torch
import time
import argparse

import transformer_engine.pytorch as te


def run_once(
module: torch.nn.Module,
args: List[torch.Tensor],
iters=2000,
use_te: bool = True,
is_first_microbatch: bool = True,
):
if use_te:
for _ in range(iters):
module(*args, is_first_microbatch=is_first_microbatch)
else:
for _ in range(iters):
module(*args)


def speedometer(
module: torch.nn.Module,
args: List[torch.Tensor],
timing_iters: int = 2000,
warmup_iters: int = 100,
num_rounds: int = 5,
use_te: bool = True,
) -> float:
"""Measure average run time for a PyTorch module"""
# warm up
run_once(module, args, iters=warmup_iters, use_te=use_te, is_first_microbatch=True)

gpu_times = []
cpu_times = []
for round_idx in range(num_rounds):
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
torch.cuda.synchronize()
start.record()
cpu_start = time.time()
# run timing
run_once(module, args, iters=timing_iters, use_te=use_te, is_first_microbatch=False)
cpu_end = time.time()
end.record()
torch.cuda.synchronize()
gpu_elapsed = start.elapsed_time(end)
cpu_elapsed = (cpu_end - cpu_start) * 1000
gpu_times.append(gpu_elapsed)
cpu_times.append(cpu_elapsed)
print(
f"Round {round_idx+1}/{num_rounds}: GPU {gpu_elapsed/timing_iters*1000:.2f} µs, CPU"
f" {cpu_elapsed/timing_iters*1000:.2f} µs"
)
print(
f"Average GPU time over {num_rounds} rounds:"
f" {sum(gpu_times)/(num_rounds*timing_iters)*1000:.2f} µs"
)
print(
f"Average CPU time over {num_rounds} rounds:"
f" {sum(cpu_times)/(num_rounds*timing_iters)*1000:.2f} µs"
)

return sum(gpu_times) / num_rounds


def main():
parser = argparse.ArgumentParser(
description="Benchmark torch.nn.Linear performance and CPU overhead."
)
parser.add_argument("--hidden_size", type=int, default=3072, help="Hidden size")
parser.add_argument("--seq_length", type=int, default=2048, help="Sequence length")
parser.add_argument("--warmup", type=int, default=500, help="Number of warmup iterations")
parser.add_argument(
"--timing_iters", type=int, default=2000, help="Number of timing iterations per round"
)
parser.add_argument("--num_rounds", type=int, default=5, help="Number of timing rounds")
parser.add_argument(
"--backend",
type=str,
choices=["torch", "te"],
default="te",
help="Linear backend: torch or te",
)
args = parser.parse_args()

x = torch.randn(
(args.seq_length, args.hidden_size), dtype=torch.bfloat16, device="cuda", requires_grad=True
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: requires_grad=True is set but the benchmark runs under torch.no_grad() context. Setting requires_grad=False would be more consistent.

)
use_te = True
if args.backend == "torch":
model = (
torch.nn.Linear(args.hidden_size, args.hidden_size, bias=False)
.to(torch.bfloat16)
.cuda()
)
use_te = False
else:
model = te.Linear(args.hidden_size, args.hidden_size, bias=False, device="cuda").to(
torch.bfloat16
)
with torch.no_grad():
avg_gpu_time_per_round = speedometer(
model,
[x],
timing_iters=args.timing_iters,
warmup_iters=args.warmup,
num_rounds=args.num_rounds,
use_te=use_te,
)

total_ops = 2 * args.hidden_size * args.hidden_size * args.seq_length * args.timing_iters
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: FLOP calculation uses avg_gpu_time_per_round which is the sum over num_rounds, but total_ops is scaled by timing_iters only. The division should use avg_gpu_time_per_round / num_rounds or total_ops * num_rounds.


tflops = total_ops / avg_gpu_time_per_round / 1e9
print(f"Estimated TFLOP/s: {tflops:.2f}")


if __name__ == "__main__":
main()
18 changes: 11 additions & 7 deletions transformer_engine/common/transformer_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -717,11 +717,15 @@ void nvte_destroy_quantization_config(NVTEQuantizationConfig config) {
}

int nvte_is_non_tn_fp8_gemm_supported() {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't handle the case where we have multiple GPUs with different archs. We could add an arg for the device ID, but that just pushes the CPU overhead problem somewhere else.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, but we didn't really support this case anyway?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For topology like 1 CPU 8/4GPUs with homogenous GPU arch, we can cache the TN layout check.

int deviceComputeCapability =
transformer_engine::cuda::sm_arch(transformer_engine::cuda::current_device());

// Note: this is temporary restriction and should be lifted in the future.
// (remove the note once it's done.)
return (deviceComputeCapability >= 100 && deviceComputeCapability < 120) ||
deviceComputeCapability >= 130;
static int cached_result = 0;
static std::once_flag flag;
std::call_once(flag, []() {
int deviceComputeCapability =
transformer_engine::cuda::sm_arch(transformer_engine::cuda::current_device());
// Note: this is temporary restriction and should be lifted in the future.
// (remove the note once it's done.)
cached_result = (deviceComputeCapability >= 100 && deviceComputeCapability < 120) ||
deviceComputeCapability >= 130;
});
return cached_result;
}
13 changes: 10 additions & 3 deletions transformer_engine/pytorch/module/grouped_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"""GroupedLinear API"""
from typing import Union, Optional, Callable, Tuple, List
import warnings
import contextlib

import functools
import torch
Expand All @@ -28,6 +29,7 @@
clear_tensor_data,
init_method_constant,
requires_grad,
should_set_cuda_device_every_batch,
)
from ..distributed import (
set_tensor_model_parallel_attributes,
Expand Down Expand Up @@ -766,9 +768,14 @@ def forward(
if skip_fp8_weight_update is not None:
is_first_microbatch = False

with torch.cuda.device(
getattr(self, list(self.named_parameters())[0][0]).device
), self.prepare_forward(inp, num_gemms=self.num_gemms) as inp:
if should_set_cuda_device_every_batch():
device_ctx = torch.cuda.device(
getattr(self, list(self.named_parameters())[0][0]).device
)
else:
device_ctx = contextlib.nullcontext()
Comment on lines +771 to +776
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: Retrieving device via list(self.named_parameters())[0][0] creates a list every forward pass. Consider caching during initialization to avoid repeated computation.


with device_ctx, self.prepare_forward(inp, num_gemms=self.num_gemms) as inp:
weight_tensors = self._get_weight_tensors()
bias_tensors = [getattr(self, f"bias{i}") for i in range(self.num_gemms)]

Expand Down
16 changes: 12 additions & 4 deletions transformer_engine/pytorch/module/layernorm_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import Callable, Dict, Optional, Tuple, Union, List
from functools import reduce
from operator import mul as multiply_op
import contextlib

import torch
from torch.nn import init
Expand Down Expand Up @@ -40,6 +41,7 @@
nvtx_range_push,
requires_grad,
needs_quantized_gemm,
should_set_cuda_device_every_batch,
)
from ..distributed import (
set_tensor_model_parallel_attributes,
Expand Down Expand Up @@ -1532,10 +1534,16 @@ def forward(
).is_fp8_ubuf():
fp8_grad = True

with torch.cuda.device(
getattr(self, list(self.named_parameters())[0][0]).device
), self.prepare_forward(
inp, allow_non_contiguous=False # removed .contiguous from inside the layer
if should_set_cuda_device_every_batch():
device_ctx = torch.cuda.device(
getattr(self, list(self.named_parameters())[0][0]).device
)
else:
device_ctx = contextlib.nullcontext()
Comment on lines +1537 to +1542
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: The device is retrieved by accessing the first parameter via list(self.named_parameters())[0][0]. This creates a list from the iterator on every forward pass. Consider caching the device during initialization (e.g., in __init__ or reset_parameters) to avoid repeated computation. Is there a reason the device cannot be cached during initialization?


with device_ctx, self.prepare_forward(
inp,
allow_non_contiguous=False, # removed .contiguous from inside the layer
) as inp:

# Get concatenated weight and bias tensors
Expand Down
13 changes: 10 additions & 3 deletions transformer_engine/pytorch/module/layernorm_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import Callable, Optional, Tuple, Union, List
from functools import reduce
from operator import mul as multiply_op
import contextlib

import torch
from torch.nn.parameter import Parameter
Expand Down Expand Up @@ -45,6 +46,7 @@
clear_tensor_data,
requires_grad,
needs_quantized_gemm,
should_set_cuda_device_every_batch,
)
from ..distributed import (
set_tensor_model_parallel_attributes,
Expand Down Expand Up @@ -1806,9 +1808,14 @@ def forward(
if get_ub("fc2_fprop", FP8GlobalStateManager.is_fp8_enabled()).is_fp8_ubuf():
fp8_output = True

with torch.cuda.device(
getattr(self, list(self.named_parameters())[0][0]).device
), self.prepare_forward(inp, num_gemms=2) as inp:
if should_set_cuda_device_every_batch():
device_ctx = torch.cuda.device(
getattr(self, list(self.named_parameters())[0][0]).device
)
else:
device_ctx = contextlib.nullcontext()

with device_ctx, self.prepare_forward(inp, num_gemms=2) as inp:

quantizers = (
self._get_quantizers(fp8_output)
Expand Down
13 changes: 10 additions & 3 deletions transformer_engine/pytorch/module/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from functools import reduce
from operator import mul as multiply_op
import warnings
import contextlib

import torch

Expand Down Expand Up @@ -38,6 +39,7 @@
assert_dim_for_all_gather,
nvtx_range_pop,
nvtx_range_push,
should_set_cuda_device_every_batch,
)
from ..distributed import (
set_tensor_model_parallel_attributes,
Expand Down Expand Up @@ -1419,9 +1421,14 @@ def forward(
).is_fp8_ubuf():
fp8_grad = True

with torch.cuda.device(
getattr(self, list(self.named_parameters())[0][0]).device
), self.prepare_forward(
if should_set_cuda_device_every_batch():
device_ctx = torch.cuda.device(
getattr(self, list(self.named_parameters())[0][0]).device
)
else:
device_ctx = contextlib.nullcontext()

with device_ctx, self.prepare_forward(
inp,
allow_non_contiguous=isinstance(inp, QuantizedTensor),
) as inp:
Expand Down
27 changes: 27 additions & 0 deletions transformer_engine/pytorch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,6 +507,33 @@ def get_cudnn_version() -> Tuple[int, int, int]:
return (major, minor, patch)


_env_var = os.environ.get("NVTE_SET_CUDA_DEVICE")
if _env_var is not None:
_set_cuda_device_every_batch = _env_var.strip() == "1"
else:
_set_cuda_device_every_batch = True


def set_cuda_device_every_batch(enabled: bool = True) -> None:
"""
Controls whether the module forward methods set CUDA device context for every batch.
By default, this behavior is enabled, unless overridden with the
NVTE_SET_CUDA_DEVICE environment variable.
If enabled, the relevant modules will wrap each forward pass in `torch.cuda.device(...)`
using the module's first parameter's device.
"""
global _set_cuda_device_every_batch
_set_cuda_device_every_batch = enabled


def should_set_cuda_device_every_batch() -> bool:
"""
Returns True if the system is configured to set CUDA device context every batch,
otherwise False.
"""
return _set_cuda_device_every_batch


def canonicalize_device(device: Optional[torch.device | str]) -> torch.device:
"""Canonicalize PyTorch device

Expand Down