Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 2 additions & 9 deletions include/tvm/topi/nn/rms_norm.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,22 +41,18 @@ using namespace tvm::te;
* \param data N-D tensor with shape [d_0, d_1, ..., d_{N-1}]
* \param weight K-D tensor with shape [r_0, r_1, ..., r_{K-1}] where K == len(axis) and
* d_{axis_k} == r_k
* \param bias Optional, K-D tensor with shape [r_0, r_1, ..., r_{K-1}] where
* d_{axis_k} == r_k
* \param axis The axis to normalize over.
* \param epsilon The epsilon value to avoid division by zero.
* \param name The name of the operation.
* \param tag The tag to mark the operation.
* \return The normalized tensor, with the same shape as data.
*/
inline Tensor rms_norm(const Tensor& data, const Tensor& weight, const Tensor& bias,
const Array<Integer>& axis, double epsilon, std::string name = "T_rms_norm",
inline Tensor rms_norm(const Tensor& data, const Tensor& weight, const Array<Integer>& axis,
double epsilon, std::string name = "T_rms_norm",
std::string tag = kInjective) {
const auto& data_type = data->dtype;
const auto& weight_type = weight.defined() ? weight->dtype : data_type;
ICHECK(data_type == weight_type) << "rms_norm: data and weight must have the same type";
const auto& bias_type = bias.defined() ? bias->dtype : data_type;
ICHECK(data_type == bias_type) << "rms_norm: data and bias must have the same type";

auto square = multiply(data, data);
auto square_sum = sum(square, axis, /*keepdims=*/false, /*atleast1d=*/true);
Expand All @@ -80,9 +76,6 @@ inline Tensor rms_norm(const Tensor& data, const Tensor& weight, const Tensor& b
auto output =
data(indices) * weight(reduce_indices) *
tvm::rsqrt(square_sum(non_reduce_indices) / reduce_extent + make_const(data_type, epsilon));
if (bias.defined()) {
output += bias(reduce_indices);
}
return output;
};
auto rms_norm = tvm::te::compute(data->shape, rms_norm_func, name, tag);
Expand Down
4 changes: 3 additions & 1 deletion python/tvm/relax/frontend/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,7 @@
# specific language governing permissions and limitations
# under the License.
"""A PyTorch-like API to build IRModules."""
from . import spec
from . import op, spec
from .core import Effect, Module, ModuleList, Parameter, Tensor
from .modules import Embedding, IOEffect, KVCache, Linear, RMSNorm
from .op import *
4 changes: 2 additions & 2 deletions python/tvm/relax/frontend/nn/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def data(self, data: Union[None, NDArray, np.ndarray, "torch.Tensor"]) -> None:

def to(self, dtype: Optional[str] = None) -> None: # pylint: disable=invalid-name
"""Change the dtype of the parameter if it is not bound to any concrete data"""
if dtype is not None and self._data is not None:
if dtype is not None:
if self._data is not None:
raise ValueError(
"Changing the dtype of a Parameter that has been bound to concrete "
Expand Down Expand Up @@ -267,7 +267,7 @@ def named_parameters(self, prefix: str = "") -> Iterator[Tuple[str, Parameter]]:

Yields
------
(str, Parameter) Tuple containing the name and parameter
(str, Parameter) - Tuple containing the name and parameter
"""
yield from _attribute_finder(
self, prefix, condition_yield=lambda x: isinstance(x, Parameter)
Expand Down
279 changes: 277 additions & 2 deletions python/tvm/relax/frontend/nn/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,15 @@
# specific language governing permissions and limitations
# under the License.
"""Builtin Modules."""
from typing import List, Optional
from typing import List, Optional, Sequence, Union

from tvm import relax as rx
from tvm import tir
from tvm._ffi import register_func
from tvm.runtime import NDArray

from .core import Effect, Tensor
from . import op
from .core import Effect, Module, Parameter, Tensor, get_default_dtype


class IOEffect(Effect):
Expand Down Expand Up @@ -49,3 +53,274 @@ def finalize(self) -> List[rx.Var]:
def print_(self, tensor: Tensor) -> None:
"""Encloses the side effect of NDArray printing"""
raise NotImplementedError


@register_func("effect.print")
def _print(_, array: NDArray) -> None:
print(f"effect.print: shape = {array.shape}, dtype = {array.dtype}, data =\n{array}")


class Linear(Module):
"""
Module for linear layer.
"""

def __init__( # pylint: disable=too-many-arguments
self,
in_features: int,
out_features: int,
bias: bool = True,
dtype: Optional[str] = None,
out_dtype: Optional[str] = None,
):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.out_dtype = out_dtype
self.weight = Parameter((out_features, in_features), dtype)
if bias:
self.bias = Parameter((out_features,), dtype)
else:
self.bias = None

def forward(self, x: Tensor) -> Tensor: # pylint: disable=invalid-name
"""
Forward method for linear layer.

Parameters
----------
x : Tensor
The input tensor.

Returns
-------
ret : Tensor
The output tensor for the linear layer.
"""
# x: [*B, in_features]
# w: [in_features, out_features]
w = op.permute_dims(self.weight) # pylint: disable=invalid-name
# x: [*B, out_features]
x = op.matmul(x, w, out_dtype=self.out_dtype)
if self.bias is not None:
x = x + self.bias
return x


class RMSNorm(Module):
"""
Module for rms norm layer.
"""

def __init__(
self,
hidden_size: int,
axes: Union[int, List[int]],
epsilon: float = 1e-5,
bias: bool = True,
dtype: Optional[str] = None,
):
super().__init__()
self.epsilon = epsilon
self.axes = axes
self.weight = Parameter((hidden_size,), dtype=dtype)
if bias:
self.bias = Parameter((hidden_size,), dtype=dtype)
else:
self.bias = None

# pylint: disable=invalid-name
def forward(self, x: Tensor):
"""
Forward method for rms norm layer.

Parameters
----------
x : Tensor
The input tensor.

Returns
-------
ret : Tensor
The output tensor for the rms norm layer.
"""
out = op.rms_norm(x, weight=self.weight, axes=self.axes, epsilon=self.epsilon)
if self.bias:
out = op.add(out, self.bias)
return out

# pylint: enable=invalid-name


class KVCache(Effect):
"""
Effect to implement KVCache.
"""

init_seq_len: int
unit_shape: List[int]
dtype: str
cache: Optional[rx.Var]

def __init__(
self,
init_seq_len: int,
unit_shape: Sequence[int],
dtype: Optional[str] = None,
):
if dtype is None:
dtype = get_default_dtype()
# Usually the shape is: [init_seq_len, num_heads, head_dim]
# and unit_shape = [num_heads, head_dim]
self.init_seq_len = init_seq_len
self.unit_shape = [int(i) for i in unit_shape]
self.dtype = dtype

def emit_init(self, name_hint: str, bb: rx.BlockBuilder): # pylint: disable=arguments-renamed
"""
Emit the initialization of the KVCache effect.

Parameters
----------
name_hint : str
The name hint of the initialization binding Var.

bb : relax.BlockBuilder
The relax BlockBuilder to emit.
"""
init_shape = rx.ShapeExpr([self.init_seq_len] + self.unit_shape)
return [
bb.emit(
rx.Call(
rx.extern("vm.builtin.attention_kv_cache_create"),
args=[rx.op.zeros(init_shape, self.dtype), init_shape, rx.PrimValue(0)],
sinfo_args=[rx.ObjectStructInfo()],
),
name_hint=name_hint,
)
]

def create(self, name_hint: str) -> rx.Var:
"""
Create the implicit inputs to a relax.Function that represents the KVCache effect.

Parameters
----------
name_hint : str
The name hint of the relax.Var.

Returns
-------
ret : relax.Var
The relax.Var for KVCache.
"""
self.cache = rx.Var(name_hint, struct_info=rx.ObjectStructInfo())
return [self.cache]

def finalize(self) -> List[rx.Var]:
"""
Finalize the KVCache effect as the implicit return value of a relax.Function.

Returns
-------
ret : List[rx.Var]
The output relax.Var as KVCache.
"""
result = self.cache
self.cache = None
return [result]

def to(self, dtype: Optional[str] = None) -> None:
"""
Convert the KVCache effect to specific dtype.

Parameters
----------
dtype : Optional[str]
The target data type to convert.
"""
if dtype is not None:
self.dtype = dtype

def view(self, seq_len: tir.Var) -> Tensor:
"""
View the last elements in KVCache.

Parameters
----------
seq_len : tir.Var
The number of last elements to view.

Returns
-------
ret : Tensor
The last tensor to view.
"""
shape = rx.ShapeExpr([seq_len] + self.unit_shape)
return Tensor(
_expr=rx.BlockBuilder.current().emit(
rx.Call(
rx.extern("vm.builtin.attention_kv_cache_view"),
args=[self.cache, shape],
sinfo_args=[rx.TensorStructInfo(shape, self.dtype)],
)
)
)

def append(self, new_element: Tensor) -> None:
"""
Append a new element in KVCache.

Parameters
----------
new_element : Tensor
The new tensor to append.
"""
if new_element.dtype != self.dtype:
raise TypeError(
f'KVCache has been set to use dtype "{self.dtype}", '
f'but got "{new_element.dtype}"'
)
self.cache = rx.BlockBuilder.current().emit(
rx.Call(
rx.extern("vm.builtin.attention_kv_cache_append"),
args=[self.cache, new_element._expr], # pylint: disable=protected-access
sinfo_args=[rx.ObjectStructInfo()],
)
)


class Embedding(Module):
"""
Module for embedding layer.
"""

def __init__(self, num: int, dim: int, dtype: Optional[str] = None):
self.num = num
self.dim = dim
self.weight = Parameter((num, dim), dtype=dtype)

def forward(self, x: Tensor): # pylint: disable=invalid-name
"""
Forward method for embedding layer.

Parameters
----------
x : Tensor
The input tensor.

Returns
-------
ret : Tensor
The output tensor for the embedding layer.
"""
if x.ndim == 1:
return op.take(self.weight, x, axis=0)
return op.reshape(
op.take(
self.weight,
op.reshape(x, shape=[-1]),
axis=0,
),
shape=[*x.shape, self.dim], # TODO(@junrushao): revisit and remove self.dim
)
12 changes: 2 additions & 10 deletions python/tvm/relax/frontend/nn/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,7 +488,6 @@ def softmax(x: Tensor, axis: int = -1, name: str = "softmax") -> Tensor:
def rms_norm(
x: Tensor,
weight: Tensor,
bias: Optional[Tensor],
axes: Union[int, List[int]],
epsilon: float = 1e-5,
name: str = "rms_norm",
Expand All @@ -501,7 +500,7 @@ def rms_norm(

.. math::

out = \frac{data}{\sqrt{mean(data, axis)+\epsilon}} * weight + bias
out = \frac{data}{\sqrt{mean(data, axis)+\epsilon}} * weight

Parameters
----------
Expand All @@ -511,9 +510,6 @@ def rms_norm(
weight : Tensor
The scale factor.

bias : Tensor
Optional offset factor.

axes : Union[int, List[int]]
The axes that along which the normalization is applied.

Expand All @@ -528,11 +524,7 @@ def rms_norm(
result : Tensor
The computed result.
"""
if bias is None:
bias = _op.zeros(weight.shape, dtype=weight.dtype)
else:
bias = bias._expr
return _wrap_nested(_op.nn.rms_norm(x._expr, weight._expr, bias, axes, epsilon), name)
return _wrap_nested(_op.nn.rms_norm(x._expr, weight._expr, axes, epsilon), name)


def triu(x: Tensor, diagonal: int = 0, name: str = "triu") -> Tensor:
Expand Down
Loading