From 29413187eb6a84a8032032e7f033371f6f83e47c Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Fri, 19 Jan 2024 09:03:24 -0800 Subject: [PATCH 01/15] Avoid using torch.compile for roll and fill_ (#609) Signed-off-by: Kirthi Shankar Sivamani --- transformer_engine/pytorch/fp8.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/fp8.py b/transformer_engine/pytorch/fp8.py index 7bec34c861..d4d82cf0be 100644 --- a/transformer_engine/pytorch/fp8.py +++ b/transformer_engine/pytorch/fp8.py @@ -583,7 +583,7 @@ def _update_amax_history(amax_history: torch.Tensor) -> torch.Tensor: return amax_history -@jit_fuser +@torch.jit.script def _default_get_amax( amax_history: torch.Tensor, amax_compute_algo: str, @@ -625,7 +625,7 @@ def _compute_scaling_factor_inverse( return torch.where(non_weight_mask, 1.0 / scale, scale_inv) -@jit_fuser +@torch.jit.script def _fused_amax_and_scale_update( amax_history: torch.Tensor, scale: torch.Tensor, From f26690abfcb863a78bfb32f91f4121537b2d07a3 Mon Sep 17 00:00:00 2001 From: Alp Dener Date: Wed, 17 Jan 2024 11:58:22 -0600 Subject: [PATCH 02/15] [PyTorch] Deferred Initialization via `device='meta'` option (#596) * Implemented deferred initialization via `device='meta'` option for te.Linear and added new PyTorch example to demonstrate its use with FullyShardedDataParallel execution. Signed-off-by: Alp Dener * correcting Float8Tensor initialization and fixing linting errors Signed-off-by: Alp Dener * removed duplicate code from upstream rebase, local tests passing Signed-off-by: Alp Dener * improved comments/documentation for FSDP example Signed-off-by: Alp Dener * converted reset_parameters() into a base module function Signed-off-by: Alp Dener * fixed Float8Tensor creation with deferred init, all tests passing locally Signed-off-by: Alp Dener * extended deferred initialization to all TE modules Signed-off-by: Alp Dener * fixed linting errors Signed-off-by: Alp Dener * removed unnecessary reference to the parent module of parameter, added clarifying comments in parameter reset Signed-off-by: Alp Dener --------- Signed-off-by: Alp Dener --- examples/pytorch/fsdp/README.md | 53 +++++ examples/pytorch/fsdp/fsdp.py | 195 ++++++++++++++++++ transformer_engine/pytorch/module/_common.py | 19 +- transformer_engine/pytorch/module/base.py | 49 +++++ .../pytorch/module/layernorm.py | 17 +- .../pytorch/module/layernorm_linear.py | 57 +++-- .../pytorch/module/layernorm_mlp.py | 92 ++++----- transformer_engine/pytorch/module/linear.py | 48 ++--- transformer_engine/pytorch/module/rmsnorm.py | 15 +- transformer_engine/pytorch/utils.py | 15 ++ 10 files changed, 441 insertions(+), 119 deletions(-) create mode 100644 examples/pytorch/fsdp/README.md create mode 100644 examples/pytorch/fsdp/fsdp.py diff --git a/examples/pytorch/fsdp/README.md b/examples/pytorch/fsdp/README.md new file mode 100644 index 0000000000..d492ea4a57 --- /dev/null +++ b/examples/pytorch/fsdp/README.md @@ -0,0 +1,53 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +# Basic Example for Using PyTorch Fully Sharded Data Parallel mode with Transformer Engine + +```bash +# FSDP without deferred initialization: +# Duplicate modules initialized on each device. Load on device memory reduced only after +# torch.distributed.fsdp.FullyShardedDataParallel mode shards model parameters. +$ torchrun --standalone --nnodes=1 --nproc-per-node=$(nvidia-smi -L | wc -l) fsdp.py +# Sample output on 8xL40S: +# [GPU-0] WORLD_SIZE = 8 +# [GPU-0] TransformerEngine Model: +# TransformerLayer( +# (self_attention): MultiheadAttention( +# (layernorm_qkv): LayerNormLinear() +# (core_attention): DotProductAttention( +# (flash_attention): FlashAttention() +# (fused_attention): FusedAttention() +# (unfused_attention): UnfusedDotProductAttention( +# (scale_mask_softmax): FusedScaleMaskSoftmax() +# (attention_dropout): Dropout(p=0.1, inplace=False) +# ) +# ) +# (proj): Linear() +# ) +# (layernorm_mlp): LayerNormMLP() +# ) +# [GPU-0] Pre-FSDP memory use = 83.935232MiB +# [GPU-0] Post-FSDP memory use = 10.491904MiB +# [GPU-0] Iter. 1 +# [GPU-0] Iter. 2 +# [GPU-0] Iter. 3 +# [GPU-0] Training Time: 6.647654296875s +# [GPU-0] Avg. Iter. Time: 2.2158847656250003s +# [GPU-0] Peak memory use = 3000MiB + +# FSDP with deferred initialization: +# Modules initialized with empty paramaters via `device='meta'` option. Zero load on device +# memory until torch.distributed.fsdp.FullyShardedDataParallel mode triggers a reset on +# on already sharded model parameters. +$ torchrun --standalone --nnodes=1 --nproc-per-node=$(nvidia-smi -L | wc -l) fsdp.py --defer-init +# Sample output on 8xL40S: +# [GPU-0] WORLD_SIZE = 8 +# ... +# [GPU-0] Pre-FSDP memory use = 0.0MiB +# [GPU-0] Post-FSDP memory use = 10.491904MiB +# ... +``` + +**NOTE:** This example has `fp8_autocast()` enabled by default. To run on GPUs without Fp8 support +(e.g.: A100), add the `--no-fp8` option to the commands shown above. diff --git a/examples/pytorch/fsdp/fsdp.py b/examples/pytorch/fsdp/fsdp.py new file mode 100644 index 0000000000..5d30be6c97 --- /dev/null +++ b/examples/pytorch/fsdp/fsdp.py @@ -0,0 +1,195 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import os +import argparse +from functools import partial + +import torch +import torch.distributed as dist +from torch import nn +from torch.distributed.fsdp import FullyShardedDataParallel, MixedPrecision +from torch.distributed.fsdp.wrap import always_wrap_policy, transformer_auto_wrap_policy + +import transformer_engine.pytorch as te +from transformer_engine.common.recipe import Format, DelayedScaling + +def lowercase(s): + return str(s).lower() + +def torch_dtype(d): + typemap = { + 'fp32' : torch.float32, + 'float32' : torch.float32, + 'fp16' : torch.float16, + 'float16' : torch.float16, + 'bf16' : torch.bfloat16, + 'bfloat16' : torch.bfloat16 + } + if lowercase(d) not in typemap.keys(): + raise TypeError + return typemap[lowercase(d)] + +te_layer_map = { + 'linear': te.Linear, + 'layernorm': te.LayerNorm, + 'rmsnorm': te.RMSNorm, + 'layernormlinear': te.LayerNormLinear, + 'layernormmlp': te.LayerNormMLP, + 'multiheadattention': te.MultiheadAttention, + 'transformerlayer': te.TransformerLayer +} +def te_layer(l): + if lowercase(l) not in te_layer_map.keys(): + raise TypeError + return te_layer_map[lowercase(l)] + +def get_layer_args(args): + hidden_size = args.num_heads * args.head_dim + layer_args = (hidden_size, ) + layer_kwargs = { + 'params_dtype': args.dtype, + 'device': 'meta' if args.defer_init else 'cuda' + } + if args.layer_type in [te.Linear, te.LayerNormLinear, te.LayerNormMLP]: + ffn_hidden_size = 3 * hidden_size if args.num_layers == 1 else hidden_size + layer_args += (ffn_hidden_size, ) + layer_kwargs['bias'] = True + if args.layer_type == te.LayerNormMLP: + layer_kwargs['seq_length'] = args.seq_length + elif args.layer_type == te.MultiheadAttention: + layer_args += (args.num_heads, ) + layer_kwargs['fuse_qkv_params'] = True + elif args.layer_type == te.TransformerLayer: + layer_args += (3 * hidden_size, args.num_heads) + layer_kwargs['fuse_qkv_params'] = True + layer_kwargs['seq_length'] = args.seq_length + return layer_args, layer_kwargs + +def parse_fsdp_args(): + parser = argparse.ArgumentParser(description="Run Transformer Engine modules with the " + + "torch.distributed.fsdp.FullyShardedDataParallel strategy.") + parser.add_argument("-t", "--layer-type", type=te_layer, default=te.TransformerLayer, + choices=list(te_layer_map.values()), + help="TE module type used to construct the test model.") + parser.add_argument("--no-fp8", action="store_true", default=False, + help="Disables the te.fp8_autocast() context.") + parser.add_argument('-i', "--num-iters", type=int, default=3, + help="Number of dummy 'training' iterations.") + parser.add_argument('-b', "--batch-size", type=int, default=32, + help="Input batch size.") + parser.add_argument('-s', "--seq-length", type=int, default=1048, + help="Input sequence length.") + parser.add_argument('-n', "--num-heads", type=int, default=16, + help="Number of attention heads.") + parser.add_argument('-d', "--head-dim", type=int, default=128, + help="Dimension of each attention head (number of KV channels).") + parser.add_argument('-l', "--num-layers", type=int, default=1, + help="Number of modules chained together with nn.Sequential.") + parser.add_argument("--seed", type=int, default=1234, + help="PyTorch RNG seed.") + parser.add_argument("--defer-init", action="store_true", + help="Defer module parameter initialization until after FSDP sharding.") + parser.add_argument('-v', "--verbose", action="store_true", default=False, + help="Print out information from all GPUs instead of only the root GPU-0.") + parser.add_argument("--dtype", type=torch_dtype, default=torch.bfloat16, + help="Data type for input tensor and Transformer Engine module parameters.") + return parser.parse_args() + +def train(args): + local_rank = int(os.environ["LOCAL_RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + + # Initialize torch.distributed global process group + dist.init_process_group(backend="nccl") + torch.cuda.set_device(local_rank) + if local_rank == 0: + print(f"[GPU-0] WORLD_SIZE = {world_size}\n\n", end='') + torch.manual_seed(args.seed) + + # Construct a simple homogeneous model (only one layer type) with NO PARALLELISM + layer_args, layer_kwargs = get_layer_args(args) + if args.num_layers > 1: + te_layer_list = [] + for i in range(args.num_layers): + if args.layer_type in [te.MultiheadAttention, te.TransformerLayer]: + layer_kwargs['layer_number'] = i+1 + te_layer_list.append(args.layer_type(*layer_args, **layer_kwargs)) + te_model = nn.Sequential(*te_layer_list) + else: + # Single layer model + te_model = args.layer_type(*layer_args, **layer_kwargs) + if local_rank == 0: + print(f"[GPU-0] TransformerEngine Model:\n{te_model}\n", end='') + + # Print out allocated device memory before the model parameters are sharded by FSDP + pre_mem_use = torch.cuda.memory_allocated(device=f"cuda:{local_rank}") * 1e-6 + if local_rank == 0 or args.verbose: + print(f"[GPU-{local_rank}] Pre-FSDP memory use = {pre_mem_use}MiB\n", end='') + + # Wrap the model with FSDP + # NOTE: The TE model itself has no inherent parallelism. FSDP shards model parameters and + # controls all communication. + all_gpus = dist.new_group(backend='nccl') + fsdp_wrap_policy = always_wrap_policy + if args.layer_type == te.TransformerLayer: + # NOTE: FSDP causes illegal memory access without this special policy for Transformers + fsdp_wrap_policy = partial(transformer_auto_wrap_policy, + transformer_layer_cls={te.TransformerLayer}) + te_model = FullyShardedDataParallel(te_model, + process_group=all_gpus, + use_orig_params=True, + mixed_precision=MixedPrecision( + param_dtype=args.dtype, + reduce_dtype=torch.float32, + ), + sync_module_states=True, + auto_wrap_policy=fsdp_wrap_policy) + + # Print out allocated device memory after the model parameters are sharded + post_mem_use = torch.cuda.memory_allocated(device=f"cuda:{local_rank}") * 1e-6 + if local_rank == 0 or args.verbose: + print(f"[GPU-{local_rank}] Post-FSDP memory use = {post_mem_use}MiB\n", end='') + + # Fp8 setup for TE + fp8_format = Format.HYBRID + fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=32, amax_compute_algo="max") + + # Optimizer must be created after the model is wrapped in FSDP and the parameters are sharded + optim = torch.optim.Adam(te_model.parameters(), lr=0.0001) + + # Start and time dummy "training" iterations + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + torch.cuda.synchronize() + start.record() + for i in range(args.num_iters): + # Generate a random input batch + x = torch.rand(args.seq_length, args.batch_size, + args.num_heads*args.head_dim).to(dtype=args.dtype).cuda() + # fp8_autocast needs to be given the FSDP process group for amax reductions + with te.fp8_autocast(enabled=not args.no_fp8, fp8_recipe=fp8_recipe, fp8_group=all_gpus): + y = te_model(x) + loss = y.sum() + # calculate gradient and take training step outside the fp8_autocast context + loss.backward() + optim.step() + del x + if local_rank == 0: + print(f"[GPU-0] Iter. {i+1}\n", end='') + end.record() + torch.cuda.synchronize() + + # Print out "training" time and peak memory use stats + train_time = start.elapsed_time(end)/1000. + max_memory_alloc = int(torch.cuda.max_memory_allocated(device=f"cuda:{local_rank}") * 1e-6) + if local_rank == 0 or args.verbose: + print(f"[GPU-{local_rank}] Training Time: {train_time}s\n" + + f"[GPU-{local_rank}] Avg. Iter. Time: {train_time /args.num_iters}s\n" + + f"[GPU-{local_rank}] Peak memory use = {max_memory_alloc}MiB\n\n", end='') + + +if __name__ == "__main__": + args = parse_fsdp_args() + train(args) diff --git a/transformer_engine/pytorch/module/_common.py b/transformer_engine/pytorch/module/_common.py index edc3da120d..d2ab776288 100644 --- a/transformer_engine/pytorch/module/_common.py +++ b/transformer_engine/pytorch/module/_common.py @@ -4,12 +4,14 @@ """Internal function used by multiple modules.""" -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union, Callable +from dataclasses import dataclass import torch from .. import cpp_extensions as tex from ..fp8 import get_fp8_te_dtype +from ..utils import get_default_init_method def _get_normalization_func(normalization: str, fp8_output: bool, @@ -187,3 +189,18 @@ def _noop_cat( # Perform no-op concat return _NoopCatFunc.apply(split_ranges, full_tensor, *tensors) + + +@dataclass +class _ParameterInitMeta: + """ + Stores essential metadata needed to support deferred parameter initialization. + """ + init_fn: Optional[Callable] = get_default_init_method() + get_rng_state_tracker: Optional[Callable] = None + fp8_meta_index: Optional[int] = None + + def __post_init__(self): + """Safeguard reference to the parameter's parent module and initialization function.""" + if self.init_fn is None: + self.init_fn = get_default_init_method() diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index cf9634b2cc..ad1f383617 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -16,6 +16,7 @@ import torch.nn.functional as F import transformer_engine_extensions as tex +from ._common import _ParameterInitMeta from ..export import is_in_onnx_export_mode from ..fp8 import ( get_default_fp8_recipe, @@ -234,6 +235,8 @@ def __init__(self) -> None: self.fp8_meta["async_amax_reduction"] = bool( int(os.getenv("NVTE_ASYNC_AMAX_REDUCTION", "0")) ) + self.param_init_meta = {} + self.primary_weights_in_fp8 = FP8GlobalStateManager.with_fp8_parameters() def set_meta_tensor(self, fwd: bool) -> None: """Init scales and amaxes for fwd | bwd.""" @@ -746,6 +749,52 @@ def get_fp8_weights_empty_tensors( ) return fp8_weight_tensors + def register_parameter(self, name, param, **kwargs): + """ + Thin wrapper around PyTorch parameter registration to stash additional parameter + metedata used in deferred initialization. + """ + super().register_parameter(name, param) + self.param_init_meta[name] = _ParameterInitMeta(**kwargs) + + def reset_parameters(self, defer_init: Optional[bool] = False) -> None: + """ + Reset all module parameters to initial values. Unless deferred initialization + is specified, all parameters on a 'meta' device are also materialized on a real cuda + device before the values are reset to initial. + """ + if defer_init: + return + + for name, param in self.named_parameters(recurse=False): + # Ensure parameter is on a real device + if param.device == torch.device('meta'): + param = param.to(device='cuda') + + # Initialize the parameter values on device + init_fn = self.param_init_meta[name].init_fn + get_rng_state_tracker = self.param_init_meta[name].get_rng_state_tracker + if get_rng_state_tracker is None: + init_fn(param) + else: + with get_rng_state_tracker().fork(): + init_fn(param) + + # If primary weights are in fp8, wrap the parameter as Float8Tensor + fp8_meta_index = self.param_init_meta[name].fp8_meta_index + if self.primary_weights_in_fp8 and fp8_meta_index is not None: + param = Float8Tensor.to_float8( + param, + fp8_meta=self.fp8_meta, + fp8_meta_index=fp8_meta_index + ) + + # Redo parameter wrap in case we broke it above + # NOTE: Currently this can only be broken when primary weights are in Fp8 but + # re-applying the nn.Parameter() wrap is a no-op when the input is already + # a parameter so we always re-apply it just for extra safety. + setattr(self, name, torch.nn.Parameter(param)) + @abstractmethod def forward(self): """Needs override.""" diff --git a/transformer_engine/pytorch/module/layernorm.py b/transformer_engine/pytorch/module/layernorm.py index 653e23f4f3..fac941306f 100644 --- a/transformer_engine/pytorch/module/layernorm.py +++ b/transformer_engine/pytorch/module/layernorm.py @@ -4,6 +4,7 @@ """LayerNorm API""" import os +import warnings from typing import Union, Tuple, Optional import torch @@ -139,7 +140,8 @@ def __init__( ) setattr(self.weight, "sequence_parallel", sequence_parallel) setattr(self.bias, "sequence_parallel", sequence_parallel) - self.reset_layer_norm_parameters() + + self.reset_parameters(defer_init=(device == 'meta')) # These many SMs are subtracted from the total SM count when calling forward # and backward LayerNorm C APIs. These envvars can be used to prevent the LN @@ -150,12 +152,25 @@ def __init__( def reset_layer_norm_parameters(self) -> None: """Init LN params""" + warnings.warn( + ("This method will be deprecated in an upcoming release. " + "Update your code to use LayerNorm.reset_parameters() instead."), + DeprecationWarning, + stacklevel=2 + ) if not self.zero_centered_gamma: init.ones_(self.weight) else: init.zeros_(self.weight) init.zeros_(self.bias) + def reset_parameters(self, defer_init=False) -> None: + """Init LayerNorm parameters""" + if defer_init: + return + init.constant_(self.weight, float(not self.zero_centered_gamma)) + init.zeros_(self.bias) + @no_torch_dynamo() def forward(self, inp: torch.Tensor) -> torch.Tensor: """LayerNorm FWD""" diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index d36d5a9923..2e6803f992 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -25,6 +25,7 @@ from ..utils import ( divide, get_default_init_method, + init_method_constant, cast_if_needed, assert_dim_for_fp8_exec, clear_tensor_data, @@ -33,7 +34,6 @@ set_tensor_model_parallel_attributes, get_distributed_world_size, allreduce, - initialize_affine_weight_gpu, reduce_scatter_along_first_dim, gather_along_first_dim, ) @@ -749,43 +749,25 @@ def __init__( self.sequence_parallel = (self.tp_size > 1) and sequence_parallel self.eps = eps - self.layer_norm_weight = torch.nn.Parameter( + layer_norm_weight = torch.nn.Parameter( torch.empty(in_features, device=device, dtype=params_dtype) ) - setattr(self.layer_norm_weight, "sequence_parallel", self.sequence_parallel) + self.register_parameter('layer_norm_weight', layer_norm_weight, + init_fn=init_method_constant(float(not self.zero_centered_gamma))) + setattr(self.layer_norm_weight, "sequence_parallel", self.sequence_parallel) # pylint: disable=access-member-before-definition if self.normalization != "RMSNorm": - self.layer_norm_bias = torch.nn.Parameter( + layer_norm_bias = torch.nn.Parameter( torch.empty(in_features, device=device, dtype=params_dtype) ) - setattr(self.layer_norm_bias, "sequence_parallel", self.sequence_parallel) + self.register_parameter('layer_norm_bias', layer_norm_bias) + setattr(self.layer_norm_bias, "sequence_parallel", self.sequence_parallel) # pylint: disable=access-member-before-definition else: self.layer_norm_bias = None - self.reset_layer_norm_parameters() - temp_weight = torch.empty( + self.weight_tensor = torch.empty( self.out_features, self.in_features, device=device, dtype=params_dtype) - initialize_affine_weight_gpu( - temp_weight, - init_method, - get_rng_state_tracker, - partition_dim=1 if self.parallel_mode == "row" else 0, - stride=1, - ) - - if self.primary_weights_in_fp8: - self.init_fp8_metadata() - self.fp8_meta["update_amax_and_scale_fwd"] = True - - self.weight_tensor = Float8Tensor.to_float8( - temp_weight, - fp8_meta=self.fp8_meta, - fp8_meta_index=tex.FP8FwdTensors.GEMM1_WEIGHT, - ) - else: - self.weight_tensor = temp_weight - if self.use_bias: self.bias_tensor = torch.empty( self.out_features, @@ -794,9 +776,6 @@ def __init__( else: self.bias_tensor = torch.Tensor().to(dtype=params_dtype, device=device) - with torch.no_grad(): - self.bias_tensor.zero_() - # Configure parameter splits self.weight_names = [] self.bias_names = [] @@ -861,7 +840,10 @@ def __init__( if is_subview: weight = weight[split_start:split_end] weight = torch.nn.Parameter(weight) - self.register_parameter(self.weight_names[i], weight) + self.register_parameter(self.weight_names[i], weight, + init_fn=init_method, + get_rng_state_tracker=get_rng_state_tracker, + fp8_meta_index=tex.FP8FwdTensors.GEMM1_WEIGHT) # Construct bias parameter if needed if self.use_bias: @@ -892,8 +874,13 @@ def __init__( del self.weight_tensor del self.bias_tensor - self.fp8_weight_shapes.append(torch.Size((self.out_features, self.in_features))) + if self.primary_weights_in_fp8: + self.init_fp8_metadata() + self.fp8_meta["update_amax_and_scale_fwd"] = True + + self.reset_parameters(defer_init=(device == 'meta')) + self.fp8_weight_shapes.append(torch.Size((self.out_features, self.in_features))) # For RPL, bias has to be added after TP collectives # So it cannot be fused with the GEMM @@ -911,6 +898,12 @@ def __init__( def reset_layer_norm_parameters(self) -> None: """Init LN params""" + warnings.warn( + ("This method will be deprecated in an upcoming release. " + "Update your code to use LayerNormLinear.reset_parameters() instead."), + DeprecationWarning, + stacklevel=2 + ) if not self.zero_centered_gamma: init.ones_(self.layer_norm_weight) else: diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index e5e884cd22..8f88d725ad 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -30,6 +30,7 @@ from ..utils import ( divide, get_default_init_method, + init_method_constant, cast_if_needed, assert_dim_for_fp8_exec, clear_tensor_data, @@ -38,7 +39,6 @@ set_tensor_model_parallel_attributes, get_distributed_world_size, allreduce, - initialize_affine_weight_gpu, reduce_scatter_along_first_dim, gather_along_first_dim, ) @@ -1170,91 +1170,76 @@ def __init__( # LN init self.eps = eps - self.layer_norm_weight = Parameter( + layer_norm_weight = Parameter( torch.empty(hidden_size, device=device, dtype=params_dtype) ) + self.register_parameter('layer_norm_weight', layer_norm_weight, + init_fn=init_method_constant(float(not self.zero_centered_gamma))) setattr(self.layer_norm_weight, "sequence_parallel", self.sequence_parallel) if self.normalization != "RMSNorm": - self.layer_norm_bias = Parameter( + layer_norm_bias = Parameter( torch.empty(hidden_size, device=device, dtype=params_dtype) ) - setattr(self.layer_norm_bias, "sequence_parallel", self.sequence_parallel) + self.register_parameter('layer_norm_bias', layer_norm_bias) + setattr(self.layer_norm_bias, "sequence_parallel", self.sequence_parallel) # pylint: disable=access-member-before-definition else: self.layer_norm_bias = None - self.reset_layer_norm_parameters() + # FC1 init if self.activation in ['reglu', 'geglu', 'swiglu']: fc1_output_features = 2 * self.size_per_partition else: fc1_output_features = self.size_per_partition - # FC1 init - fc1_temp_weight = torch.empty( - fc1_output_features, hidden_size, device=device, dtype=params_dtype) - - initialize_affine_weight_gpu( - fc1_temp_weight, - init_method, - get_rng_state_tracker, - set_tp_attributes=False, - ) - if self.primary_weights_in_fp8: - self.init_fp8_metadata(num_gemms=2) - self.fp8_meta["update_amax_and_scale_fwd"] = True - - fc1_temp_weight = Float8Tensor.to_float8( - fc1_temp_weight, - fp8_meta=self.fp8_meta, - fp8_meta_index=tex.FP8FwdTensors.GEMM1_WEIGHT, + fc1_weight = Parameter( + torch.empty( + fc1_output_features, hidden_size, device=device, dtype=params_dtype ) - - self.fc1_weight = Parameter(fc1_temp_weight) + ) + self.register_parameter('fc1_weight', fc1_weight, + init_fn=init_method, + get_rng_state_tracker=get_rng_state_tracker, + fp8_meta_index=tex.FP8FwdTensors.GEMM1_WEIGHT) set_tensor_model_parallel_attributes(self.fc1_weight, True, 0, 1) self.fp8_weight_shapes.append(self.fc1_weight.shape) if self.use_bias: - self.fc1_bias = Parameter( + fc1_bias = Parameter( torch.empty(fc1_output_features, device=device, dtype=params_dtype) ) - set_tensor_model_parallel_attributes(self.fc1_bias, True, 0, 1) + self.register_parameter('fc1_bias', fc1_bias) + set_tensor_model_parallel_attributes(self.fc1_bias, True, 0, 1) # pylint: disable=access-member-before-definition else: self.fc1_bias = torch.Tensor().to(dtype=params_dtype, device=device) - with torch.no_grad(): - self.fc1_bias.zero_() - # FC2 init - fc2_temp_weight = torch.empty( - hidden_size, self.size_per_partition, device=device, dtype=params_dtype) - - initialize_affine_weight_gpu( - fc2_temp_weight, - output_layer_init_method, - get_rng_state_tracker, - set_tp_attributes=False, + fc2_weight = Parameter( + torch.empty(hidden_size, self.size_per_partition, device=device, dtype=params_dtype) ) - - if self.primary_weights_in_fp8: - fc2_temp_weight = Float8Tensor.to_float8( - fc2_temp_weight, - fp8_meta=self.fp8_meta, - fp8_meta_index=tex.FP8FwdTensors.GEMM2_WEIGHT, - ) - - self.fc2_weight = Parameter(fc2_temp_weight) + self.register_parameter('fc2_weight', fc2_weight, + init_fn=output_layer_init_method, + get_rng_state_tracker=get_rng_state_tracker, + fp8_meta_index=tex.FP8FwdTensors.GEMM2_WEIGHT) set_tensor_model_parallel_attributes(self.fc2_weight, True, 1, 1) self.fp8_weight_shapes.append(self.fc2_weight.shape) if self.use_bias: - self.fc2_bias = Parameter( + fc2_bias = Parameter( torch.empty(hidden_size, device=device, dtype=params_dtype) ) + self.register_parameter('fc2_bias', fc2_bias) # RPL if self.set_parallel_mode: - setattr(self.fc2_bias, "sequence_parallel", sequence_parallel) + setattr(self.fc2_bias, "sequence_parallel", sequence_parallel) # pylint: disable=access-member-before-definition else: self.fc2_bias = torch.Tensor().to(dtype=params_dtype, device=device) + if self.primary_weights_in_fp8: + self.init_fp8_metadata(num_gemms=2) + self.fp8_meta["update_amax_and_scale_fwd"] = True + + self.reset_parameters(defer_init=(device == 'meta')) + # For RPL, bias has to be added after TP collectives # So it cannot be fused with the GEMM if self.set_parallel_mode and self.apply_bias: @@ -1262,9 +1247,6 @@ def __init__( else: self.gemm_bias_unfused_add = False - with torch.no_grad(): - self.fc2_bias.zero_() - if self.bias_gelu_nvfusion: set_jit_fusion_options() if seq_length and micro_batch_size: @@ -1281,6 +1263,12 @@ def __init__( def reset_layer_norm_parameters(self) -> None: """Init LN params""" + warnings.warn( + ("This method will be deprecated in an upcoming release. " + "Update your code to use LayerNormMLP.reset_parameters() instead."), + DeprecationWarning, + stacklevel=2 + ) if not self.zero_centered_gamma: init.ones_(self.layer_norm_weight) else: diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 2a28d67292..2cad516881 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -23,7 +23,6 @@ from ..fp8 import get_fp8_te_dtype, FP8GlobalStateManager from ..utils import ( divide, - get_default_init_method, cast_if_needed, assert_dim_for_fp8_exec, clear_tensor_data, @@ -32,7 +31,6 @@ set_tensor_model_parallel_attributes, get_distributed_world_size, allreduce, - initialize_affine_weight_gpu, reduce_scatter_along_first_dim, gather_along_first_dim, ) @@ -82,7 +80,7 @@ def forward( ub_split_ag: bool, ub_atomic_gemm_rs: bool, ub_atomic_gemm_ag: bool, - ub_name: str, + ub_name: str ) -> torch.Tensor: # Make sure input dimensions are compatible in_features = weight.shape[-1] @@ -625,6 +623,10 @@ def __init__( if any([ub_atomic_gemm_rs, ub_atomic_gemm_ag]): assert ub_name is not None, "Userbuffer name [string] is not set." self.ub_name = ub_name + self.get_rng_state_tracker = get_rng_state_tracker + if device == 'meta': + assert parameters_split is None, ("Cannot split module parameters " + "on 'meta' device.") if ub_split_rs or ub_split_ag or ub_atomic_gemm_rs: assert ( @@ -655,44 +657,17 @@ def __init__( elif self.parallel_mode == "row": self.in_features = divide(self.in_features, self.tp_size) - if init_method is None: - init_method = get_default_init_method() - self.sequence_parallel = (self.tp_size > 1) and sequence_parallel - temp_weight = torch.empty( + self.weight_tensor = torch.empty( self.out_features, self.in_features, device=device, dtype=params_dtype) - # TODO(ksivaman): This functionality works with FP8 outside TE. - initialize_affine_weight_gpu( - temp_weight, - init_method, - get_rng_state_tracker, - partition_dim=1 if self.parallel_mode == "row" else 0, - stride=1, - ) - - if self.primary_weights_in_fp8: - self.init_fp8_metadata() - self.fp8_meta["update_amax_and_scale_fwd"] = True - - self.weight_tensor = Float8Tensor.to_float8( - temp_weight, - fp8_meta=self.fp8_meta, - fp8_meta_index=tex.FP8FwdTensors.GEMM1_WEIGHT, - ) - else: - self.weight_tensor = temp_weight - if self.use_bias: self.bias_tensor = torch.empty(self.out_features, device=device, dtype=params_dtype) else: self.bias_tensor = torch.Tensor().to(dtype=params_dtype, device=device) - with torch.no_grad(): - self.bias_tensor.zero_() - # Configure parameter splits self.weight_names = [] self.bias_names = [] @@ -757,7 +732,10 @@ def __init__( if is_subview: weight = weight[split_start:split_end] weight = torch.nn.Parameter(weight) - self.register_parameter(self.weight_names[i], weight) + self.register_parameter(self.weight_names[i], weight, + init_fn=init_method, + get_rng_state_tracker=get_rng_state_tracker, + fp8_meta_index=tex.FP8FwdTensors.GEMM1_WEIGHT) # Construct bias parameter if needed if self.use_bias: @@ -788,6 +766,12 @@ def __init__( del self.weight_tensor del self.bias_tensor + if self.primary_weights_in_fp8: + self.init_fp8_metadata() + self.fp8_meta["update_amax_and_scale_fwd"] = True + + self.reset_parameters(defer_init=(device == 'meta')) + self.fp8_weight_shapes.append(torch.Size((self.out_features, self.in_features))) # For RPL, bias has to be added after TP collectives diff --git a/transformer_engine/pytorch/module/rmsnorm.py b/transformer_engine/pytorch/module/rmsnorm.py index 8da16d1c38..cad357de04 100644 --- a/transformer_engine/pytorch/module/rmsnorm.py +++ b/transformer_engine/pytorch/module/rmsnorm.py @@ -4,6 +4,7 @@ """RMSNorm API""" import os +import warnings from typing import Union, Tuple, Optional import torch @@ -141,7 +142,8 @@ def __init__( ) ) setattr(self.weight, "sequence_parallel", sequence_parallel) - self.reset_rms_norm_parameters() + + self.reset_parameters(defer_init=(device == 'meta')) # These many SMs are subtracted from the total SM count when calling forward # and backward RMSNorm C APIs. These envvars can be used to prevent the LN @@ -152,11 +154,22 @@ def __init__( def reset_rms_norm_parameters(self) -> None: """Init RMSNorm params""" + warnings.warn( + ("This method will be deprecated in an upcoming release. " + "Update your code to use RMSNorm.reset_parameters() instead."), + DeprecationWarning, + stacklevel=2 + ) if not self.zero_centered_gamma: init.ones_(self.weight) else: init.zeros_(self.weight) + def reset_parameters(self, defer_init=False) -> None: + """Reset RMSNorm parameters""" + if defer_init: + return + init.constant_(self.weight, float(not self.zero_centered_gamma)) @no_torch_dynamo() def forward(self, inp: torch.Tensor) -> torch.Tensor: diff --git a/transformer_engine/pytorch/utils.py b/transformer_engine/pytorch/utils.py index 6250c07d60..819b3d4827 100644 --- a/transformer_engine/pytorch/utils.py +++ b/transformer_engine/pytorch/utils.py @@ -40,6 +40,21 @@ def get_default_init_method() -> Callable: return init_method_normal(0.023) +def init_method_constant(val: float) -> Callable: + """Init method to set all tensor elements to a constant value.""" + if val == 1.0: + def init_(tensor: torch.Tensor) -> Callable: + return torch.nn.init.ones_(tensor) + elif val == 0.0: + def init_(tensor: torch.Tensor) -> Callable: + return torch.nn.init.zeros_(tensor) + else: + def init_(tensor: torch.Tensor) -> Callable: + return torch.nn.init.constant_(tensor, val) + + return init_ + + def init_method_normal(sigma: float) -> Callable: """Init method based on N(0, sigma).""" From f6dd3fff261cf8b22d59ed952adf1a77ffcbfa60 Mon Sep 17 00:00:00 2001 From: Sudhakar Singh Date: Thu, 18 Jan 2024 13:41:24 -0800 Subject: [PATCH 03/15] make TransformerLayer accept a `bshd` or `sbhd` tensor format (#557) * make TransformerLayer accept a `bshd` or `sbhd` tensor format Signed-off-by: Sudhakar Singh * Fixes from feedback Signed-off-by: Sudhakar Singh * more feedback fixes Signed-off-by: Sudhakar Singh * remove incorrect info from docstring Signed-off-by: Sudhakar Singh * fix from feedback Signed-off-by: Sudhakar Singh --------- Signed-off-by: Sudhakar Singh --- tests/pytorch/fused_attn/test_fused_attn.py | 11 ++- tests/pytorch/test_numerics.py | 77 +++++++++++++++++++++ transformer_engine/pytorch/attention.py | 46 ++++++++++-- transformer_engine/pytorch/transformer.py | 12 ++++ 4 files changed, 137 insertions(+), 9 deletions(-) diff --git a/tests/pytorch/fused_attn/test_fused_attn.py b/tests/pytorch/fused_attn/test_fused_attn.py index 3f8962504b..296d9ff214 100644 --- a/tests/pytorch/fused_attn/test_fused_attn.py +++ b/tests/pytorch/fused_attn/test_fused_attn.py @@ -666,10 +666,10 @@ def test_transformer_layer(dtype, model_configs, model, ckpt_attn, qkv_format, f @pytest.mark.parametrize("dtype", param_types_lean) @pytest.mark.parametrize("model_configs", [model_configs_te_layer]) @pytest.mark.parametrize("model", ["te_1_2", "te_2_0"]) -def test_te_layer_misc(dtype, model_configs, model): +@pytest.mark.parametrize("qkv_format", ["bshd", "sbhd"]) +def test_te_layer_misc(dtype, model_configs, model, qkv_format): """Test TransformerLayer module with miscellanous settings""" ckpt_attn = True - qkv_format = "bshd" fused_qkv_params = True RoPE = True test_transformer_layer(dtype, model_configs, model, @@ -705,7 +705,7 @@ def _run_transformer_layer( config: ModelConfig, backend: str, ckpt_attn: bool, - qkv_layout: str, + qkv_format: str, workspace_opt: bool, fused_qkv_params: bool, RoPE: bool, @@ -724,6 +724,10 @@ def _run_transformer_layer( # Create input tensor inp = torch.randn(config.max_seqlen_q, config.batch_size, config.hidden_size, dtype=dtype, device="cuda", requires_grad = True) + # In case the format to be tested is batch-first, need to transpose the + # input tensor. + if qkv_format == "bshd": + inp = inp.transpose(0,1) # Create seqlens if "padding" in config.attn_mask_type: @@ -815,6 +819,7 @@ def _run_transformer_layer( qkv_weight_interleaved=False, ub_tp_comm_overlap=False, bias=True, + attn_input_format=qkv_format, ) .to(dtype=dtype, device="cuda") ) diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index acc3cbeda3..de7c84695c 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -1197,3 +1197,80 @@ def test_gpt_fp8_parameters(dtype, bs, model): outputs = _test_gpt_fp8_parameters(bs, dtype, config, False) outputs_fp8_params = _test_gpt_fp8_parameters(bs, dtype, config, True) assert_all_equal(outputs, outputs_fp8_params) + +@pytest.mark.parametrize("dtype", param_types) +@pytest.mark.parametrize("bs", batch_sizes) +@pytest.mark.parametrize("model", model_configs.keys()) +def test_transformer_layer_hidden_states_format(dtype, bs, model): + config = model_configs[model] + + sigma = 0.023 + init_method = init_method_normal(sigma) + output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers) + + # Set `torch.manual_seed` to make sure the weights are identical to the + # other layer. Set `*dropout` values to 0 to make sure the forward pass + # is identical to the other layer. + torch.manual_seed(0) + block_sbhd = ( + TransformerLayer( + config.hidden_size, + 4 * config.hidden_size, + config.num_attention_heads, + layernorm_epsilon=config.eps, + init_method=init_method, + output_layer_init_method=output_layer_init_method, + hidden_dropout=0, + attention_dropout=0, + kv_channels=config.embed, + apply_residual_connection_post_layernorm=False, + output_layernorm=False, + hidden_states_format="sbhd" + ) + .to(dtype=dtype) + .cuda() + ) + + # Set `torch.manual_seed` to make sure the weights are identical to the + # other layer. Set `*dropout` values to 0 to make sure the forward pass + # is identical to the other layer. + torch.manual_seed(0) + block_bshd = ( + TransformerLayer( + config.hidden_size, + 4 * config.hidden_size, + config.num_attention_heads, + layernorm_epsilon=config.eps, + init_method=init_method, + output_layer_init_method=output_layer_init_method, + hidden_dropout=0, + attention_dropout=0, + kv_channels=config.embed, + apply_residual_connection_post_layernorm=False, + output_layernorm=False, + hidden_states_format="bshd" + ) + .to(dtype=dtype) + .cuda() + ) + + for (n1, p1), (n2, p2) in zip(block_bshd.named_parameters(), block_sbhd.named_parameters()): + assert torch.all(torch.eq(p1, p2)), f"{n1}, {n2} not identical" + + x_sbhd = torch.randn( + config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True + ).to(dtype).cuda() + + x_bshd = x_sbhd.transpose(0,1).contiguous() + + # To make sure forward is also identical (just in case some module decides + # to act fancy) + torch.manual_seed(0) + y_sbhd = block_sbhd(x_sbhd) + + # To make sure forward is also identical (just in case some module decides + # to act fancy) + torch.manual_seed(0) + y_bshd = block_bshd(x_bshd) + + assert_all_equal([y_bshd], [y_sbhd.transpose(0,1).contiguous()]) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 750bc0403c..9316b32864 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -1034,11 +1034,34 @@ def _rotate_half(x: torch.Tensor) -> torch.Tensor: return torch.cat((-x2, x1), dim=-1) -def apply_rotary_pos_emb(t: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor: +def apply_rotary_pos_emb(t: torch.Tensor, freqs: torch.Tensor, tensor_format: str = "sbhd") -> torch.Tensor: """ - input tensor t is of shape [seq_length, ..., dim] - rotary positional embeding tensor `freqs` is of shape [seq_length, ..., dim] + Parameters + ---------- + t: torch.Tensor + input tensor on which rotary positional embedding will be applied + freqs: torch.Tensor + rotary positional embeding tensor `freqs` is of shape + `[seq_length, ..., dim]` + tensor_format: {'sbhd', 'bshd'}, default = 'sbhd' + is `bshd` if `t` is of shape `[bs, seq, ...]`, or `sbhd` if `t` is + of shape `[seq, bs, ...]`. + """ + assert tensor_format in ("sbhd", "bshd"),("Only formats `sbhd` or `bshd` " + "are supported for input tensor " + "`t`.") + max_seq_len = freqs.shape[0] + cur_seq_len = t.shape[1] if tensor_format == "bshd" else t.shape[0] + + # Only apply the rotary embeddings up to the sequence length of the running + # input. + assert cur_seq_len <= max_seq_len, (f"Rotary Embeddings only supported " + "upto {max_seq_len} sequence length!") + freqs = freqs[:cur_seq_len].to(t.dtype) + if tensor_format == "bshd": + freqs = freqs.transpose(0,1) # [seq, 1, 1, dim] -> [1, seq, 1, dim] + rot_dim = freqs.shape[-1] # ideally t_pass is empty so rotary pos embedding is applied to all tensor t t, t_pass = t[..., :rot_dim], t[..., rot_dim:] @@ -2821,6 +2844,14 @@ class MultiheadAttention(torch.nn.Module): The device on which the parameters of the model will allocated. It is the user's responsibility to ensure all parameters are moved to the GPU before running the forward pass. + qkv_format: str, default = `sbhd` + dimension format for `query_layer`, `key_layer` and `value_layer`, + {`sbhd`, `bshd`}. `s` stands for the sequence length, `b` batch size, + `h` the number of heads and `d` head size. `sbhd` and `bshd` formats + are used for when sequences in a batch are of equal length or padded to + equal length. Please note that these formats do not reflect how + tensors `query_layer`, `key_layer`, `value_layer` are laid out in memory. + For that, please use `_get_qkv_layout` to gain the layout information. Parallelism parameters ---------------------- @@ -2899,9 +2930,11 @@ def __init__( bias: bool = True, normalization: str = "LayerNorm", device: Union[torch.device, str] = "cuda", + qkv_format: str = "sbhd", ) -> None: super().__init__() + self.qkv_format = qkv_format self.attn_mask_type = attn_mask_type self.window_size = window_size self.window_size = check_set_window_size(attn_mask_type, self.window_size) @@ -3045,6 +3078,7 @@ def __init__( kv_channels, num_gqa_groups=self.num_gqa_groups, attention_dropout=attention_dropout, + qkv_format=self.qkv_format, tp_size=tp_size, get_rng_state_tracker=get_rng_state_tracker, sequence_parallel=sequence_parallel, @@ -3398,14 +3432,14 @@ def forward( # apply relative positional encoding (rotary embedding) if rotary_pos_emb is not None: q_pos_emb, k_pos_emb = rotary_pos_emb - query_layer = apply_rotary_pos_emb(query_layer, q_pos_emb) - key_layer = apply_rotary_pos_emb(key_layer, k_pos_emb) + query_layer = apply_rotary_pos_emb(query_layer, q_pos_emb, self.qkv_format) + key_layer = apply_rotary_pos_emb(key_layer, k_pos_emb, self.qkv_format) context_layer = self.core_attention( query_layer, key_layer, value_layer, - qkv_format='sbhd', + qkv_format=self.qkv_format, cu_seqlens_q=None, cu_seqlens_kv=None, attention_mask=attention_mask, diff --git a/transformer_engine/pytorch/transformer.py b/transformer_engine/pytorch/transformer.py index f1c6194d29..addaf31689 100644 --- a/transformer_engine/pytorch/transformer.py +++ b/transformer_engine/pytorch/transformer.py @@ -168,6 +168,14 @@ class TransformerLayer(torch.nn.Module): The device on which the parameters of the model will allocated. It is the user's responsibility to ensure all parameters are moved to the GPU before running the forward pass. + attn_input_format: {'sbhd', 'bshd'}, default = 'sbhd' + This controls whether the dimensions of the + intermediate hidden states is 'batch first' ('bshd') or + 'sequence first' ('sbhd'). `s` stands for the sequence + length, `b` batch size, `h` the number of heads, `d` + head size. Note that these formats are very closely + related to the `qkv_format` in the `MultiHeadAttention` + and `DotProductAttention` modules. Parallelism parameters ---------------------- @@ -253,6 +261,7 @@ def __init__( activation: str = 'gelu', normalization: str = "LayerNorm", device: Union[torch.device, str] = "cuda", + attn_input_format: str = "sbhd", ) -> None: super().__init__() @@ -331,6 +340,8 @@ def __init__( self.get_rng_state_tracker = get_rng_state_tracker + self.attn_input_format = attn_input_format + attention_args = ( hidden_size, num_attention_heads, @@ -360,6 +371,7 @@ def __init__( "ub_split_rs" : ub_split_rs, "ub_atomic_gemm_rs" : ub_atomic_gemm_rs, "ub_atomic_gemm_ag" : ub_atomic_gemm_ag, + "qkv_format" : self.attn_input_format, } self.self_attention = MultiheadAttention( From b25611bd4ad36706552cdfb7c4798879e5eb0a5b Mon Sep 17 00:00:00 2001 From: Sudhakar Singh Date: Fri, 19 Jan 2024 23:29:30 -0800 Subject: [PATCH 04/15] Fix failing CI due to PR #557 merge (#616) fix failing tests due to PR #557 Signed-off-by: Sudhakar Singh Co-authored-by: cyanguwa <8636796+cyanguwa@users.noreply.github.com> --- tests/pytorch/test_numerics.py | 4 ++-- transformer_engine/pytorch/attention.py | 12 +++++++++--- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index de7c84695c..215cae2b97 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -1225,7 +1225,7 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model): kv_channels=config.embed, apply_residual_connection_post_layernorm=False, output_layernorm=False, - hidden_states_format="sbhd" + attn_input_format="sbhd" ) .to(dtype=dtype) .cuda() @@ -1248,7 +1248,7 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model): kv_channels=config.embed, apply_residual_connection_post_layernorm=False, output_layernorm=False, - hidden_states_format="bshd" + attn_input_format="bshd" ) .to(dtype=dtype) .cuda() diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 9316b32864..cf7bee8c66 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -1034,7 +1034,11 @@ def _rotate_half(x: torch.Tensor) -> torch.Tensor: return torch.cat((-x2, x1), dim=-1) -def apply_rotary_pos_emb(t: torch.Tensor, freqs: torch.Tensor, tensor_format: str = "sbhd") -> torch.Tensor: +def apply_rotary_pos_emb( + t: torch.Tensor, + freqs: torch.Tensor, + tensor_format: str = "sbhd" + ) -> torch.Tensor: """ Parameters ---------- @@ -1056,8 +1060,10 @@ def apply_rotary_pos_emb(t: torch.Tensor, freqs: torch.Tensor, tensor_format: st # Only apply the rotary embeddings up to the sequence length of the running # input. - assert cur_seq_len <= max_seq_len, (f"Rotary Embeddings only supported " - "upto {max_seq_len} sequence length!") + if cur_seq_len > max_seq_len: + raise Exception(f"Rotary Embeddings only supported upto {max_seq_len} " + "sequence length!") + freqs = freqs[:cur_seq_len].to(t.dtype) if tensor_format == "bshd": freqs = freqs.transpose(0,1) # [seq, 1, 1, dim] -> [1, seq, 1, dim] From c6f0a1f555ab315493032b0a77b0985654d42964 Mon Sep 17 00:00:00 2001 From: Selvaraj Anandaraj Date: Sun, 21 Jan 2024 01:47:13 -0800 Subject: [PATCH 05/15] Activation offloading to CPU's for the Linear, Layernorm Linear and the Layernorm MLP modules (#571) * Added support activation offloading to CPU's Signed-off-by: Selvaraj Anandaraj * Moving CPU offloading library to TE Signed-off-by: Selvaraj Anandaraj * Restructured code, added switch to choose between weight/activation offloading Signed-off-by: Selvaraj Anandaraj * Removed arg during constructor Signed-off-by: Selvaraj Anandaraj * Fix nit-pick errors Signed-off-by: Selvaraj Anandaraj * Documentation fixes Signed-off-by: Przemek Tredak * Fix to the code block in docs Signed-off-by: Przemek Tredak * Added offloading unit test Signed-off-by: Selvaraj Anandaraj * Fixed formatting Signed-off-by: Selvaraj Anandaraj * wgrad fusion fix, minor errors and lint Signed-off-by: Kirthi Shankar Sivamani * Errors, test, lint Signed-off-by: Kirthi Shankar Sivamani * RM test file Signed-off-by: Kirthi Shankar Sivamani * Fixed stray PyT tensors in LayernormMLP getting offloaded Signed-off-by: Selvaraj Anandaraj * Fixed typi Signed-off-by: Selvaraj Anandaraj * Fix offloading for rmsnorm, rm test Signed-off-by: Kirthi Shankar Sivamani * Fix errors Signed-off-by: Kirthi Shankar Sivamani * Float8Tensor compatible offloading Signed-off-by: Kirthi Shankar Sivamani * Cleanup Signed-off-by: Kirthi Shankar Sivamani --------- Signed-off-by: Selvaraj Anandaraj Signed-off-by: Przemek Tredak Signed-off-by: Kirthi Shankar Sivamani Co-authored-by: Selvaraj Anandaraj Co-authored-by: Przemyslaw Tredak Co-authored-by: Kirthi Shankar Sivamani --- docs/api/pytorch.rst | 2 + tests/pytorch/test_sanity.py | 24 +- transformer_engine/pytorch/__init__.py | 1 + transformer_engine/pytorch/cpu_offload.py | 506 ++++++++++++++++++ .../pytorch/module/layernorm_linear.py | 27 +- .../pytorch/module/layernorm_mlp.py | 38 +- transformer_engine/pytorch/module/linear.py | 26 +- 7 files changed, 615 insertions(+), 9 deletions(-) create mode 100644 transformer_engine/pytorch/cpu_offload.py diff --git a/docs/api/pytorch.rst b/docs/api/pytorch.rst index 7c81c2f071..9b291e6d0a 100644 --- a/docs/api/pytorch.rst +++ b/docs/api/pytorch.rst @@ -40,3 +40,5 @@ pyTorch .. autoapifunction:: transformer_engine.pytorch.checkpoint .. autoapifunction:: transformer_engine.pytorch.onnx_export + +.. autoapifunction:: transformer_engine.pytorch.get_cpu_offload_context diff --git a/tests/pytorch/test_sanity.py b/tests/pytorch/test_sanity.py index f1e172b36b..593231d6d1 100644 --- a/tests/pytorch/test_sanity.py +++ b/tests/pytorch/test_sanity.py @@ -4,6 +4,7 @@ from dataclasses import dataclass from typing import Optional +from contextlib import nullcontext import torch import pytest @@ -20,6 +21,7 @@ TransformerLayer, RMSNorm, LayerNorm, + get_cpu_offload_context, ) from transformer_engine.common import recipe @@ -215,7 +217,7 @@ def _test_sanity_e2e_gradient_accumulation_fusion(block, dtype, config, fp8_reci assert torch.count_nonzero(p.main_grad) > 0, "Gradient not accumulated." -def _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad): +def _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad, cpu_offload): te_inp_hidden_states = torch.randn( config.seq_len, config.batch_size, config.hidden_size, dtype=dtype, requires_grad=True ).cuda() @@ -223,9 +225,16 @@ def _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad): if skip_wgrad: _disable_wgrads(block) + if cpu_offload: + offload_context, sync_function = get_cpu_offload_context(enabled=True) + else: + offload_context = nullcontext() + sync_function = lambda x: x + use_fp8 = fp8_recipe is not None - with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe): + with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe), offload_context: te_out = block(te_inp_hidden_states) + te_out = sync_function(te_out) loss = te_out.sum() loss.backward() torch.cuda.synchronize() @@ -449,9 +458,11 @@ def test_sanity_layernorm_mlp(dtype, fp8_recipe, model, skip_wgrad, @pytest.mark.parametrize("activation", all_activations) @pytest.mark.parametrize("normalization", all_normalizations) @pytest.mark.parametrize("parallel_attention_mlp", all_boolean) +@pytest.mark.parametrize("cpu_offload", all_boolean) def test_sanity_gpt(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, bias, activation, - normalization, parallel_attention_mlp): + normalization, parallel_attention_mlp, + cpu_offload): config = model_configs[model] if fp8_recipe is not None: @@ -489,7 +500,7 @@ def test_sanity_gpt(dtype, fp8_recipe, model, skip_wgrad, .cuda() ) - _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad) + _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad, cpu_offload) def test_sanity_gpt_126m(): @@ -512,6 +523,7 @@ def test_sanity_gpt_126m(): activation="gelu", normalization="LayerNorm", parallel_attention_mlp=False, + cpu_offload=False, ) @@ -713,7 +725,7 @@ def test_sanity_drop_path(dtype, fp8_recipe, model, skip_wgrad): .cuda() ) - _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad) + _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad, False) @pytest.mark.parametrize("dtype", param_types) @@ -751,7 +763,7 @@ def test_sanity_fused_qkv_params(dtype, fp8_recipe, model, skip_wgrad): .cuda() ) - _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad) + _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad, False) @pytest.mark.parametrize("dtype", param_types) diff --git a/transformer_engine/pytorch/__init__.py b/transformer_engine/pytorch/__init__.py index 43ad38e108..16bd128734 100644 --- a/transformer_engine/pytorch/__init__.py +++ b/transformer_engine/pytorch/__init__.py @@ -17,6 +17,7 @@ from .export import onnx_export from .distributed import checkpoint from .distributed import CudaRNGStatesTracker +from .cpu_offload import get_cpu_offload_context # Register custom op symbolic ONNX functions from .te_onnx_extensions import ( onnx_cast_to_fp8, diff --git a/transformer_engine/pytorch/cpu_offload.py b/transformer_engine/pytorch/cpu_offload.py new file mode 100644 index 0000000000..dcede62ef7 --- /dev/null +++ b/transformer_engine/pytorch/cpu_offload.py @@ -0,0 +1,506 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Functionality for CPU offloading of tensors saved for backward pass.""" +from typing import Any +from contextlib import nullcontext +import torch + +from .float8_tensor import Float8Tensor + +__all__ = ['get_cpu_offload_context'] + +CPUOffloadEnabled = False + + +class CpuOffloadSavedTensorHook: + """Contex-manager that executes a pair of pack/unpack hooks for saved tensors. + + In this context, the ``on_save_for_backward`` method will be called every time + a tensor is saved for backward (this includes intermediary results saved using + :func:`~torch.autograd.function._ContextMethodMixin.save_for_backward` but + also those recorded by a PyTorch-defined operation). + + The ``on_get_saved_tensors`` method will be called when the backward function + of this op attempts to retrieve the saved tensor from context (this includes + :func: `torch.Tensor.backward()` or :func: `torch.autograd.grad()`. It takes the + as input the return value of the ``on_save_for_backward``, and is meant to return + an identical copy of the tensor being saved by ``on_save_for_backward`` in terms of + size, device and element values. + + Example: + + >>> import torch + >>> from typing import Any + >>> + >>> class DummyHook(CpuOffloadSavedTensorHook): + ... + ... def on_save_for_backward(self, tensor: torch.Tensor) -> Any: + ... logging.info("On save", tensor) + ... return (tensor,) + ... + ... def on_get_saved_tensor(self, saved_state: Any) -> torch.Tensor: + ... logging.info("On get", saved_state) + ... tensor, = saved_state + ... return tensor + ... + >>> a = torch.ones(5, requires_grad=True) + >>> b = torch.ones(5, requires_grad=True) * 2 + >>> with DummyHook(): + ... y = a * b + ... + On save tensor([1., 1., 1., 1., 1.], requires_grad=True) + On save tensor([2., 2., 2., 2., 2.], grad_fn=) + >>> y.sum().backward() + On get (tensor([1., 1., 1., 1., 1.], requires_grad=True),) + On get (tensor([2., 2., 2., 2., 2.], grad_fn=),) + + """ + + def __init__(self) -> None: + self.inside_context = False + + def __enter__(self): + global CPUOffloadEnabled + CPUOffloadEnabled = True + + self.inside_context = True + torch._C._autograd._push_saved_tensors_default_hooks( + self.on_save_for_backward, + self.on_get_saved_tensor + ) + + def __exit__(self, *args: Any): + global CPUOffloadEnabled + CPUOffloadEnabled = False + + self.inside_context = False + torch._C._autograd._pop_saved_tensors_default_hooks() + + + def on_save_for_backward(self, tensor: torch.Tensor) -> Any: + """On save for backward.""" + raise NotImplementedError("`on_save_for_backward: Callable[[torch.Tensor], Any]`" + "is not implemented in CpuOffloadHook class. Inherit " + "this class and implement your custom hooks") + + def on_get_saved_tensor(self, saved_state: Any) -> torch.Tensor: + """On get saved tensor.""" + raise NotImplementedError("`on_get_saved_tensors: Callable[[Any], torch.Tensor]`" + "is not implemented in CpuOffloadHook class. Inherit " + "this class and implement your custom hooks") + + +class CpuOffloadHookWithOffloadHandler(CpuOffloadSavedTensorHook): + """Context-manager that offloads/recovers tensors through an offload hander. + + The hook just offloads/recovers the tensor object to the handler through `tensor_push` + and `tensor_pop` interface. How the offload-handler manages the offloading, recovering + or prefetching timing is transparent to this hook. + """ + def __init__(self, offload_handler, handler_extra_kwargs={}, debug=False) -> None: # pylint: disable=dangerous-default-value + self.debug = debug + self.offload_handler = offload_handler + self.handler_extra_kwargs = handler_extra_kwargs + super().__init__() + + def on_save_for_backward(self, tensor: torch.Tensor) -> Any: + retrieve_identifier = self.offload_handler.tensor_push( + tensor, + **self.handler_extra_kwargs + ) + return retrieve_identifier + + def on_get_saved_tensor(self, saved_state: Any) -> torch.Tensor: + tensor = self.offload_handler.tensor_pop( + saved_state, + **self.handler_extra_kwargs + ) + return tensor + + +class OffloadHandler: + """A base class for CPU offload-handler.""" + def __init__(self) -> None: + pass + + def tensor_push(self, tensor: torch.Tensor, **kwargs) -> Any: + """Tensor push.""" + raise NotImplementedError("`tensor_push is not implented in OffloadHandler class. " + "Inherit this class and implement your custom tensor_push.") + + def tensor_pop(self, tensor_tag: Any, **kwargs): + """Tensor pop.""" + raise NotImplementedError("`tensor_pop is not implented in OffloadHandler class. " + "Inherit this class and implement your custom tensor_pop.") + + +class GroupCommitFunction(torch.autograd.Function): + """this is a dummy op with output identical to input. + However, it is necessary for marking a timepoint for offload handler to + accomplish all synchronizations. Implementing it as a function is necessary + because we need to actions in both forward and backward. + """ + @staticmethod + def forward(ctx, tensor, cpu_offload_handler): + cpu_offload_handler.on_group_commit_forward() + ctx.cpu_offload_handler = cpu_offload_handler + # return the identical tensor + return tensor + + @staticmethod + def backward(ctx, grad_output): + cpu_offload_handler = ctx.cpu_offload_handler + cpu_offload_handler.on_group_commit_backward() + return grad_output, None + + +group_prefetch_offload_commit = GroupCommitFunction.apply + + +class SynchronizedGroupOffloadHandler(OffloadHandler): + """Offload Handler that offloads/reloads in a synchronized way. + The device-to-host and host-to-device copying happen in the same stream + as the computation kernels, thus the copying will block computation. + """ + def __init__(self, + num_offload_group, + tensor_need_offloading_checker=(lambda _: True), + debug=False + ) -> None: + super().__init__() + + self.num_offload_group = num_offload_group + self.tensor_need_offloading_checker = tensor_need_offloading_checker + self.debug = debug + + self.groupid_reset() + + def groupid_reset(self): + """Groupid reset.""" + # Data structures to label saved tensors and book-keep their cpu copies. + # Currently, on push, create a new cpu tensor and copies; on pop, copies + # the tensor back to gpu and deletes the cpu tensor. + # These will increment whenever `group_commit()` is invoked + self.current_group, self.tensor_count_current_group = (0, 0) + self.tensor_tag_to_state = {} + + def on_group_commit_forward(self): + """On group commit forward.""" + # finishing up with updating current group and tensor count + self.current_group += 1 # increment + self.tensor_count_current_group = 0 # reset + + def on_group_commit_backward(self): + """On group commit backward.""" + self.current_group -= 1 + assert self.current_group >= 0 + + @staticmethod + def offload(src_tensor, pin_memory=True): + """Offload.""" + fp8_offload = isinstance(src_tensor, Float8Tensor) + + cpu_backup = torch.empty( + src_tensor.size(), dtype=torch.uint8 if fp8_offload else src_tensor.dtype, + layout=src_tensor.layout, device="cpu", pin_memory=pin_memory) + + if fp8_offload: + cpu_backup = Float8Tensor.make_like(src_tensor, data=cpu_backup) + + cpu_backup.copy_(src_tensor, non_blocking=pin_memory) + state = (src_tensor.device, cpu_backup) + return state + + @staticmethod + def reload(state, non_blocking=None): + """Reload.""" + dev, cpu_backup = state + if non_blocking is None: + non_blocking = cpu_backup.is_pinned() + return cpu_backup.to(dev, non_blocking=non_blocking) + + def tensor_push(self, tensor: torch.Tensor, **kwargs): + """Tensor push.""" + # obtain a unique tensor tag + tensor_tag = (self.current_group, self.tensor_count_current_group) + self.tensor_count_current_group += 1 + assert tensor_tag not in self.tensor_tag_to_state + if (self.current_group < self.num_offload_group + and self.tensor_need_offloading_checker(tensor)): + state = SynchronizedGroupOffloadHandler.offload(tensor) + self.tensor_tag_to_state[tensor_tag] = state + else: + # will be offloaded together after group commit + self.tensor_tag_to_state[tensor_tag] = tensor + return tensor_tag + + def tensor_pop(self, tensor_tag, **kwargs): + """Tensor pop.""" + assert tensor_tag in self.tensor_tag_to_state + state = self.tensor_tag_to_state.pop(tensor_tag) + if isinstance(state, tuple): + tensor = SynchronizedGroupOffloadHandler.reload(state) + else: + tensor = state + return tensor + + +class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler): + """Compared to synchronize, this uses more memory because of the buffer but + achieves better performance due to the overlapping. D2h and h2d copying are + completely hidden behind computation if computation time of a layer is longer + than host-device communication time. Bulk offloading with delay and bulk reloading + with prefetch are implemented. """ + def __init__(self, + num_offload_group, # must be <= actual number of groups (number of commits) + num_prefetch_group=1, + tensor_need_offloading_checker=(lambda t: True), + debug=False + ) -> None: + super().__init__(num_offload_group=num_offload_group, + tensor_need_offloading_checker=tensor_need_offloading_checker, + debug=debug) + self.num_prefetch_group = num_prefetch_group + + # prepare for tensor buffer + self.tensor_id_to_tensor_buf_double_bufs = [] + for _ in range(2): + self.tensor_id_to_tensor_buf_double_bufs.append({}) + + # allocate streams and events for synchronization + self.d2h_stream = torch.cuda.Stream() + self.h2d_stream = torch.cuda.Stream() + self.h2d_finish_events = [] + self.compute_stream_bwd_start_events = [] + for _ in range(self.num_offload_group): + self.h2d_finish_events.append(torch.cuda.Event()) + self.compute_stream_bwd_start_events.append(torch.cuda.Event()) + self.d2h_final_event = torch.cuda.Event() + + def get_tensor_buf_for_offloaded_tensor(self, tensor, tensor_tag): + """Get tensor buffer for offloaded tensor.""" + group_id, tensor_id = tensor_tag + # obtain ping-pong buffer + id_buf_map = self.tensor_id_to_tensor_buf_double_bufs[(group_id % 2)] + + if not tensor_id in id_buf_map: + allocate_new_buf = True + else: + tensor_buf = id_buf_map[tensor_id] + if not (tensor_buf.size() == tensor.size() and tensor_buf.dtype == tensor.dtype): # pylint: disable=simplifiable-if-statement + allocate_new_buf = True + else: + allocate_new_buf = False # in this case, reuse the old buffer + + if allocate_new_buf: + # supposed to only execute once + fp8_offload = isinstance(tensor, Float8Tensor) + buffer = torch.empty( + tensor.size(), dtype=torch.uint8 if fp8_offload else tensor.dtype, + layout=tensor.layout, device=tensor.device) + + if isinstance(tensor, Float8Tensor): + id_buf_map[tensor_id] = Float8Tensor.make_like(tensor, data=buffer) + else: + id_buf_map[tensor_id] = buffer + + return id_buf_map[tensor_id] + + + def tensor_push(self, tensor: torch.Tensor, **kwargs) -> Any: + # obtain a unique tensor tag + tensor_tag = (self.current_group, self.tensor_count_current_group) + self.tensor_count_current_group += 1 + assert tensor_tag not in self.tensor_tag_to_state + + if (self.current_group < self.num_offload_group + and self.tensor_need_offloading_checker(tensor)): + # first copy the tensor to tensorbuf, so that the original tensor will not be deleted + tensor_buf = self.get_tensor_buf_for_offloaded_tensor(tensor, tensor_tag) + tensor_buf.copy_(tensor) + if hasattr(tensor,"weight_offloading"): + tensor_buf.weight_offloading = True + if hasattr(tensor,"activation_offloading"): + tensor_buf.activation_offloading = True + # Here we just save it, and at commit, bulk_offload_group will handle it + self.tensor_tag_to_state[tensor_tag] = tensor_buf + else: + self.tensor_tag_to_state[tensor_tag] = tensor + return tensor_tag + + def tensor_pop(self, tensor_tag, **kwargs): + """Tensor pop.""" + assert tensor_tag in self.tensor_tag_to_state + tensor = self.tensor_tag_to_state.pop(tensor_tag) + # the tensor should have been copied back in on_group_commit_backward() + # which invokes bulk_reload_group. + assert not isinstance(tensor, tuple) + return tensor + + def bulk_offload_group(self, group_to_offload): + """Bulk offload group.""" + with torch.cuda.stream(self.d2h_stream): + for tensor_tag, state in self.tensor_tag_to_state.items(): + group_id, _ = tensor_tag + if group_id == group_to_offload: + assert not isinstance(state, tuple) + tensor_on_device = state + + # if offload, return the reference to cpu copy + if self.tensor_need_offloading_checker(tensor_on_device): + state = SynchronizedGroupOffloadHandler.offload(tensor_on_device) + self.tensor_tag_to_state[tensor_tag] = state + + def synchronize_on_group_commit_forward(self, current_group): + """Synchronize on group commit forward.""" + # the host should wait for the copying of previous group + # to avoid overwriting buffer + previous_group = current_group - 1 + if previous_group < self.num_offload_group: + torch.cuda.synchronize() + # TODO (guyueh): this part is originally designed to reduce the peak memory usage. # pylint: disable=fixme + # however, uncommenting this part will cause illegal access, have not figured out why. + + if previous_group + 2 >= self.num_offload_group: + # this buffer is no longer required + self.tensor_id_to_tensor_buf_double_bufs[(previous_group % 2)] = {} + + # the copying of this group should wait for the computation stream event + if current_group < self.num_offload_group: + # perform bulk offloading + self.bulk_offload_group(current_group) + if current_group == self.num_offload_group - 1: + self.d2h_stream.record_event(self.d2h_final_event) + + def on_group_commit_forward(self): + """This function will cause host device synchronization""" + # handle synchronization events + self.synchronize_on_group_commit_forward(self.current_group) + + # during forward, the next_group_to_fetch always points to the min of + # the last commited group, and the last offloaded group + self.next_group_to_fetch = min(self.current_group, self.num_offload_group -1) + + super().on_group_commit_forward() + + def bulk_reload_group(self, group_to_reload): + """Bulk reload group.""" + assert group_to_reload < self.num_offload_group + if group_to_reload == self.num_offload_group - 1: + self.h2d_stream.wait_event(self.d2h_final_event) + with torch.cuda.stream(self.h2d_stream): + # move back tensors + for tensor_label, state in self.tensor_tag_to_state.items(): + group_id, _ = tensor_label + if group_id == group_to_reload: + if isinstance(state, tuple): + recovered_tensor = SynchronizedGroupOffloadHandler.reload(state) + self.tensor_tag_to_state[tensor_label] = recovered_tensor + + def on_group_commit_backward(self): + # first decrement the current group. + # after last commit in forward, the group will +1; in backward it -1. + # Finally it should be decremented to 0. + self.current_group -= 1 + assert self.current_group >= 0 + + # decide the range of group to prefetch + should_prefetch_until_group = self.current_group - self.num_prefetch_group + should_prefetch_until_group = max(should_prefetch_until_group, 0) + + # do prefetch + for group_num_to_prefetch in range( + self.next_group_to_fetch, should_prefetch_until_group - 1, -1 + ): + # record the event in the compute stream, for h2d to wait + torch.cuda.current_stream().record_event( + self.compute_stream_bwd_start_events[group_num_to_prefetch]) + + # start of h2d should wait for the compute and the d2h + self.h2d_stream.wait_event(self.compute_stream_bwd_start_events[group_num_to_prefetch]) + + #recover tensors (copy back from host) + self.bulk_reload_group(group_num_to_prefetch) + + # record an event for the backward of this layer to wait + self.h2d_stream.record_event(self.h2d_finish_events[group_num_to_prefetch]) + + # always is set to -1 at the end of the backward + self.next_group_to_fetch = min(self.num_offload_group - 1, should_prefetch_until_group - 1) + + # wait for the current group + if self.current_group < self.num_offload_group: + torch.cuda.current_stream().wait_event(self.h2d_finish_events[self.current_group]) + + +def get_cpu_offload_context( + enabled: bool = False, + num_layers: int = 1, + offload_activations: bool = True, + offload_weights: bool = True): + """ + This function returns the CPU Offload context and the synchronizer function that needs to be + used after every transformer layer. Returns `nullcontext()` if offloading is not enabled. + + Usage: + + .. code-block:: python + + cpu_offload_context, cpu_offload_synchronizer = get_cpu_offload_context(enabled=True) + + with cpu_offload_context: + te_layer.forward(inp_tensor) + cpu_offload_synchronizer() + + Parameters + ---------- + enabled: bool, default = `False` + When set to True, CPU Offloading functionality is enabled. + num_layers: int, default = 1 + Determines the number of transformer layers + you want to offload activations/weights for. + offload_activations: bool, default = `True` + When set to `True`, offloads the activations for the TE layer. + offload_weights: bool, default = `True` + When set to `True`, offloads the weights for the TE layer. + + """ + + def tensor_need_offloading_checker_activations(tensor): + return hasattr(tensor,"activation_offloading") + + # This includes the Gradient Accumulation Buffer + def tensor_need_offloading_checker_weights(tensor): + return hasattr(tensor, "weight_offloading") + + def tensor_need_offloading_checker_all(tensor): # pylint: disable=unused-argument + return (hasattr(tensor,"activation_offloading") or hasattr(tensor, "weight_offloading")) + + if offload_activations and offload_weights: + tensor_need_offloading_checker = tensor_need_offloading_checker_all + elif offload_activations: + tensor_need_offloading_checker = tensor_need_offloading_checker_activations + elif offload_weights: + tensor_need_offloading_checker = tensor_need_offloading_checker_weights + else: + raise ValueError( + "CPU Offloading is enabled while it is not " + "mentioned what to offload (weights/activations)") + + cpu_offload_handler = AsyncDoubleBufferGroupOffloadHandler( + num_offload_group=num_layers, + num_prefetch_group=1, + tensor_need_offloading_checker=tensor_need_offloading_checker + ) + + def group_prefetch_offload_commit_async(tensor): + return group_prefetch_offload_commit(tensor,cpu_offload_handler) + + if enabled: + return ( + CpuOffloadHookWithOffloadHandler(offload_handler=cpu_offload_handler), + group_prefetch_offload_commit_async, + ) + return nullcontext(), group_prefetch_offload_commit_async diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 2e6803f992..0431b8e046 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -42,7 +42,6 @@ from ._common import _apply_normalization, _noop_cat from ..float8_tensor import Float8Tensor - __all__ = ["LayerNormLinear"] @@ -68,6 +67,7 @@ def forward( fp8_calibration: bool, fp8_meta: Dict[str, Any], fuse_wgrad_accumulation: bool, + cpu_offloading: bool, tp_group: Union[dist_group_type, None], tp_size: int, sequence_parallel: bool, @@ -239,12 +239,27 @@ def forward( ) if is_grad_enabled: + if cpu_offloading: + if fuse_wgrad_accumulation: + weight.main_grad.weight_offloading = True + if fp8: + weight_t_fp8.weight_offloading = True + ln_weight.weight_offloading = True + weight.weight_offloading = True + + inputmat.activation_offloading = True + if normalization == "LayerNorm": + mu.activation_offloading = True + rsigma.activation_offloading = True + ln_out.activation_offloading = True + ctx.save_for_backward( inputmat, ln_weight, mu, rsigma, weight, + weight.main_grad if cpu_offloading and fuse_wgrad_accumulation else None, weight_t_fp8, ln_out, fp8_meta["scaling_fwd"].scale_inv.clone() if fp8 else None, @@ -254,6 +269,7 @@ def forward( ctx.fp8 = fp8 ctx.fp8_meta = fp8_meta ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation + ctx.cpu_offloading = cpu_offloading ctx.is_first_microbatch = is_first_microbatch ctx.use_bias = use_bias ctx.sequence_parallel = sequence_parallel @@ -298,11 +314,16 @@ def backward( mu, rsigma, weight, + main_grad, weight_t_fp8, ln_out, fwd_scale_inverses, ) = ctx.saved_tensors + if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation: + weight = torch.nn.Parameter(weight, False) + weight.main_grad = main_grad + # Primary weights are in FP8. if ctx.fp8 and weight_t_fp8 is None: weight_t_fp8 = weight.transpose(update_cache=ctx.is_first_microbatch) @@ -582,6 +603,7 @@ def backward( None, None, None, + None, ) @@ -992,6 +1014,8 @@ def forward( is_first_microbatch ) + from ..cpu_offload import CPUOffloadEnabled + if torch.is_grad_enabled(): fwd_fn = _LayerNormLinear.apply args = [] @@ -1013,6 +1037,7 @@ def forward( self.fp8_calibration, self.fp8_meta, self.fuse_wgrad_accumulation, + CPUOffloadEnabled, self.tp_group, self.tp_size, self.sequence_parallel, diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 8f88d725ad..050ac21a92 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -51,7 +51,6 @@ from ..float8_tensor import Float8Tensor from ._common import _apply_normalization - __all__ = ["LayerNormMLP"] @@ -95,6 +94,7 @@ def forward( fp8_calibration: bool, fp8_meta: Dict[str, Any], fuse_wgrad_accumulation: bool, + cpu_offloading: bool, tp_group: Union[dist_group_type, None], tp_size: int, sequence_parallel: bool, @@ -420,6 +420,26 @@ def forward( clear_tensor_data(gelu_out) if is_grad_enabled: + if cpu_offloading: + if fuse_wgrad_accumulation: + fc1_weight.main_grad.weight_offloading = True + fc2_weight.main_grad.weight_offloading = True + if fp8: + fc1_weight_t_fp8.weight_offloading = True + fc2_weight_t_fp8.weight_offloading = True + ln_weight.weight_offloading = True + fc1_weight.weight_offloading = True + fc2_weight.weight_offloading = True + fc1_bias.weight_offloading = True + + inputmat.activation_offloading = True + if normalization == "LayerNorm": + mu.activation_offloading = True + rsigma.activation_offloading = True + ln_out.activation_offloading = True + fc1_out.activation_offloading = True + gelu_out.activation_offloading = True + ctx.save_for_backward( inputmat, ln_weight, @@ -429,8 +449,10 @@ def forward( fc1_out, gelu_out, fc1_weight, + fc1_weight.main_grad if (cpu_offloading and fuse_wgrad_accumulation) else None, fc1_weight_t_fp8, fc2_weight, + fc2_weight.main_grad if (cpu_offloading and fuse_wgrad_accumulation) else None, fc2_weight_t_fp8, fc1_bias, fp8_meta["scaling_fwd"].scale_inv.clone() if fp8 else None, @@ -440,6 +462,7 @@ def forward( ctx.fp8 = fp8 ctx.fp8_meta = fp8_meta ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation + ctx.cpu_offloading = cpu_offloading ctx.is_first_microbatch = is_first_microbatch ctx.use_fc1_bias = use_fc1_bias ctx.use_fc2_bias = use_fc2_bias @@ -492,13 +515,22 @@ def backward( fc1_out, gelu_out, fc1_weight, + fc1_weight_main_grad, fc1_weight_t_fp8, fc2_weight, + fc2_weight_main_grad, fc2_weight_t_fp8, fc1_bias, fwd_scale_inverses, ) = ctx.saved_tensors + if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation: + fc1_weight = Parameter(fc1_weight, False) + fc2_weight = Parameter(fc2_weight, False) + + fc1_weight.main_grad = fc1_weight_main_grad + fc2_weight.main_grad = fc2_weight_main_grad + # Primary weights are in FP8. if ctx.fp8 and fc1_weight_t_fp8 is None: fc1_weight_t_fp8 = fc1_weight.transpose(update_cache=ctx.is_first_microbatch) @@ -993,6 +1025,7 @@ def backward( None, None, None, + None, ) @@ -1336,6 +1369,8 @@ def forward( is_first_microbatch ) + from ..cpu_offload import CPUOffloadEnabled + if torch.is_grad_enabled(): fwd_fn = _LayerNormMLP.apply args = [] @@ -1362,6 +1397,7 @@ def forward( self.fp8_calibration, self.fp8_meta, self.fuse_wgrad_accumulation, + CPUOffloadEnabled, self.tp_group, self.tp_size, self.sequence_parallel, diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 2cad516881..87c78aa151 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -45,7 +45,6 @@ from ..float8_tensor import Float8Tensor - __all__ = ["Linear"] @@ -68,6 +67,7 @@ def forward( fp8_calibration: bool, fp8_meta: Dict[str, Any], fuse_wgrad_accumulation: bool, + cpu_offloading: bool, tp_group: Union[dist_group_type, None], tp_size: int, sequence_parallel: bool, @@ -266,12 +266,26 @@ def forward( saved_inputmat = inputmat else: saved_inputmat_t = inputmat_t + if cpu_offloading: + saved_inputmat_t.activation_offloading = True else: saved_inputmat = inputmat_no_fp8 + + if cpu_offloading: + if fuse_wgrad_accumulation: + weight.main_grad.weight_offloading = True + if fp8: + weight_t_fp8.weight_offloading = True + weight.weight_offloading = True + + if saved_inputmat is not None: + saved_inputmat.activation_offloading = True + ctx.save_for_backward( saved_inputmat, saved_inputmat_t, weight, + weight.main_grad if cpu_offloading and fuse_wgrad_accumulation else None, weight_t_fp8 if fp8 else None, fp8_meta["scaling_fwd"].scale_inv.clone() if fp8 else None, ) @@ -279,6 +293,7 @@ def forward( ctx.fp8 = fp8 ctx.fp8_meta = fp8_meta ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation + ctx.cpu_offloading = cpu_offloading ctx.is_first_microbatch = is_first_microbatch ctx.use_bias = use_bias ctx.sequence_parallel = sequence_parallel @@ -315,10 +330,15 @@ def backward( inputmat, inputmat_t, weight, + main_grad, weight_t_fp8, fwd_scale_inverses, ) = ctx.saved_tensors + if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation: + weight = torch.nn.Parameter(weight, False) + weight.main_grad = main_grad + # Primary weights are in FP8. if ctx.fp8 and weight_t_fp8 is None: weight_t_fp8 = weight.transpose(update_cache=ctx.is_first_microbatch) @@ -515,6 +535,7 @@ def backward( None, None, None, + None, ) @@ -862,6 +883,8 @@ def forward( is_first_microbatch ) + from ..cpu_offload import CPUOffloadEnabled + if torch.is_grad_enabled(): linear_fn = _Linear.apply args = [] @@ -880,6 +903,7 @@ def forward( self.fp8_calibration, self.fp8_meta, self.fuse_wgrad_accumulation, + CPUOffloadEnabled, self.tp_group, self.tp_size, self.sequence_parallel, From cc289dc55df47189ec3bb6ec3b7332d76004951f Mon Sep 17 00:00:00 2001 From: Marks101 <46690260+Marks101@users.noreply.github.com> Date: Mon, 22 Jan 2024 19:05:24 +0100 Subject: [PATCH 06/15] [PyTorch] Fix bias initialization introduced in #596 (#622) Signed-off-by: Markus Schnoes --- transformer_engine/pytorch/module/layernorm_linear.py | 6 ++++-- transformer_engine/pytorch/module/layernorm_mlp.py | 9 ++++++--- transformer_engine/pytorch/module/linear.py | 4 +++- 3 files changed, 13 insertions(+), 6 deletions(-) diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 0431b8e046..589c787b74 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -781,7 +781,8 @@ def __init__( layer_norm_bias = torch.nn.Parameter( torch.empty(in_features, device=device, dtype=params_dtype) ) - self.register_parameter('layer_norm_bias', layer_norm_bias) + self.register_parameter('layer_norm_bias', layer_norm_bias, + init_fn=init_method_constant(0.0)) setattr(self.layer_norm_bias, "sequence_parallel", self.sequence_parallel) # pylint: disable=access-member-before-definition else: self.layer_norm_bias = None @@ -873,7 +874,8 @@ def __init__( if is_subview: bias = bias[split_start:split_end] bias = torch.nn.Parameter(bias) - self.register_parameter(self.bias_names[i], bias) + self.register_parameter(self.bias_names[i], bias, + init_fn=init_method_constant(0.0)) if parallel_mode == "row": bias.sequence_parallel = sequence_parallel else: diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 050ac21a92..54de8f16f8 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -1213,7 +1213,8 @@ def __init__( layer_norm_bias = Parameter( torch.empty(hidden_size, device=device, dtype=params_dtype) ) - self.register_parameter('layer_norm_bias', layer_norm_bias) + self.register_parameter('layer_norm_bias', layer_norm_bias, + init_fn=init_method_constant(0.0)) setattr(self.layer_norm_bias, "sequence_parallel", self.sequence_parallel) # pylint: disable=access-member-before-definition else: self.layer_norm_bias = None @@ -1240,7 +1241,8 @@ def __init__( fc1_bias = Parameter( torch.empty(fc1_output_features, device=device, dtype=params_dtype) ) - self.register_parameter('fc1_bias', fc1_bias) + self.register_parameter('fc1_bias', fc1_bias, + init_fn=init_method_constant(0.0)) set_tensor_model_parallel_attributes(self.fc1_bias, True, 0, 1) # pylint: disable=access-member-before-definition else: self.fc1_bias = torch.Tensor().to(dtype=params_dtype, device=device) @@ -1260,7 +1262,8 @@ def __init__( fc2_bias = Parameter( torch.empty(hidden_size, device=device, dtype=params_dtype) ) - self.register_parameter('fc2_bias', fc2_bias) + self.register_parameter('fc2_bias', fc2_bias, + init_fn=init_method_constant(0.0)) # RPL if self.set_parallel_mode: setattr(self.fc2_bias, "sequence_parallel", sequence_parallel) # pylint: disable=access-member-before-definition diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 87c78aa151..88eb6080e8 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -26,6 +26,7 @@ cast_if_needed, assert_dim_for_fp8_exec, clear_tensor_data, + init_method_constant, ) from ..distributed import ( set_tensor_model_parallel_attributes, @@ -764,7 +765,8 @@ def __init__( if is_subview: bias = bias[split_start:split_end] bias = torch.nn.Parameter(bias) - self.register_parameter(self.bias_names[i], bias) + self.register_parameter(self.bias_names[i], bias, + init_fn=init_method_constant(0.0)) if parallel_mode == "row": bias.sequence_parallel = sequence_parallel else: From bbadf40304e20f0640885b64e8fd0fbeedc8a6ad Mon Sep 17 00:00:00 2001 From: Alp Dener Date: Tue, 23 Jan 2024 15:30:47 -0600 Subject: [PATCH 07/15] [PyTorch] Fix for deferred init bug causing NeMo MLPerf LLM crash (#619) * added missing parameter materialization on real device for LayerNorm and RMSNorm Signed-off-by: Alp Dener * added new unittest for deferred initialization and modified parameter materialization to support standalone execution outside of FSDP Signed-off-by: Alp Dener * restored tensor parallel attributes that were being wiped out by the parameter reset Signed-off-by: Alp Dener * fixed incorrect order of fp8 metadata initialization Signed-off-by: Alp Dener * added deferred init unittest to the QA script Signed-off-by: Alp Dener --------- Signed-off-by: Alp Dener --- qa/L0_pytorch_unittest/test.sh | 1 + tests/pytorch/test_deferred_init.py | 87 +++++++++++++++++++ transformer_engine/pytorch/module/base.py | 2 +- .../pytorch/module/layernorm.py | 11 ++- .../pytorch/module/layernorm_linear.py | 41 ++++++--- .../pytorch/module/layernorm_mlp.py | 25 ++++-- transformer_engine/pytorch/module/linear.py | 33 ++++--- transformer_engine/pytorch/module/rmsnorm.py | 6 +- 8 files changed, 168 insertions(+), 38 deletions(-) create mode 100644 tests/pytorch/test_deferred_init.py diff --git a/qa/L0_pytorch_unittest/test.sh b/qa/L0_pytorch_unittest/test.sh index 729b4b8992..51b7b6235e 100644 --- a/qa/L0_pytorch_unittest/test.sh +++ b/qa/L0_pytorch_unittest/test.sh @@ -8,6 +8,7 @@ set -e pip install pytest==6.2.5 onnxruntime==1.13.1 pytest -v -s $TE_PATH/tests/pytorch/test_sanity.py +pytest -v -s $TE_PATH/tests/pytorch/test_deferred_init.py PYTORCH_JIT=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_numerics.py pytest -v -s $TE_PATH/tests/pytorch/test_jit.py pytest -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py diff --git a/tests/pytorch/test_deferred_init.py b/tests/pytorch/test_deferred_init.py new file mode 100644 index 0000000000..cbc761a27c --- /dev/null +++ b/tests/pytorch/test_deferred_init.py @@ -0,0 +1,87 @@ +# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import pytest +import torch +import torch.distributed as dist + +import transformer_engine.pytorch as te + +_core_modules = [ + te.LayerNorm, + te.RMSNorm, + te.Linear, + te.LayerNormLinear, + te.LayerNormMLP, +] + +_composed_modules = [ + te.MultiheadAttention, + te.TransformerLayer, +] + +batch_size = 32 +seq_length = 2048 +num_heads = 16 +head_dim = 64 +dtype = torch.bfloat16 + +class TestDeferredInit: + + @staticmethod + def get_module_args(module): + hidden_size = num_heads * head_dim + args = (hidden_size,) + kwargs = { + 'params_dtype': dtype, + 'device': 'meta' + } + if module in [te.Linear, te.LayerNormLinear, te.LayerNormMLP]: + ffn_hidden_size = 2 * hidden_size + args += (ffn_hidden_size, ) + kwargs['bias'] = True + if module == te.LayerNormMLP: + kwargs['seq_length'] = seq_length + elif module == te.MultiheadAttention: + args += (num_heads, ) + kwargs['fuse_qkv_params'] = True + elif module == te.TransformerLayer: + args += (3 * hidden_size, num_heads) + kwargs['fuse_qkv_params'] = True + kwargs['seq_length'] = seq_length + + return args, kwargs + + @pytest.mark.parametrize("module_type", _core_modules+_composed_modules) + def test_zero_memory_init( + self, + module_type: torch.nn.Module, + ) -> None: + """Test deferred initialization via device='meta'.""" + # This should not allocate any memory on CUDA device until we call reset_parameters() later. + args, kwargs = TestDeferredInit.get_module_args(module_type) + module = module_type(*args, **kwargs) + assert torch.cuda.memory_allocated(device=0) == 0.0, ( + f"Initializing {module_type.__name__} with device='meta' prematurely allocated " + "memory on CUDA device" + ) + del module + + @pytest.mark.parametrize("module_type", _core_modules) + def test_reset_parameters( + self, + module_type: torch.nn.Module, + ) -> None: + """Test parameter reset for core modules that have been initialized with device='meta'.""" + # Core modules own their own parameters so calling reset_parameters() here should + # materialize them on CUDA device. + args, kwargs = TestDeferredInit.get_module_args(module_type) + module = module_type(*args, **kwargs) + with torch.no_grad(): + module.reset_parameters() + assert torch.cuda.memory_allocated(device=0) > 0.0, ( + f"{module_type.__name__}.reset_parameters() failed to materialize parameters " + "on CUDA device" + ) + del module diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index ad1f383617..f77e07a68f 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -769,7 +769,7 @@ def reset_parameters(self, defer_init: Optional[bool] = False) -> None: for name, param in self.named_parameters(recurse=False): # Ensure parameter is on a real device if param.device == torch.device('meta'): - param = param.to(device='cuda') + param = torch.empty_like(param, device='cuda') # Initialize the parameter values on device init_fn = self.param_init_meta[name].init_fn diff --git a/transformer_engine/pytorch/module/layernorm.py b/transformer_engine/pytorch/module/layernorm.py index fac941306f..6178199be6 100644 --- a/transformer_engine/pytorch/module/layernorm.py +++ b/transformer_engine/pytorch/module/layernorm.py @@ -138,8 +138,7 @@ def __init__( dtype=params_dtype, ) ) - setattr(self.weight, "sequence_parallel", sequence_parallel) - setattr(self.bias, "sequence_parallel", sequence_parallel) + self.sequence_parallel = sequence_parallel self.reset_parameters(defer_init=(device == 'meta')) @@ -168,7 +167,15 @@ def reset_parameters(self, defer_init=False) -> None: """Init LayerNorm parameters""" if defer_init: return + + if self.weight.device == torch.device('meta'): + self.weight = torch.nn.Parameter(torch.empty_like(self.weight, device='cuda')) + setattr(self.weight, "sequence_parallel", self.sequence_parallel) init.constant_(self.weight, float(not self.zero_centered_gamma)) + + if self.bias.device == torch.device('meta'): + self.bias = torch.nn.Parameter(torch.empty_like(self.bias, device='cuda')) + setattr(self.bias, "sequence_parallel", self.sequence_parallel) init.zeros_(self.bias) @no_torch_dynamo() diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 589c787b74..2de860cf73 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -776,14 +776,12 @@ def __init__( ) self.register_parameter('layer_norm_weight', layer_norm_weight, init_fn=init_method_constant(float(not self.zero_centered_gamma))) - setattr(self.layer_norm_weight, "sequence_parallel", self.sequence_parallel) # pylint: disable=access-member-before-definition if self.normalization != "RMSNorm": layer_norm_bias = torch.nn.Parameter( torch.empty(in_features, device=device, dtype=params_dtype) ) self.register_parameter('layer_norm_bias', layer_norm_bias, init_fn=init_method_constant(0.0)) - setattr(self.layer_norm_bias, "sequence_parallel", self.sequence_parallel) # pylint: disable=access-member-before-definition else: self.layer_norm_bias = None @@ -876,22 +874,10 @@ def __init__( bias = torch.nn.Parameter(bias) self.register_parameter(self.bias_names[i], bias, init_fn=init_method_constant(0.0)) - if parallel_mode == "row": - bias.sequence_parallel = sequence_parallel else: bias = torch.Tensor().to(dtype=params_dtype, device=device) setattr(self, self.bias_names[i], bias) - # Configure tensor parallelism - set_tensor_model_parallel_attributes( - tensor=weight, - is_parallel=True, - dim=1 if parallel_mode == "row" else 0, - stride=1, - ) - if parallel_mode == "column": - set_tensor_model_parallel_attributes(bias, True, 0, 1) - # Concatenated tensors are not needed if not splitting # into multiple parameters if not is_subview: @@ -935,6 +921,33 @@ def reset_layer_norm_parameters(self) -> None: if self.layer_norm_bias is not None: init.zeros_(self.layer_norm_bias) + def reset_parameters(self, defer_init=False): + super().reset_parameters(defer_init=defer_init) + + if not defer_init: + # Set parallelism attributes for layer norm parameters + setattr(self.layer_norm_weight, "sequence_parallel", self.sequence_parallel) + if self.normalization != "RMSNorm": + setattr(self.layer_norm_bias, "sequence_parallel", self.sequence_parallel) + + # Set parallelism attributes for linear weights + for weight in self.weight_names: + set_tensor_model_parallel_attributes( + tensor=getattr(self, weight), + is_parallel=True, + dim=1 if self.parallel_mode == "row" else 0, + stride=1, + ) + + # Set parallelism attributes for linear biases + if self.use_bias: + for bias in self.bias_names: + if self.parallel_mode == "row": + setattr(getattr(self, bias), "sequence_parallel", self.sequence_parallel) + elif self.parallel_mode == "column": + set_tensor_model_parallel_attributes(getattr(self, bias), True, 0, 1) + + def get_fp8_weights_scratchpad( self, is_first_microbatch: Union[bool, None], diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 54de8f16f8..d48ee4887d 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -1208,14 +1208,12 @@ def __init__( ) self.register_parameter('layer_norm_weight', layer_norm_weight, init_fn=init_method_constant(float(not self.zero_centered_gamma))) - setattr(self.layer_norm_weight, "sequence_parallel", self.sequence_parallel) if self.normalization != "RMSNorm": layer_norm_bias = Parameter( torch.empty(hidden_size, device=device, dtype=params_dtype) ) self.register_parameter('layer_norm_bias', layer_norm_bias, init_fn=init_method_constant(0.0)) - setattr(self.layer_norm_bias, "sequence_parallel", self.sequence_parallel) # pylint: disable=access-member-before-definition else: self.layer_norm_bias = None @@ -1234,7 +1232,6 @@ def __init__( init_fn=init_method, get_rng_state_tracker=get_rng_state_tracker, fp8_meta_index=tex.FP8FwdTensors.GEMM1_WEIGHT) - set_tensor_model_parallel_attributes(self.fc1_weight, True, 0, 1) self.fp8_weight_shapes.append(self.fc1_weight.shape) if self.use_bias: @@ -1243,7 +1240,6 @@ def __init__( ) self.register_parameter('fc1_bias', fc1_bias, init_fn=init_method_constant(0.0)) - set_tensor_model_parallel_attributes(self.fc1_bias, True, 0, 1) # pylint: disable=access-member-before-definition else: self.fc1_bias = torch.Tensor().to(dtype=params_dtype, device=device) @@ -1255,7 +1251,6 @@ def __init__( init_fn=output_layer_init_method, get_rng_state_tracker=get_rng_state_tracker, fp8_meta_index=tex.FP8FwdTensors.GEMM2_WEIGHT) - set_tensor_model_parallel_attributes(self.fc2_weight, True, 1, 1) self.fp8_weight_shapes.append(self.fc2_weight.shape) if self.use_bias: @@ -1264,9 +1259,6 @@ def __init__( ) self.register_parameter('fc2_bias', fc2_bias, init_fn=init_method_constant(0.0)) - # RPL - if self.set_parallel_mode: - setattr(self.fc2_bias, "sequence_parallel", sequence_parallel) # pylint: disable=access-member-before-definition else: self.fc2_bias = torch.Tensor().to(dtype=params_dtype, device=device) @@ -1312,6 +1304,23 @@ def reset_layer_norm_parameters(self) -> None: if self.layer_norm_bias is not None: init.zeros_(self.layer_norm_bias) + def reset_parameters(self, defer_init=False): + super().reset_parameters(defer_init=defer_init) + + if not defer_init: + # Set parallel attributes for layer norm parameters + setattr(self.layer_norm_weight, "sequence_parallel", self.sequence_parallel) + if self.normalization != "RMSNorm": + setattr(self.layer_norm_bias, "sequence_parallel", self.sequence_parallel) + + # Set parallel attributes for linear parameters + set_tensor_model_parallel_attributes(self.fc1_weight, True, 0, 1) + set_tensor_model_parallel_attributes(self.fc2_weight, True, 1, 1) + if self.use_bias: + set_tensor_model_parallel_attributes(self.fc1_bias, True, 0, 1) + if self.set_parallel_mode: + setattr(self.fc2_bias, "sequence_parallel", self.sequence_parallel) + def get_fp8_weights_scratchpad( self, is_first_microbatch: Union[bool, None], diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 88eb6080e8..68c5bf1a1d 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -767,22 +767,10 @@ def __init__( bias = torch.nn.Parameter(bias) self.register_parameter(self.bias_names[i], bias, init_fn=init_method_constant(0.0)) - if parallel_mode == "row": - bias.sequence_parallel = sequence_parallel else: bias = torch.Tensor().to(dtype=params_dtype, device=device) setattr(self, self.bias_names[i], bias) - # Configure tensor parallelism - set_tensor_model_parallel_attributes( - tensor=weight, - is_parallel=True, - dim=1 if parallel_mode == "row" else 0, - stride=1, - ) - if parallel_mode == "column": - set_tensor_model_parallel_attributes(bias, True, 0, 1) - # Concatenated tensors are not needed if not splitting # into multiple parameters if not is_subview: @@ -804,6 +792,27 @@ def __init__( else: self.gemm_bias_unfused_add = False + def reset_parameters(self, defer_init=False): + super().reset_parameters(defer_init=defer_init) + + if not defer_init: + # Set parallelism attributes for linear weights + for weight in self.weight_names: + set_tensor_model_parallel_attributes( + tensor=getattr(self, weight), + is_parallel=True, + dim=1 if self.parallel_mode == "row" else 0, + stride=1, + ) + + # Set parallelism attributes for linear biases + if self.use_bias: + for bias in self.bias_names: + if self.parallel_mode == "row": + setattr(getattr(self, bias), "sequence_parallel", self.sequence_parallel) + elif self.parallel_mode == "column": + set_tensor_model_parallel_attributes(getattr(self, bias), True, 0, 1) + def get_fp8_weights_scratchpad( self, is_first_microbatch: Union[bool, None], diff --git a/transformer_engine/pytorch/module/rmsnorm.py b/transformer_engine/pytorch/module/rmsnorm.py index cad357de04..4b1b2c749a 100644 --- a/transformer_engine/pytorch/module/rmsnorm.py +++ b/transformer_engine/pytorch/module/rmsnorm.py @@ -141,7 +141,7 @@ def __init__( dtype=params_dtype, ) ) - setattr(self.weight, "sequence_parallel", sequence_parallel) + self.sequence_parallel = sequence_parallel self.reset_parameters(defer_init=(device == 'meta')) @@ -169,7 +169,11 @@ def reset_parameters(self, defer_init=False) -> None: """Reset RMSNorm parameters""" if defer_init: return + + if self.weight.device == torch.device('meta'): + self.weight = torch.nn.Parameter(torch.empty_like(self.weight, device='cuda')) init.constant_(self.weight, float(not self.zero_centered_gamma)) + setattr(self.weight, "sequence_parallel", self.sequence_parallel) @no_torch_dynamo() def forward(self, inp: torch.Tensor) -> torch.Tensor: From ffdd519647701a34ec05e5cea54a0f35ecfbe64e Mon Sep 17 00:00:00 2001 From: Alp Dener Date: Wed, 24 Jan 2024 09:32:11 -0600 Subject: [PATCH 08/15] [PyTorch] Workaround for incorrect output from torch.cuda.is_bf16_compatible() on V100s and TU102s (#626) * replaced torch.cuda.is_bf16_compatible() with explicit sm_80 check via torch.cuda.get_device_capability() Signed-off-by: Alp Dener * implement te.utils.is_bf16_compatible() to replace torch.cuda counterpart Signed-off-by: Alp Dener --------- Signed-off-by: Alp Dener --- tests/pytorch/fused_attn/test_fused_attn.py | 3 ++- tests/pytorch/test_numerics.py | 3 ++- tests/pytorch/test_sanity.py | 3 ++- transformer_engine/pytorch/utils.py | 6 ++++++ 4 files changed, 12 insertions(+), 3 deletions(-) diff --git a/tests/pytorch/fused_attn/test_fused_attn.py b/tests/pytorch/fused_attn/test_fused_attn.py index 296d9ff214..42ffb32ad1 100644 --- a/tests/pytorch/fused_attn/test_fused_attn.py +++ b/tests/pytorch/fused_attn/test_fused_attn.py @@ -41,6 +41,7 @@ get_device_compute_capability, init_method_normal, scaled_init_method_normal, + is_bf16_compatible, ) import transformer_engine_extensions as tex from transformer_engine_extensions import NVTE_Fused_Attn_Backend @@ -194,7 +195,7 @@ def _is_unfused_attention_supported(config: ModelConfig) -> bool: } param_types = [torch.float16] -if torch.cuda.is_bf16_supported(): +if is_bf16_compatible(): # bf16 requires sm_80 or higher param_types.append(torch.bfloat16) param_types_lean = [torch.bfloat16] diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index 215cae2b97..4f5a9807c1 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -17,6 +17,7 @@ init_method_normal, scaled_init_method_normal, attention_mask_func, + is_bf16_compatible, ) from transformer_engine.pytorch import ( DotProductAttention, LayerNormLinear, LayerNormMLP, Linear, @@ -53,7 +54,7 @@ def __init__(self, hidden_size, eps, num_attention_heads, embed, num_layers, seq } param_types = [torch.float32, torch.float16] -if torch.cuda.is_bf16_supported(): +if is_bf16_compatible(): # bf16 requires sm_80 or higher param_types.append(torch.bfloat16) batch_sizes = [1, 2] diff --git a/tests/pytorch/test_sanity.py b/tests/pytorch/test_sanity.py index 593231d6d1..ae960369c4 100644 --- a/tests/pytorch/test_sanity.py +++ b/tests/pytorch/test_sanity.py @@ -13,6 +13,7 @@ from transformer_engine.pytorch.utils import ( init_method_normal, scaled_init_method_normal, + is_bf16_compatible, ) from transformer_engine.pytorch import ( LayerNormLinear, @@ -101,7 +102,7 @@ def is_fp8_supported(self): ] param_types = [torch.float32, torch.float16] -if torch.cuda.is_bf16_supported(): +if is_bf16_compatible(): # bf16 requires sm_80 or higher param_types.append(torch.bfloat16) all_boolean = [True, False] diff --git a/transformer_engine/pytorch/utils.py b/transformer_engine/pytorch/utils.py index 819b3d4827..824508077b 100644 --- a/transformer_engine/pytorch/utils.py +++ b/transformer_engine/pytorch/utils.py @@ -222,3 +222,9 @@ def assert_dim_for_fp8_exec(tensor: torch.Tensor) -> None: "Tensor dimensions are not compatible for FP8 execution: " f"({tensor.shape[0]} % 8 != 0, {tensor.shape[1]} % 16 != 0)" ) + +def is_bf16_compatible() -> None: + """Replaces torch.cuda.is_bf16_compatible() with an explicit + check on device compute capability to enforce sm_80 or higher. + """ + return torch.cuda.get_device_capability()[0] >= 8 From 8571f6999ffea902166d82dffe4ab0675d86e35f Mon Sep 17 00:00:00 2001 From: Marks101 <46690260+Marks101@users.noreply.github.com> Date: Wed, 24 Jan 2024 18:48:21 +0100 Subject: [PATCH 09/15] [PyTorch] forward attention_type in MultiHeadAttention (#621) [PyTorch] fix forward attention_type in MultiheadAttention Signed-off-by: Markus Schnoes Co-authored-by: Kirthi Shankar Sivamani --- transformer_engine/pytorch/attention.py | 1 + 1 file changed, 1 insertion(+) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index cf7bee8c66..7bf0678898 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -3090,6 +3090,7 @@ def __init__( sequence_parallel=sequence_parallel, tp_group=tp_group, layer_number=self.layer_number, + attention_type=self.attention_type, ) # Linear From 18186b410ad968b21ed841a5a03bd5574b96ab12 Mon Sep 17 00:00:00 2001 From: Przemyslaw Tredak Date: Wed, 24 Jan 2024 10:13:08 -0800 Subject: [PATCH 10/15] Fix compatibility with pyTorch 2.0 (#627) Signed-off-by: Przemek Tredak --- transformer_engine/pytorch/jit.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/jit.py b/transformer_engine/pytorch/jit.py index 5fb1768ba6..684004a27e 100644 --- a/transformer_engine/pytorch/jit.py +++ b/transformer_engine/pytorch/jit.py @@ -17,7 +17,12 @@ no_torch_dynamo = lambda recursive=True: lambda func: func if torch.__version__ >= "2": import torch._dynamo - no_torch_dynamo = lambda recursive=True: lambda f: torch._dynamo.disable(f, recursive=recursive) + if torch.__version__ >= "2.1": + no_torch_dynamo = lambda recursive=True: lambda f: \ + torch._dynamo.disable(f, recursive=recursive) + else: + # no "recursive" option in pyTorch 2.0 - it acts as if recursive was True + no_torch_dynamo = lambda recursive=True: torch._dynamo.disable def set_jit_fusion_options() -> None: From bcbe9b0365b649695a325f720423b4fa61d37527 Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Wed, 24 Jan 2024 16:55:50 -0800 Subject: [PATCH 11/15] Revert "Avoid redundant computation for cu_seqlens (#535)" This reverts commit fad3044bde1547eae9543a6a3f80401e59bb629e. Signed-off-by: Przemek Tredak --- transformer_engine/pytorch/attention.py | 32 +++++++++++-------------- 1 file changed, 14 insertions(+), 18 deletions(-) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 7bf0678898..a8300bad87 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -1621,24 +1621,20 @@ def forward( query_layer_packed, key_layer_packed, value_layer_packed) cu_seqlens_q, cu_seqlens_kv = _cu_seqlens_q, _cu_seqlens_kv else: - if self.layer_number == 1: - if cu_seqlens_q is None: - cu_seqlens_q = torch.arange( - 0, - (batch_size + 1) * max_seqlen_q, - step=max_seqlen_q, - dtype=torch.int32, - device=query_layer.device) - if cu_seqlens_kv is None: - cu_seqlens_kv = torch.arange( - 0, - (batch_size + 1) * max_seqlen_kv, - step=max_seqlen_kv, - dtype=torch.int32, - device=key_layer.device) - _cu_seqlens_q, _cu_seqlens_kv = cu_seqlens_q, cu_seqlens_kv - else: - cu_seqlens_q, cu_seqlens_kv = _cu_seqlens_q, _cu_seqlens_kv + if cu_seqlens_q is None: + cu_seqlens_q = torch.arange( + 0, + (batch_size + 1) * max_seqlen_q, + step=max_seqlen_q, + dtype=torch.int32, + device=query_layer.device) + if cu_seqlens_kv is None: + cu_seqlens_kv = torch.arange( + 0, + (batch_size + 1) * max_seqlen_kv, + step=max_seqlen_kv, + dtype=torch.int32, + device=key_layer.device) elif qkv_format == 'thd': assert not context_parallel, "thd format not supported with context parallelism!" assert (cu_seqlens_q is not None and cu_seqlens_kv is not None From e7319f55e3f41886a2a9ceb3c7a45fd809daffb0 Mon Sep 17 00:00:00 2001 From: Przemyslaw Tredak Date: Fri, 26 Jan 2024 12:59:40 -0800 Subject: [PATCH 12/15] Fix pipeline parallelism with FusedAttn (#635) Signed-off-by: Przemek Tredak --- transformer_engine/pytorch/attention.py | 86 +++++++++++-------------- 1 file changed, 39 insertions(+), 47 deletions(-) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index a8300bad87..469791c5d5 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -1587,32 +1587,30 @@ def forward( assert ( max_seqlen_q == max_seqlen_kv ), "Maximum sequence length for Q and KV should be the same." - if self.layer_number == 1: - if cu_seqlens_q is None: - assert (attention_mask is not None - ), "Please provide attention_mask for padding!" - _cu_seqlens_q, _indices_q = get_cu_seqlens_and_indices(attention_mask) - else: - _cu_seqlens_q = cu_seqlens_q - _indices_q = get_indices(max_seqlen_q, cu_seqlens_q) + if cu_seqlens_q is None: + assert (attention_mask is not None + ), "Please provide attention_mask for padding!" + _cu_seqlens_q, _indices_q = get_cu_seqlens_and_indices(attention_mask) + else: + _cu_seqlens_q = cu_seqlens_q + _indices_q = get_indices(max_seqlen_q, cu_seqlens_q) _cu_seqlens_kv = _cu_seqlens_q query_layer_packed, key_layer_packed, value_layer_packed = PackTensors.apply( _indices_q, query_layer, key_layer, value_layer ) else: - if self.layer_number == 1: - if cu_seqlens_q is None or cu_seqlens_kv is None: - assert (attention_mask is not None - ), "Please provide attention_mask for padding!" - _cu_seqlens_q, _indices_q = get_cu_seqlens_and_indices( - attention_mask[0]) - _cu_seqlens_kv, _indices_kv = get_cu_seqlens_and_indices( - attention_mask[1]) - else: - _cu_seqlens_q = cu_seqlens_q - _cu_seqlens_kv = cu_seqlens_kv - _indices_q = get_indices(max_seqlen_q, cu_seqlens_q) - _indices_kv = get_indices(max_seqlen_kv, cu_seqlens_kv) + if cu_seqlens_q is None or cu_seqlens_kv is None: + assert (attention_mask is not None + ), "Please provide attention_mask for padding!" + _cu_seqlens_q, _indices_q = get_cu_seqlens_and_indices( + attention_mask[0]) + _cu_seqlens_kv, _indices_kv = get_cu_seqlens_and_indices( + attention_mask[1]) + else: + _cu_seqlens_q = cu_seqlens_q + _cu_seqlens_kv = cu_seqlens_kv + _indices_q = get_indices(max_seqlen_q, cu_seqlens_q) + _indices_kv = get_indices(max_seqlen_kv, cu_seqlens_kv) query_layer_packed = PackTensors.apply(_indices_q, query_layer) key_layer_packed, value_layer_packed = PackTensors.apply( _indices_kv, key_layer, value_layer @@ -2030,39 +2028,33 @@ def forward( global _cu_seqlens_q, _cu_seqlens_kv if (cu_seqlens_q is not None and cu_seqlens_kv is not None): # use cu_seqlens when both cu_seqlens and attention_mask are present - if self.layer_number == 1: - _cu_seqlens_q, _cu_seqlens_kv = cu_seqlens_q, cu_seqlens_kv + _cu_seqlens_q, _cu_seqlens_kv = cu_seqlens_q, cu_seqlens_kv elif attention_mask is not None: if self.attention_type == "self": - if self.layer_number == 1: - _cu_seqlens_q = get_cu_seqlens(attention_mask) - _cu_seqlens_kv = _cu_seqlens_q + _cu_seqlens_q = get_cu_seqlens(attention_mask) + _cu_seqlens_kv = _cu_seqlens_q else: - if self.layer_number == 1: - _cu_seqlens_q = get_cu_seqlens(attention_mask[0]) - _cu_seqlens_kv = get_cu_seqlens(attention_mask[1]) + _cu_seqlens_q = get_cu_seqlens(attention_mask[0]) + _cu_seqlens_kv = get_cu_seqlens(attention_mask[1]) else: raise Exception("Please provide attention_mask or cu_seqlens for padding!") cu_seqlens_q, cu_seqlens_kv = _cu_seqlens_q, _cu_seqlens_kv else: - if self.layer_number == 1: - if cu_seqlens_q is None: - cu_seqlens_q = torch.arange( - 0, - (batch_size + 1) * max_seqlen_q, - step=max_seqlen_q, - dtype=torch.int32, - device=query_layer.device) - if cu_seqlens_kv is None: - cu_seqlens_kv = torch.arange( - 0, - (batch_size + 1) * max_seqlen_kv, - step=max_seqlen_kv, - dtype=torch.int32, - device=key_layer.device) - _cu_seqlens_q, _cu_seqlens_kv = cu_seqlens_q, cu_seqlens_kv - else: - cu_seqlens_q, cu_seqlens_kv = _cu_seqlens_q, _cu_seqlens_kv + if cu_seqlens_q is None: + cu_seqlens_q = torch.arange( + 0, + (batch_size + 1) * max_seqlen_q, + step=max_seqlen_q, + dtype=torch.int32, + device=query_layer.device) + if cu_seqlens_kv is None: + cu_seqlens_kv = torch.arange( + 0, + (batch_size + 1) * max_seqlen_kv, + step=max_seqlen_kv, + dtype=torch.int32, + device=key_layer.device) + _cu_seqlens_q, _cu_seqlens_kv = cu_seqlens_q, cu_seqlens_kv qkv_dtype = TE_DType[query_layer.dtype] From f15b70744a0aebe5aca9d3466ba81805cd36f3de Mon Sep 17 00:00:00 2001 From: Selvaraj Anandaraj Date: Mon, 29 Jan 2024 16:00:01 -0800 Subject: [PATCH 13/15] Fixed offloading for PyT version/ Added Attention activation offloading support/ Native FP8 support (#632) * Fixed offloading for PyT version/ Added Attention activation offloading support/ Native FP8 support Signed-off-by: Selvaraj Anandaraj * Removed activation offloading for fused attention Signed-off-by: Selvaraj Anandaraj * Fixed the illegal memory access issue for activation offloading of attention Signed-off-by: Selvaraj Anandaraj * Removed the version guard Signed-off-by: Selvaraj Anandaraj * Pipeline failures fix Signed-off-by: Selvaraj Anandaraj * Fixed lint erros Signed-off-by: Selvaraj Anandaraj * Lint error fix Signed-off-by: Selvaraj Anandaraj --------- Signed-off-by: Selvaraj Anandaraj Co-authored-by: Selvaraj Anandaraj --- transformer_engine/pytorch/attention.py | 24 ++++++++++ transformer_engine/pytorch/cpu_offload.py | 46 +++++++++++++------ .../pytorch/module/layernorm_linear.py | 2 +- .../pytorch/module/layernorm_mlp.py | 3 +- transformer_engine/pytorch/module/linear.py | 2 +- 5 files changed, 59 insertions(+), 18 deletions(-) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 469791c5d5..b7a98de0cd 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -1662,6 +1662,14 @@ def forward( deterministic=self.deterministic ) else: + + from .cpu_offload import CPUOffloadEnabled + if CPUOffloadEnabled: + tensor_list = [query_layer, key_layer, value_layer, cu_seqlens_q, cu_seqlens_kv] + for tensor in tensor_list: + if tensor is not None: + tensor.activation_offloading = True + with self.attention_dropout_ctx(): fa_optional_forward_kwargs = {} if _flash_attn_2_3_plus: @@ -1848,6 +1856,15 @@ def forward(ctx, is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seql attn_scale, dropout_p, fast_zero_fill, qkv_layout, attn_bias_type, attn_mask_type, rng_gen) + from .cpu_offload import CPUOffloadEnabled + if CPUOffloadEnabled: + tensor_list = [q, k, v, out, cu_seqlens_q, cu_seqlens_kv] + qkv_layout = 'sbhd_sbhd_sbhd' + for tensor in tensor_list: + if tensor is not None: + tensor.activation_offloading = True + + ctx.save_for_backward(q, k, v, out, cu_seqlens_q, cu_seqlens_kv) ctx.aux_ctx_tensors = aux_ctx_tensors ctx.max_seqlen_q = max_seqlen_q @@ -2722,6 +2739,13 @@ def forward( assert (not context_parallel), \ "Context parallelism is only implemented with Flash Attention and Fused Attention!" + from .cpu_offload import CPUOffloadEnabled + if CPUOffloadEnabled: + warnings.warn( + "Attention activation Offloading is only implemented" + "with Flash Attention and Fused Attention!" + ) + if _NVTE_DEBUG: print("[DotProductAttention]: using unfused DPA") if use_unfused_attention: diff --git a/transformer_engine/pytorch/cpu_offload.py b/transformer_engine/pytorch/cpu_offload.py index dcede62ef7..b2635bb9bf 100644 --- a/transformer_engine/pytorch/cpu_offload.py +++ b/transformer_engine/pytorch/cpu_offload.py @@ -184,6 +184,7 @@ def groupid_reset(self): # the tensor back to gpu and deletes the cpu tensor. # These will increment whenever `group_commit()` is invoked self.current_group, self.tensor_count_current_group = (0, 0) + self.torch_tensor_count = 0 self.tensor_tag_to_state = {} def on_group_commit_forward(self): @@ -310,24 +311,35 @@ def get_tensor_buf_for_offloaded_tensor(self, tensor, tensor_tag): def tensor_push(self, tensor: torch.Tensor, **kwargs) -> Any: - # obtain a unique tensor tag - tensor_tag = (self.current_group, self.tensor_count_current_group) - self.tensor_count_current_group += 1 - assert tensor_tag not in self.tensor_tag_to_state - if (self.current_group < self.num_offload_group - and self.tensor_need_offloading_checker(tensor)): - # first copy the tensor to tensorbuf, so that the original tensor will not be deleted - tensor_buf = self.get_tensor_buf_for_offloaded_tensor(tensor, tensor_tag) - tensor_buf.copy_(tensor) - if hasattr(tensor,"weight_offloading"): - tensor_buf.weight_offloading = True - if hasattr(tensor,"activation_offloading"): - tensor_buf.activation_offloading = True - # Here we just save it, and at commit, bulk_offload_group will handle it - self.tensor_tag_to_state[tensor_tag] = tensor_buf + torch_stray_tensor = isinstance(tensor,(torch._subclasses.fake_tensor.FakeTensor, + torch._subclasses.functional_tensor.FunctionalTensor)) + + if not torch_stray_tensor: + # obtain a unique tensor tag + tensor_tag = (self.current_group, self.tensor_count_current_group) + self.tensor_count_current_group += 1 + assert tensor_tag not in self.tensor_tag_to_state + + if (self.current_group < self.num_offload_group + and self.tensor_need_offloading_checker(tensor)): + # first copy the tensor to tensorbuf, + # so that the original tensor will not be deleted + tensor_buf = self.get_tensor_buf_for_offloaded_tensor(tensor, tensor_tag) + tensor_buf.copy_(tensor) + if hasattr(tensor,"weight_offloading"): + tensor_buf.weight_offloading = True + if hasattr(tensor,"activation_offloading"): + tensor_buf.activation_offloading = True + # Here we just save it, and at commit, bulk_offload_group will handle it + self.tensor_tag_to_state[tensor_tag] = tensor_buf + else: + self.tensor_tag_to_state[tensor_tag] = tensor else: + tensor_tag = (-1,self.torch_tensor_count) + self.torch_tensor_count += 1 self.tensor_tag_to_state[tensor_tag] = tensor + return tensor_tag def tensor_pop(self, tensor_tag, **kwargs): @@ -350,6 +362,10 @@ def bulk_offload_group(self, group_to_offload): # if offload, return the reference to cpu copy if self.tensor_need_offloading_checker(tensor_on_device): + if hasattr(tensor_on_device,"weight_offloading"): + delattr(tensor_on_device,"weight_offloading") + if hasattr(tensor_on_device,"activation_offloading"): + delattr(tensor_on_device,"activation_offloading") state = SynchronizedGroupOffloadHandler.offload(tensor_on_device) self.tensor_tag_to_state[tensor_tag] = state diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 2de860cf73..6836ef6d22 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -242,7 +242,7 @@ def forward( if cpu_offloading: if fuse_wgrad_accumulation: weight.main_grad.weight_offloading = True - if fp8: + if fp8 and weight_t_fp8 is not None: weight_t_fp8.weight_offloading = True ln_weight.weight_offloading = True weight.weight_offloading = True diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index d48ee4887d..3a0e5cb559 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -424,8 +424,9 @@ def forward( if fuse_wgrad_accumulation: fc1_weight.main_grad.weight_offloading = True fc2_weight.main_grad.weight_offloading = True - if fp8: + if fp8 and fc1_weight_t_fp8 is not None: fc1_weight_t_fp8.weight_offloading = True + if fp8 and fc2_weight_t_fp8 is not None: fc2_weight_t_fp8.weight_offloading = True ln_weight.weight_offloading = True fc1_weight.weight_offloading = True diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 68c5bf1a1d..f2c955bfc0 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -275,7 +275,7 @@ def forward( if cpu_offloading: if fuse_wgrad_accumulation: weight.main_grad.weight_offloading = True - if fp8: + if fp8 and weight_t_fp8 is not None: weight_t_fp8.weight_offloading = True weight.weight_offloading = True From df9c29e6a2cff8413acfc8c471a8f0417ebecec5 Mon Sep 17 00:00:00 2001 From: cyanguwa <8636796+cyanguwa@users.noreply.github.com> Date: Wed, 31 Jan 2024 08:20:19 -0800 Subject: [PATCH 14/15] Update FindCUDNN.cmake for cuDNN 9 (#640) * update cudnn cmake for v9 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add back license information Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --------- Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- transformer_engine/cmake/FindCUDNN.cmake | 82 ++++++++++++++++-------- 1 file changed, 57 insertions(+), 25 deletions(-) diff --git a/transformer_engine/cmake/FindCUDNN.cmake b/transformer_engine/cmake/FindCUDNN.cmake index 6d7455919e..065174e62a 100644 --- a/transformer_engine/cmake/FindCUDNN.cmake +++ b/transformer_engine/cmake/FindCUDNN.cmake @@ -8,25 +8,29 @@ find_path( CUDNN_INCLUDE_DIR cudnn.h HINTS $ENV{CUDNN_PATH} ${CUDNN_PATH} ${CUDAToolkit_INCLUDE_DIRS} PATH_SUFFIXES include + REQUIRED ) -function(find_cudnn_library NAME) - string(TOUPPER ${NAME} UPPERCASE_NAME) +file(READ "${CUDNN_INCLUDE_DIR}/cudnn_version.h" cudnn_version_header) +string(REGEX MATCH "#define CUDNN_MAJOR [1-9]+" macrodef "${cudnn_version_header}") +string(REGEX MATCH "[1-9]+" CUDNN_MAJOR_VERSION "${macrodef}") +function(find_cudnn_library NAME) find_library( - ${UPPERCASE_NAME}_LIBRARY ${NAME} + ${NAME}_LIBRARY ${NAME} "lib${NAME}.so.${CUDNN_MAJOR_VERSION}" HINTS $ENV{CUDNN_PATH} ${CUDNN_PATH} ${CUDAToolkit_LIBRARY_DIR} PATH_SUFFIXES lib64 lib/x64 lib + REQUIRED ) - - if(${UPPERCASE_NAME}_LIBRARY) + + if(${NAME}_LIBRARY) add_library(CUDNN::${NAME} UNKNOWN IMPORTED) set_target_properties( CUDNN::${NAME} PROPERTIES INTERFACE_INCLUDE_DIRECTORIES ${CUDNN_INCLUDE_DIR} - IMPORTED_LOCATION ${${UPPERCASE_NAME}_LIBRARY} + IMPORTED_LOCATION ${${NAME}_LIBRARY} ) - message(STATUS "${NAME} found at ${${UPPERCASE_NAME}_LIBRARY}.") + message(STATUS "${NAME} found at ${${NAME}_LIBRARY}.") else() message(STATUS "${NAME} not found.") endif() @@ -35,24 +39,18 @@ function(find_cudnn_library NAME) endfunction() find_cudnn_library(cudnn) -find_cudnn_library(cudnn_adv_infer) -find_cudnn_library(cudnn_adv_train) -find_cudnn_library(cudnn_cnn_infer) -find_cudnn_library(cudnn_cnn_train) -find_cudnn_library(cudnn_ops_infer) -find_cudnn_library(cudnn_ops_train) include (FindPackageHandleStandardArgs) find_package_handle_standard_args( - CUDNN REQUIRED_VARS - CUDNN_INCLUDE_DIR CUDNN_LIBRARY + LIBRARY REQUIRED_VARS + CUDNN_INCLUDE_DIR cudnn_LIBRARY ) -if(CUDNN_INCLUDE_DIR AND CUDNN_LIBRARY) +if(CUDNN_INCLUDE_DIR AND cudnn_LIBRARY) - message(STATUS "cuDNN: ${CUDNN_LIBRARY}") + message(STATUS "cuDNN: ${cudnn_LIBRARY}") message(STATUS "cuDNN: ${CUDNN_INCLUDE_DIR}") - + set(CUDNN_FOUND ON CACHE INTERNAL "cuDNN Library Found") else() @@ -71,11 +69,45 @@ target_include_directories( target_link_libraries( CUDNN::cudnn_all INTERFACE - CUDNN::cudnn_adv_train - CUDNN::cudnn_ops_train - CUDNN::cudnn_cnn_train - CUDNN::cudnn_adv_infer - CUDNN::cudnn_cnn_infer - CUDNN::cudnn_ops_infer - CUDNN::cudnn + CUDNN::cudnn ) + +if(CUDNN_MAJOR_VERSION EQUAL 8) + find_cudnn_library(cudnn_adv_infer) + find_cudnn_library(cudnn_adv_train) + find_cudnn_library(cudnn_cnn_infer) + find_cudnn_library(cudnn_cnn_train) + find_cudnn_library(cudnn_ops_infer) + find_cudnn_library(cudnn_ops_train) + + target_link_libraries( + CUDNN::cudnn_all + INTERFACE + CUDNN::cudnn_adv_train + CUDNN::cudnn_ops_train + CUDNN::cudnn_cnn_train + CUDNN::cudnn_adv_infer + CUDNN::cudnn_cnn_infer + CUDNN::cudnn_ops_infer + ) +elseif(CUDNN_MAJOR_VERSION EQUAL 9) + find_cudnn_library(cudnn_cnn) + find_cudnn_library(cudnn_adv) + find_cudnn_library(cudnn_graph) + find_cudnn_library(cudnn_ops) + find_cudnn_library(cudnn_engines_runtime_compiled) + find_cudnn_library(cudnn_engines_precompiled) + find_cudnn_library(cudnn_heuristic) + + target_link_libraries( + CUDNN::cudnn_all + INTERFACE + CUDNN::cudnn_adv + CUDNN::cudnn_ops + CUDNN::cudnn_cnn + CUDNN::cudnn_graph + CUDNN::cudnn_engines_runtime_compiled + CUDNN::cudnn_engines_precompiled + CUDNN::cudnn_heuristic + ) +endif() From 5b90b7f5ed67b373bc5f843d1ac3b7a8999df08e Mon Sep 17 00:00:00 2001 From: cyanguwa <8636796+cyanguwa@users.noreply.github.com> Date: Fri, 2 Feb 2024 20:36:10 -0800 Subject: [PATCH 15/15] Update cudnn-frontend to 1.0.3 to fix cuDNN v9 SDPA NaNs (#650) * Update cudnn frontend to 1.0.3 to fix cudnn v9 Nans Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * make d_out contiguous for bwd Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove cudnnDestroy to let torch handle it Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * Update transformer_engine/pytorch/attention.py Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by: cyanguwa <8636796+cyanguwa@users.noreply.github.com> * Update transformer_engine/pytorch/attention.py Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by: cyanguwa <8636796+cyanguwa@users.noreply.github.com> * Update transformer_engine/pytorch/attention.py Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by: cyanguwa <8636796+cyanguwa@users.noreply.github.com> --------- Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Signed-off-by: cyanguwa <8636796+cyanguwa@users.noreply.github.com> Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> --- 3rdparty/cudnn-frontend | 2 +- transformer_engine/common/fused_attn/utils.h | 5 ----- transformer_engine/pytorch/attention.py | 3 +++ 3 files changed, 4 insertions(+), 6 deletions(-) diff --git a/3rdparty/cudnn-frontend b/3rdparty/cudnn-frontend index 9f82dda5c0..a86ad708db 160000 --- a/3rdparty/cudnn-frontend +++ b/3rdparty/cudnn-frontend @@ -1 +1 @@ -Subproject commit 9f82dda5c029d15a5f371f0fe003dc0c74a0c987 +Subproject commit a86ad708db725e4d29919bb6fadf8e6cdfa5dc06 diff --git a/transformer_engine/common/fused_attn/utils.h b/transformer_engine/common/fused_attn/utils.h index 9da0dc553a..44288dd754 100644 --- a/transformer_engine/common/fused_attn/utils.h +++ b/transformer_engine/common/fused_attn/utils.h @@ -152,11 +152,6 @@ class cudnnExecutionPlanManager { } ~cudnnExecutionPlanManager() { - static thread_local std::once_flag flag; - std::call_once(flag, [&] { - if (handle_ != nullptr) { - cudnnDestroy(handle_); - }}); } private: diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index b7a98de0cd..27c031e267 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -1733,6 +1733,7 @@ def forward(ctx, is_training, max_seqlen, cu_seqlens, qkv, qkv_dtype, attn_bias, @staticmethod def backward(ctx, d_out): + d_out = d_out.contiguous() qkv, out, cu_seqlens = ctx.saved_tensors if not ctx.aux_ctx_tensors[0].is_contiguous(): ctx.aux_ctx_tensors[0] = ctx.aux_ctx_tensors[0].contiguous() @@ -1802,6 +1803,7 @@ def forward(ctx, is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seql @staticmethod def backward(ctx, d_out): + d_out = d_out.contiguous() q, kv, out, cu_seqlens_q, cu_seqlens_kv = ctx.saved_tensors if not ctx.aux_ctx_tensors[0].is_contiguous(): ctx.aux_ctx_tensors[0] = ctx.aux_ctx_tensors[0].contiguous() @@ -1883,6 +1885,7 @@ def forward(ctx, is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seql @staticmethod def backward(ctx, d_out): + d_out = d_out.contiguous() q, k, v, out, cu_seqlens_q, cu_seqlens_kv = ctx.saved_tensors if not ctx.aux_ctx_tensors[0].is_contiguous(): ctx.aux_ctx_tensors[0] = ctx.aux_ctx_tensors[0].contiguous()