Skip to content
117 changes: 100 additions & 17 deletions tests/pytorch/attention/run_attention_with_cp.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@

import os
import sys
import copy
import json
import traceback
import logging
from contextlib import nullcontext
import torch
Expand All @@ -21,6 +24,7 @@
Float8CurrentScalingQuantizer,
MXFP8Quantizer,
)
from transformer_engine.pytorch.quantization import FP8GlobalStateManager
from transformer_engine.common.recipe import (
DelayedScaling,
Float8CurrentScaling,
Expand Down Expand Up @@ -209,10 +213,10 @@ def run_dpa_with_cp(
os.environ["NVTE_FUSED_ATTN"] = "0"
if kernel_backend == "FlashAttention":
os.environ["NVTE_FLASH_ATTN"] = "1"
config = model_configs_flash_attn[model]
config = copy.deepcopy(model_configs_flash_attn[model])
if kernel_backend == "FusedAttention":
os.environ["NVTE_FUSED_ATTN"] = "1"
config = model_configs_fused_attn[model]
config = copy.deepcopy(model_configs_fused_attn[model])
assert config.attn_mask_type in [
"causal",
"no_mask",
Expand All @@ -223,18 +227,18 @@ def run_dpa_with_cp(
else:
config.attn_mask_type = "padding"

# set up distributed group
rank = int(os.getenv("RANK", "0"))
world_size = int(os.getenv("WORLD_SIZE", "1"))
if dist.is_initialized():
world_size = dist.get_world_size()
rank = dist.get_rank()
else:
# When called from batch main(), dist is already initialized — reuse it.
# When called standalone (legacy single-config), init here.
_owns_dist = not dist.is_initialized()
if _owns_dist:
rank = int(os.getenv("RANK", "0"))
world_size = int(os.getenv("WORLD_SIZE", "1"))
device_count = torch.cuda.device_count()
device = rank % device_count
torch.cuda.set_device(device)
torch.cuda.set_device(rank % device_count)
dist.init_process_group(backend="nccl", world_size=world_size, rank=rank)
world_size = dist.get_world_size()
rank = dist.get_rank()
logging.info(f"[Rank {rank}] Setup: world_size {world_size}")
dist.init_process_group(backend="nccl", world_size=world_size, rank=rank)

# set up communication group for CP
cp_comm_ranks = range(world_size)
Expand Down Expand Up @@ -630,7 +634,6 @@ def run_dpa_with_cp(
== 0
)
else:
# Forward-only: reshape only out/out_ for comparison
out = out.index_select(0, seq_idx_q).contiguous()
out_ = out_

Expand Down Expand Up @@ -762,14 +765,94 @@ def run_dpa_with_cp(
)
logging.info(f"[Rank {rank}] CP vs no-CP: {names[i]} matches")

# destroy distribution group
dist.destroy_process_group()
dist.destroy_process_group(cp_comm_group)
if cp_comm_type == "a2a+p2p":
for sg in cp_comm_sub_groups:
dist.destroy_process_group(sg)
if _owns_dist:
dist.destroy_process_group()


_TRANSIENT_ENV_KEYS = (
"NVTE_FP8_DPA_BWD",
"NVTE_DPA_FP8CS_O_in_F16",
"NVTE_FLASH_ATTN",
"NVTE_FUSED_ATTN",
"NVTE_ALLOW_NONDETERMINISTIC_ALGO",
)


def _init_distributed():
rank = int(os.getenv("RANK", "0"))
world_size = int(os.getenv("WORLD_SIZE", "1"))
device_count = torch.cuda.device_count()
local_rank = int(os.getenv("LOCAL_RANK", str(rank % device_count)))
torch.cuda.set_device(local_rank)
dist.init_process_group(backend="nccl", world_size=world_size, rank=rank)
return rank, world_size


def _run_single_config(kwargs):
"""Run one config, return ``(ok, error_message)``."""
torch.manual_seed(1234)
torch.cuda.manual_seed(1234)
try:
run_dpa_with_cp(**kwargs)
return True, None
except BaseException: # noqa: BLE001 - capture any failure for per-config reporting
return False, traceback.format_exc()


def main(**kwargs):
run_dpa_with_cp(**kwargs)
"""Single-config (key=val args) or batch (batch_config_json=<path>) entry point."""
batch_path = kwargs.pop("batch_config_json", None)
rank, _ = _init_distributed()
try:
if batch_path is None:
run_dpa_with_cp(**kwargs)
else:
with open(batch_path, "r") as f:
configs = json.load(f)
assert isinstance(
configs, list
), f"batch_config_json must be a JSON list, got {type(configs)}"
results_path = batch_path + ".results.json"
results = []

def _flush_results():
if rank != 0:
return
tmp_path = results_path + ".tmp"
with open(tmp_path, "w") as f:
json.dump(results, f)
os.replace(tmp_path, results_path)

for cfg in configs:
FP8GlobalStateManager.reset()
for env_key in _TRANSIENT_ENV_KEYS:
os.environ.pop(env_key, None)
ok, err = _run_single_config(cfg)
ok_tensor = torch.tensor(1 if ok else 0, dtype=torch.int32, device="cuda")
dist.all_reduce(ok_tensor, op=dist.ReduceOp.MIN)
ok_aggregate = bool(ok_tensor.item())
if not ok_aggregate and ok and err is None:
err = "Failed on a non-zero rank (see subprocess stderr for traceback)"
Comment on lines +838 to +839
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Non-rank-0 failure traceback is swallowed and the error guidance is wrong

When a non-zero rank fails inside _run_single_config, its traceback is captured in that rank's local err variable but never transmitted to rank 0. The all_reduce propagates the ok=0 flag correctly, but rank 0 only records "Failed on a non-zero rank (see subprocess stderr for traceback)". That guidance is wrong: because _run_single_config catches the exception on rank 1, rank 1 exits cleanly and torchrun exits with code 0 — there is no traceback in subprocess stderr. A developer investigating the failure would find nothing there.

This is a regression from the original non-batched flow where rank 1's uncaught exception printed directly to torchrun's stderr and was captured by run_distributed. A minimal fix is to have the failing rank(s) print their traceback to sys.stderr before returning from _run_single_config, so it appears in torchrun's captured output even when the process exits cleanly.

results.append({"ok": ok_aggregate, "error": err})
_flush_results()
try:
dist.barrier()
except BaseException: # noqa: BLE001
results[-1]["ok"] = False
if results[-1]["error"] is None:
results[-1]["error"] = traceback.format_exc()
_flush_results()
break
torch.cuda.empty_cache()
Comment thread
sudhakarsingh27 marked this conversation as resolved.
finally:
if dist.is_initialized():
dist.destroy_process_group()


if __name__ == "__main__":
kwargs = dict(arg.split("=") for arg in sys.argv[2:])
kwargs = dict(arg.split("=", 1) for arg in sys.argv[2:])
main(**kwargs)
Loading
Loading