Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 77 additions & 65 deletions transformer_engine/pytorch/module/layernorm_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

"""LayerNormLinear API"""
import os
import warnings
from typing import Union, Optional, Callable, Tuple, List, Dict, Any


Expand Down Expand Up @@ -538,6 +539,11 @@ class LayerNormLinear(TransformerEngineBaseModule):
r"""
Applies layer normalization followed by linear transformation to the incoming data.

.. warning::

Argument :attr:`skip_weight_param_allocation` is deprecated and will
be fully removed in future releases.

Parameters
----------
in_features : int
Expand Down Expand Up @@ -585,9 +591,6 @@ class LayerNormLinear(TransformerEngineBaseModule):
used to decide whether this Linear layer is Column Parallel Linear or Row
Parallel Linear as described `here <https://arxiv.org/pdf/1909.08053.pdf>`_.
When set to `None`, no communication is performed.
skip_weight_param_allocation: bool, default = `False`
if set to `True`, weight parameter is not allocated and must be
passed as a keyword argument `weight` during the forward pass.

Optimization parameters
-----------------------
Expand Down Expand Up @@ -633,6 +636,14 @@ def __init__(
) -> None:
super().__init__()

if skip_weight_param_allocation:
warnings.warn(
"Argument `skip_weight_param_allocation` is deprecated and"
"will be fully removed in future releases. It is ignored"
"starting from v0.11.",
category=DeprecationWarning,
)

params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype
self.in_features = in_features
self.out_features = out_features
Expand Down Expand Up @@ -695,72 +706,71 @@ def __init__(
setattr(self.layer_norm_bias, "sequence_parallel", self.sequence_parallel)
self.reset_layer_norm_parameters()

if not skip_weight_param_allocation:
self.weight_tensor = torch.empty(
self.out_features, self.in_features,
self.weight_tensor = torch.empty(
self.out_features, self.in_features,
device=torch.cuda.current_device(),
dtype=params_dtype)

initialize_affine_weight_gpu(
self.weight_tensor,
init_method,
get_rng_state_tracker,
partition_dim=1 if self.parallel_mode == "row" else 0,
stride=1,
)

if self.use_bias:
self.bias_tensor = torch.empty(
self.out_features,
device=torch.cuda.current_device(),
dtype=params_dtype)
else:
self.bias_tensor = torch.Tensor().to(dtype=params_dtype,
device=torch.cuda.current_device())

initialize_affine_weight_gpu(
self.weight_tensor,
init_method,
get_rng_state_tracker,
partition_dim=1 if self.parallel_mode == "row" else 0,
stride=1,
)
with torch.no_grad():
self.bias_tensor.zero_()

if self.use_bias:
self.bias_tensor = torch.empty(
self.out_features,
device=torch.cuda.current_device(),
dtype=params_dtype)
else:
self.bias_tensor = torch.Tensor().to(dtype=params_dtype,
device=torch.cuda.current_device())
if parameters_split is None:
parameters_split = ("",)

with torch.no_grad():
self.bias_tensor.zero_()
assert (
self.out_features % len(parameters_split) == 0
), f"Weight and bias params cannot be split into {len(parameters_split)} parts"

if parameters_split is None:
parameters_split = ("",)
split_size = self.out_features // len(parameters_split)

assert (
self.out_features % len(parameters_split) == 0
), f"Weight and bias params cannot be split into {len(parameters_split)} parts"
self.weight_names = []
self.bias_names = []

split_size = self.out_features // len(parameters_split)
for i, pname in enumerate(parameters_split):
wname = pname + "weight"
bname = pname + "bias"

self.weight_names = []
self.bias_names = []
self.register_parameter(
wname, Parameter(self.weight_tensor[i * split_size : (i+1) * split_size])
)

for i, pname in enumerate(parameters_split):
wname = pname + "weight"
bname = pname + "bias"
set_tensor_model_parallel_attributes(
tensor=getattr(self, wname),
is_parallel=True,
dim=1 if parallel_mode == "row" else 0,
stride=1,
)

if self.use_bias:
self.register_parameter(
wname, Parameter(self.weight_tensor[i * split_size : (i+1) * split_size])
bname, Parameter(self.bias_tensor[i * split_size : (i+1) * split_size])
)
else:
setattr(self, bname, torch.Tensor().to(dtype=params_dtype,
device=torch.cuda.current_device()))

set_tensor_model_parallel_attributes(
tensor=getattr(self, wname),
is_parallel=True,
dim=1 if parallel_mode == "row" else 0,
stride=1,
)

if self.use_bias:
self.register_parameter(
bname, Parameter(self.bias_tensor[i * split_size : (i+1) * split_size])
)
else:
setattr(self, bname, torch.Tensor().to(dtype=params_dtype,
device=torch.cuda.current_device()))

if parallel_mode == "column":
set_tensor_model_parallel_attributes(getattr(self, bname), True, 0, 1)
if parallel_mode == "column":
set_tensor_model_parallel_attributes(getattr(self, bname), True, 0, 1)

self.weight_names.append(wname)
self.bias_names.append(bname)
self.weight_names.append(wname)
self.bias_names.append(bname)

self.fp8_weight_shapes.append(torch.Size((self.out_features, self.in_features)))

Expand Down Expand Up @@ -821,17 +831,15 @@ def forward(
"""
Apply layer normalization to the input followed by a linear transformation.

.. warning::

Arguments :attr:`weight` and :attr:`bias` are deprecated and will
be fully removed in future releases.

Parameters
----------
inp : torch.Tensor
Input tensor.
weight : torch.Tensor, default = None
An optional weight tensor for the module. This argument is compulsory if module
is initialized with `skip_weight_param_allocation=True`
bias : torch.Tensor, default = None
An optional bias tensor for the module. This argument is compulsory if module
is initialized with `skip_weight_param_allocation=True` and one of `use_bias`
or `return_bias`
is_first_microbatch : {True, False, None}, default = None
During training using either gradient accumulation or
pipeline parallelism a minibatch of data is further split
Expand All @@ -847,16 +855,20 @@ def forward(
produced)
"""

if weight is not None or bias is not None:
raise RuntimeError(
"Arguments `weight` and `bias` are deprecated and "
"will be fully removed in future releases."
)

with self.prepare_forward(inp, is_first_microbatch) as inp:
bias_tensor = (
bias if bias is not None
else self.bias if self.parameters_split is None
self.bias if self.parameters_split is None
else self.bias_tensor if not torch.is_grad_enabled()
else self.noop_cat("bias_tensor", self.bias_names)
)
weight_tensor = (
weight if weight is not None
else self.weight if self.parameters_split is None
self.weight if self.parameters_split is None
else self.weight_tensor if not torch.is_grad_enabled()
else self.noop_cat("weight_tensor", self.weight_names)
)
Expand Down
Loading