From f74b52710e3addc57bee5136351561589992d0a8 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Fri, 14 Jul 2023 08:02:09 -0700 Subject: [PATCH 1/3] Deprecate unused APIs Signed-off-by: Kirthi Shankar Sivamani --- .../pytorch/module/layernorm_linear.py | 146 ++++++++++-------- transformer_engine/pytorch/module/linear.py | 146 ++++++++++-------- 2 files changed, 156 insertions(+), 136 deletions(-) diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 43c3e3f165..03f96cc075 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -4,6 +4,7 @@ """LayerNormLinear API""" import os +import warnings from typing import Union, Optional, Callable, Tuple, List, Dict, Any @@ -538,6 +539,11 @@ class LayerNormLinear(TransformerEngineBaseModule): r""" Applies layer normalization followed by linear transformation to the incoming data. + .. warning:: + + Argument :attr:`skip_weight_param_allocation` is deprecated and will + be fully removed in future releases. + Parameters ---------- in_features : int @@ -585,9 +591,6 @@ class LayerNormLinear(TransformerEngineBaseModule): used to decide whether this Linear layer is Column Parallel Linear or Row Parallel Linear as described `here `_. When set to `None`, no communication is performed. - skip_weight_param_allocation: bool, default = `False` - if set to `True`, weight parameter is not allocated and must be - passed as a keyword argument `weight` during the forward pass. Optimization parameters ----------------------- @@ -624,7 +627,7 @@ def __init__( params_dtype: Optional[torch.dtype] = None, parallel_mode: Optional[str] = None, return_layernorm_output: bool = False, - skip_weight_param_allocation: bool = False, + skip_weight_param_allocation: bool = False, # pylint: disable=unused-argument parameters_split: Optional[Tuple[str, ...]] = None, zero_centered_gamma: bool = False, ub_bulk_wgrad: bool = False, @@ -633,6 +636,12 @@ def __init__( ) -> None: super().__init__() + warnings.warn( + "Argument `skip_weight_param_allocation` is deprecated and" + "will be fully removed in future releases.", + category=DeprecationWarning, + ) + params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype self.in_features = in_features self.out_features = out_features @@ -695,72 +704,71 @@ def __init__( setattr(self.layer_norm_bias, "sequence_parallel", self.sequence_parallel) self.reset_layer_norm_parameters() - if not skip_weight_param_allocation: - self.weight_tensor = torch.empty( - self.out_features, self.in_features, + self.weight_tensor = torch.empty( + self.out_features, self.in_features, + device=torch.cuda.current_device(), + dtype=params_dtype) + + initialize_affine_weight_gpu( + self.weight_tensor, + init_method, + get_rng_state_tracker, + partition_dim=1 if self.parallel_mode == "row" else 0, + stride=1, + ) + + if self.use_bias: + self.bias_tensor = torch.empty( + self.out_features, device=torch.cuda.current_device(), dtype=params_dtype) + else: + self.bias_tensor = torch.Tensor().to(dtype=params_dtype, + device=torch.cuda.current_device()) - initialize_affine_weight_gpu( - self.weight_tensor, - init_method, - get_rng_state_tracker, - partition_dim=1 if self.parallel_mode == "row" else 0, - stride=1, - ) + with torch.no_grad(): + self.bias_tensor.zero_() - if self.use_bias: - self.bias_tensor = torch.empty( - self.out_features, - device=torch.cuda.current_device(), - dtype=params_dtype) - else: - self.bias_tensor = torch.Tensor().to(dtype=params_dtype, - device=torch.cuda.current_device()) + if parameters_split is None: + parameters_split = ("",) - with torch.no_grad(): - self.bias_tensor.zero_() + assert ( + self.out_features % len(parameters_split) == 0 + ), f"Weight and bias params cannot be split into {len(parameters_split)} parts" - if parameters_split is None: - parameters_split = ("",) + split_size = self.out_features // len(parameters_split) - assert ( - self.out_features % len(parameters_split) == 0 - ), f"Weight and bias params cannot be split into {len(parameters_split)} parts" + self.weight_names = [] + self.bias_names = [] - split_size = self.out_features // len(parameters_split) + for i, pname in enumerate(parameters_split): + wname = pname + "weight" + bname = pname + "bias" - self.weight_names = [] - self.bias_names = [] + self.register_parameter( + wname, Parameter(self.weight_tensor[i * split_size : (i+1) * split_size]) + ) - for i, pname in enumerate(parameters_split): - wname = pname + "weight" - bname = pname + "bias" + set_tensor_model_parallel_attributes( + tensor=getattr(self, wname), + is_parallel=True, + dim=1 if parallel_mode == "row" else 0, + stride=1, + ) + if self.use_bias: self.register_parameter( - wname, Parameter(self.weight_tensor[i * split_size : (i+1) * split_size]) + bname, Parameter(self.bias_tensor[i * split_size : (i+1) * split_size]) ) + else: + setattr(self, bname, torch.Tensor().to(dtype=params_dtype, + device=torch.cuda.current_device())) - set_tensor_model_parallel_attributes( - tensor=getattr(self, wname), - is_parallel=True, - dim=1 if parallel_mode == "row" else 0, - stride=1, - ) - - if self.use_bias: - self.register_parameter( - bname, Parameter(self.bias_tensor[i * split_size : (i+1) * split_size]) - ) - else: - setattr(self, bname, torch.Tensor().to(dtype=params_dtype, - device=torch.cuda.current_device())) - - if parallel_mode == "column": - set_tensor_model_parallel_attributes(getattr(self, bname), True, 0, 1) + if parallel_mode == "column": + set_tensor_model_parallel_attributes(getattr(self, bname), True, 0, 1) - self.weight_names.append(wname) - self.bias_names.append(bname) + self.weight_names.append(wname) + self.bias_names.append(bname) self.fp8_weight_shapes.append(torch.Size((self.out_features, self.in_features))) @@ -814,24 +822,22 @@ def get_fp8_weights_scratchpad( def forward( self, inp: torch.Tensor, - weight: Optional[torch.Tensor] = None, - bias: Optional[torch.Tensor] = None, + weight: Optional[torch.Tensor] = None, # pylint: disable=unused-argument + bias: Optional[torch.Tensor] = None, # pylint: disable=unused-argument is_first_microbatch: Optional[bool] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]: """ Apply layer normalization to the input followed by a linear transformation. + .. warning:: + + Arguments :attr:`weight` and :attr:`bias` are deprecated and will + be fully removed in future releases. + Parameters ---------- inp : torch.Tensor Input tensor. - weight : torch.Tensor, default = None - An optional weight tensor for the module. This argument is compulsory if module - is initialized with `skip_weight_param_allocation=True` - bias : torch.Tensor, default = None - An optional bias tensor for the module. This argument is compulsory if module - is initialized with `skip_weight_param_allocation=True` and one of `use_bias` - or `return_bias` is_first_microbatch : {True, False, None}, default = None During training using either gradient accumulation or pipeline parallelism a minibatch of data is further split @@ -847,16 +853,20 @@ def forward( produced) """ + warnings.warn( + "Arguments `weight` and `bias` are deprecated and" + "will be fully removed in future releases.", + category=DeprecationWarning, + ) + with self.prepare_forward(inp, is_first_microbatch) as inp: bias_tensor = ( - bias if bias is not None - else self.bias if self.parameters_split is None + self.bias if self.parameters_split is None else self.bias_tensor if not torch.is_grad_enabled() else self.noop_cat("bias_tensor", self.bias_names) ) weight_tensor = ( - weight if weight is not None - else self.weight if self.parameters_split is None + self.weight if self.parameters_split is None else self.weight_tensor if not torch.is_grad_enabled() else self.noop_cat("weight_tensor", self.weight_names) ) diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 054e28e901..1378ea7e98 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -3,6 +3,7 @@ # See LICENSE for license information. """Linear API""" +import warnings from typing import Union, Optional, Callable, Tuple, List, Dict, Any import torch @@ -441,6 +442,11 @@ class Linear(TransformerEngineBaseModule): On NVIDIA GPUs it is a drop-in replacement for `torch.nn.Linear`. + .. warning:: + + Argument :attr:`skip_weight_param_allocation` is deprecated and will + be fully removed in future releases. + Parameters ---------- in_features : int @@ -474,9 +480,6 @@ class Linear(TransformerEngineBaseModule): used to decide whether this Linear layer is Column Parallel Linear or Row Parallel Linear as described `here `_. When set to `None`, no communication is performed. - skip_weight_param_allocation: bool, default = `False` - if set to `True`, weight parameter is not allocated and must be - passed as a keyword argument `weight` during the forward pass. Optimization parameters ----------------------- @@ -511,13 +514,19 @@ def __init__( return_bias: bool = False, params_dtype: Optional[torch.dtype] = None, parallel_mode: Optional[str] = None, - skip_weight_param_allocation: bool = False, + skip_weight_param_allocation: bool = False, # pylint: disable=unused-argument parameters_split: Optional[Tuple[str, ...]] = None, ub_split_rs: bool = False, ub_split_ag: bool = False, ) -> None: super().__init__() + warnings.warn( + "Argument `skip_weight_param_allocation` is deprecated and" + "will be fully removed in future releases.", + category=DeprecationWarning, + ) + params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype self.in_features = in_features self.out_features = out_features @@ -558,72 +567,71 @@ def __init__( self.sequence_parallel = (self.tp_size > 1) and sequence_parallel - if not skip_weight_param_allocation: - self.weight_tensor = torch.empty( - self.out_features, self.in_features, + self.weight_tensor = torch.empty( + self.out_features, self.in_features, + device=torch.cuda.current_device(), + dtype=params_dtype) + + initialize_affine_weight_gpu( + self.weight_tensor, + init_method, + get_rng_state_tracker, + partition_dim=1 if self.parallel_mode == "row" else 0, + stride=1, + ) + + if self.use_bias: + self.bias_tensor = torch.empty( + self.out_features, device=torch.cuda.current_device(), dtype=params_dtype) + else: + self.bias_tensor = torch.Tensor().to(dtype=params_dtype, + device=torch.cuda.current_device()) - initialize_affine_weight_gpu( - self.weight_tensor, - init_method, - get_rng_state_tracker, - partition_dim=1 if self.parallel_mode == "row" else 0, - stride=1, - ) + with torch.no_grad(): + self.bias_tensor.zero_() - if self.use_bias: - self.bias_tensor = torch.empty( - self.out_features, - device=torch.cuda.current_device(), - dtype=params_dtype) - else: - self.bias_tensor = torch.Tensor().to(dtype=params_dtype, - device=torch.cuda.current_device()) + if parameters_split is None: + parameters_split = ("",) - with torch.no_grad(): - self.bias_tensor.zero_() + assert ( + self.out_features % len(parameters_split) == 0 + ), f"Weight and bias params cannot be split into {len(parameters_split)} parts" - if parameters_split is None: - parameters_split = ("",) + split_size = self.out_features // len(parameters_split) - assert ( - self.out_features % len(parameters_split) == 0 - ), f"Weight and bias params cannot be split into {len(parameters_split)} parts" + self.weight_names = [] + self.bias_names = [] - split_size = self.out_features // len(parameters_split) + for i, pname in enumerate(parameters_split): + wname = pname + "weight" + bname = pname + "bias" - self.weight_names = [] - self.bias_names = [] + self.register_parameter( + wname, Parameter(self.weight_tensor[i * split_size : (i+1) * split_size]) + ) - for i, pname in enumerate(parameters_split): - wname = pname + "weight" - bname = pname + "bias" + set_tensor_model_parallel_attributes( + tensor=getattr(self, wname), + is_parallel=True, + dim=1 if parallel_mode == "row" else 0, + stride=1, + ) + if self.use_bias: self.register_parameter( - wname, Parameter(self.weight_tensor[i * split_size : (i+1) * split_size]) + bname, Parameter(self.bias_tensor[i * split_size : (i+1) * split_size]) ) + else: + setattr(self, bname, torch.Tensor().to(dtype=params_dtype, + device=torch.cuda.current_device())) - set_tensor_model_parallel_attributes( - tensor=getattr(self, wname), - is_parallel=True, - dim=1 if parallel_mode == "row" else 0, - stride=1, - ) - - if self.use_bias: - self.register_parameter( - bname, Parameter(self.bias_tensor[i * split_size : (i+1) * split_size]) - ) - else: - setattr(self, bname, torch.Tensor().to(dtype=params_dtype, - device=torch.cuda.current_device())) - - if parallel_mode == "column": - set_tensor_model_parallel_attributes(getattr(self, bname), True, 0, 1) + if parallel_mode == "column": + set_tensor_model_parallel_attributes(getattr(self, bname), True, 0, 1) - self.weight_names.append(wname) - self.bias_names.append(bname) + self.weight_names.append(wname) + self.bias_names.append(bname) self.fp8_weight_shapes.append(torch.Size((self.out_features, self.in_features))) @@ -661,24 +669,22 @@ def get_fp8_weights_scratchpad( def forward( self, inp: torch.Tensor, - weight: Optional[torch.Tensor] = None, - bias: Optional[torch.Tensor] = None, + weight: Optional[torch.Tensor] = None, # pylint: disable=unused-argument + bias: Optional[torch.Tensor] = None, # pylint: disable=unused-argument is_first_microbatch: Optional[bool] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]: """ Apply the linear transformation to the input. + .. warning:: + + Arguments :attr:`weight` and :attr:`bias` are deprecated and will + be fully removed in future releases. + Parameters ---------- inp : torch.Tensor Input tensor. - weight : torch.Tensor, default = None - An optional weight tensor for the module. This argument is compulsory if module - is initialized with `skip_weight_param_allocation=True` - bias : torch.Tensor, default = None - An optional bias tensor for the module. This argument is compulsory if module - is initialized with `skip_weight_param_allocation=True` and one of `use_bias` - or `return_bias` is_first_microbatch : {True, False, None}, default = None During training using either gradient accumulation or pipeline parallelism a minibatch of data is further split @@ -694,16 +700,20 @@ def forward( produced) """ + warnings.warn( + "Arguments `weight` and `bias` are deprecated and" + "will be fully removed in future releases.", + category=DeprecationWarning, + ) + with self.prepare_forward(inp, is_first_microbatch) as inp: bias_tensor = ( - bias if bias is not None - else self.bias if self.parameters_split is None + self.bias if self.parameters_split is None else self.bias_tensor if not torch.is_grad_enabled() else self.noop_cat("bias_tensor", self.bias_names) ) weight_tensor = ( - weight if weight is not None - else self.weight if self.parameters_split is None + self.weight if self.parameters_split is None else self.weight_tensor if not torch.is_grad_enabled() else self.noop_cat("weight_tensor", self.weight_names) ) From 134dec04c545afe7ca7bfa176bd72bfa2308128f Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Fri, 14 Jul 2023 11:29:46 -0700 Subject: [PATCH 2/3] review comments Signed-off-by: Kirthi Shankar Sivamani --- .../pytorch/module/layernorm_linear.py | 27 ++++++++++--------- transformer_engine/pytorch/module/linear.py | 27 ++++++++++--------- 2 files changed, 28 insertions(+), 26 deletions(-) diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 03f96cc075..b8da737e91 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -627,7 +627,7 @@ def __init__( params_dtype: Optional[torch.dtype] = None, parallel_mode: Optional[str] = None, return_layernorm_output: bool = False, - skip_weight_param_allocation: bool = False, # pylint: disable=unused-argument + skip_weight_param_allocation: bool = False, parameters_split: Optional[Tuple[str, ...]] = None, zero_centered_gamma: bool = False, ub_bulk_wgrad: bool = False, @@ -636,11 +636,12 @@ def __init__( ) -> None: super().__init__() - warnings.warn( - "Argument `skip_weight_param_allocation` is deprecated and" - "will be fully removed in future releases.", - category=DeprecationWarning, - ) + if skip_weight_param_allocation: + warnings.warn( + "Argument `skip_weight_param_allocation` is deprecated and" + "will be fully removed in future releases.", + category=DeprecationWarning, + ) params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype self.in_features = in_features @@ -822,8 +823,8 @@ def get_fp8_weights_scratchpad( def forward( self, inp: torch.Tensor, - weight: Optional[torch.Tensor] = None, # pylint: disable=unused-argument - bias: Optional[torch.Tensor] = None, # pylint: disable=unused-argument + weight: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, is_first_microbatch: Optional[bool] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]: """ @@ -853,11 +854,11 @@ def forward( produced) """ - warnings.warn( - "Arguments `weight` and `bias` are deprecated and" - "will be fully removed in future releases.", - category=DeprecationWarning, - ) + if weight is not None or bias is not None: + raise RuntimeError( + "Arguments `weight` and `bias` are deprecated and " + "will be fully removed in future releases." + ) with self.prepare_forward(inp, is_first_microbatch) as inp: bias_tensor = ( diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 1378ea7e98..3eb97551a0 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -514,18 +514,19 @@ def __init__( return_bias: bool = False, params_dtype: Optional[torch.dtype] = None, parallel_mode: Optional[str] = None, - skip_weight_param_allocation: bool = False, # pylint: disable=unused-argument + skip_weight_param_allocation: bool = False, parameters_split: Optional[Tuple[str, ...]] = None, ub_split_rs: bool = False, ub_split_ag: bool = False, ) -> None: super().__init__() - warnings.warn( - "Argument `skip_weight_param_allocation` is deprecated and" - "will be fully removed in future releases.", - category=DeprecationWarning, - ) + if skip_weight_param_allocation: + warnings.warn( + "Argument `skip_weight_param_allocation` is deprecated and" + "will be fully removed in future releases.", + category=DeprecationWarning, + ) params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype self.in_features = in_features @@ -669,8 +670,8 @@ def get_fp8_weights_scratchpad( def forward( self, inp: torch.Tensor, - weight: Optional[torch.Tensor] = None, # pylint: disable=unused-argument - bias: Optional[torch.Tensor] = None, # pylint: disable=unused-argument + weight: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, is_first_microbatch: Optional[bool] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]: """ @@ -700,11 +701,11 @@ def forward( produced) """ - warnings.warn( - "Arguments `weight` and `bias` are deprecated and" - "will be fully removed in future releases.", - category=DeprecationWarning, - ) + if weight is not None or bias is not None: + raise RuntimeError( + "Arguments `weight` and `bias` are deprecated and " + "will be fully removed in future releases." + ) with self.prepare_forward(inp, is_first_microbatch) as inp: bias_tensor = ( From bb6bb4cc780186e25e522bdb92d6f0be56475b9e Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Fri, 14 Jul 2023 19:09:39 +0000 Subject: [PATCH 3/3] Review Signed-off-by: Kirthi Shankar Sivamani --- transformer_engine/pytorch/module/layernorm_linear.py | 3 ++- transformer_engine/pytorch/module/linear.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index b8da737e91..8d5db1693c 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -639,7 +639,8 @@ def __init__( if skip_weight_param_allocation: warnings.warn( "Argument `skip_weight_param_allocation` is deprecated and" - "will be fully removed in future releases.", + "will be fully removed in future releases. It is ignored" + "starting from v0.11.", category=DeprecationWarning, ) diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 3eb97551a0..e326db726b 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -524,7 +524,8 @@ def __init__( if skip_weight_param_allocation: warnings.warn( "Argument `skip_weight_param_allocation` is deprecated and" - "will be fully removed in future releases.", + "will be fully removed in future releases. It has ignored" + "starting from v0.11.", category=DeprecationWarning, )