From 3da0418f1de79c7c1f9dc8e68ac354bf38d1ce64 Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Fri, 25 Jul 2025 14:50:05 +0800 Subject: [PATCH] Add `layer_norm` operator --- src/ntops/kernels/layer_norm.py | 47 +++++++++++++++++++++++++++++++++ src/ntops/torch.py | 28 ++++++++++++++++++++ tests/test_layer_norm.py | 37 ++++++++++++++++++++++++++ 3 files changed, 112 insertions(+) create mode 100644 src/ntops/kernels/layer_norm.py create mode 100644 tests/test_layer_norm.py diff --git a/src/ntops/kernels/layer_norm.py b/src/ntops/kernels/layer_norm.py new file mode 100644 index 0000000..44d1694 --- /dev/null +++ b/src/ntops/kernels/layer_norm.py @@ -0,0 +1,47 @@ +import functools +import math + +import ninetoothed.language as ntl +from ninetoothed import Tensor + +from ntops.kernels.reduction import arrangement + + +def application(input, weight, bias, eps, output, num_normalized_elements): + _mean = ntl.zeros(input.dtype.shape, dtype=ntl.float32) + + for i in range(input.shape[0]): + _mean += ntl.cast(input[i], ntl.float32) + + mean = ntl.sum(_mean, 0) / num_normalized_elements + + _var = ntl.zeros(input.dtype.shape, dtype=ntl.float32) + + for i in range(input.shape[0]): + diff = ntl.cast(input[i], ntl.float32) - mean + diff = ntl.where(input[i].offsets(-1) < input.source.shape[-1], diff, 0) + _var += diff * diff + + var = ntl.sum(_var, 0) / num_normalized_elements + + std = ntl.sqrt(var + eps) + + for i in range(input.shape[0]): + output[i] = (ntl.cast(input[i], ntl.float32) - mean) / std * weight[i] + bias[i] + + +def premake(ndim, normalized_shape, dtype=None, block_size=None): + dims = tuple(-(dim + 1) for dim in range(len(normalized_shape))) + + arrangement_ = functools.partial(arrangement, dim=dims, block_size=block_size) + + tensors = ( + Tensor(ndim, other=0, dtype=dtype), + Tensor(ndim, dtype=dtype), + Tensor(ndim, dtype=dtype), + Tensor(0, dtype=dtype), + Tensor(ndim, dtype=dtype), + Tensor(0, dtype=dtype, constexpr=True, value=math.prod(normalized_shape)), + ) + + return arrangement_, application, tensors diff --git a/src/ntops/torch.py b/src/ntops/torch.py index 2a01960..497d24b 100644 --- a/src/ntops/torch.py +++ b/src/ntops/torch.py @@ -23,6 +23,7 @@ import ntops.kernels.gt import ntops.kernels.isinf import ntops.kernels.isnan +import ntops.kernels.layer_norm import ntops.kernels.le import ntops.kernels.lt import ntops.kernels.mm @@ -256,6 +257,33 @@ def isnan(input): return output +def layer_norm(input, normalized_shape, weight=None, bias=None, eps=1e-5): + if isinstance(normalized_shape, int): + normalized_shape = (normalized_shape,) + + normalized_shape = tuple(normalized_shape) + + if weight is None: + weight = torch.ones_like(input) + else: + weight = weight.expand_as(input) + + if bias is None: + bias = torch.zeros_like(input) + else: + bias = bias.expand_as(input) + + output = torch.empty_like(input) + + kernel = _cached_make( + ntops.kernels.layer_norm.premake, input.ndim, normalized_shape + ) + + kernel(input, weight, bias, eps, output, math.prod(normalized_shape)) + + return output + + def mm(input, mat2, *, out=None): m, _ = input.shape _, n = mat2.shape diff --git a/tests/test_layer_norm.py b/tests/test_layer_norm.py new file mode 100644 index 0000000..3433e7d --- /dev/null +++ b/tests/test_layer_norm.py @@ -0,0 +1,37 @@ +import random + +import pytest +import torch + +import ntops.torch +from tests.skippers import skip_if_cuda_not_available +from tests.utils import generate_arguments + + +@skip_if_cuda_not_available +@pytest.mark.parametrize("eps", (1e-8, 1e-5, 1e-3)) +@pytest.mark.parametrize("bias_is_none", (False, True)) +@pytest.mark.parametrize("weight_is_none", (False, True)) +@pytest.mark.parametrize(*generate_arguments()) +def test_cuda(shape, dtype, atol, rtol, weight_is_none, bias_is_none, eps): + device = "cuda" + + input = torch.randn(shape, dtype=dtype, device=device) + normalized_shape = shape[-random.randint(1, len(shape)) :] + if weight_is_none: + weight = None + else: + weight = torch.randn(normalized_shape, dtype=dtype, device=device) + if bias_is_none: + bias = None + else: + bias = torch.randn(normalized_shape, dtype=dtype, device=device) + + ninetoothed_output = ntops.torch.layer_norm( + input, normalized_shape, weight=weight, bias=bias, eps=eps + ) + reference_output = torch.nn.functional.layer_norm( + input, normalized_shape, weight=weight, bias=bias, eps=eps + ) + + assert torch.allclose(ninetoothed_output, reference_output, atol=atol, rtol=rtol)