Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions src/ntops/kernels/layer_norm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import functools
import math

import ninetoothed.language as ntl
from ninetoothed import Tensor

from ntops.kernels.reduction import arrangement


def application(input, weight, bias, eps, output, num_normalized_elements):
_mean = ntl.zeros(input.dtype.shape, dtype=ntl.float32)

for i in range(input.shape[0]):
_mean += ntl.cast(input[i], ntl.float32)

mean = ntl.sum(_mean, 0) / num_normalized_elements

_var = ntl.zeros(input.dtype.shape, dtype=ntl.float32)

for i in range(input.shape[0]):
diff = ntl.cast(input[i], ntl.float32) - mean
diff = ntl.where(input[i].offsets(-1) < input.source.shape[-1], diff, 0)
_var += diff * diff

var = ntl.sum(_var, 0) / num_normalized_elements

std = ntl.sqrt(var + eps)

for i in range(input.shape[0]):
output[i] = (ntl.cast(input[i], ntl.float32) - mean) / std * weight[i] + bias[i]


def premake(ndim, normalized_shape, dtype=None, block_size=None):
dims = tuple(-(dim + 1) for dim in range(len(normalized_shape)))

arrangement_ = functools.partial(arrangement, dim=dims, block_size=block_size)

tensors = (
Tensor(ndim, other=0, dtype=dtype),
Tensor(ndim, dtype=dtype),
Tensor(ndim, dtype=dtype),
Tensor(0, dtype=dtype),
Tensor(ndim, dtype=dtype),
Tensor(0, dtype=dtype, constexpr=True, value=math.prod(normalized_shape)),
)

return arrangement_, application, tensors
28 changes: 28 additions & 0 deletions src/ntops/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import ntops.kernels.gt
import ntops.kernels.isinf
import ntops.kernels.isnan
import ntops.kernels.layer_norm
import ntops.kernels.le
import ntops.kernels.lt
import ntops.kernels.mm
Expand Down Expand Up @@ -256,6 +257,33 @@ def isnan(input):
return output


def layer_norm(input, normalized_shape, weight=None, bias=None, eps=1e-5):
if isinstance(normalized_shape, int):
normalized_shape = (normalized_shape,)

normalized_shape = tuple(normalized_shape)

if weight is None:
weight = torch.ones_like(input)
else:
weight = weight.expand_as(input)

if bias is None:
bias = torch.zeros_like(input)
else:
bias = bias.expand_as(input)

output = torch.empty_like(input)

kernel = _cached_make(
ntops.kernels.layer_norm.premake, input.ndim, normalized_shape
)

kernel(input, weight, bias, eps, output, math.prod(normalized_shape))

return output


def mm(input, mat2, *, out=None):
m, _ = input.shape
_, n = mat2.shape
Expand Down
37 changes: 37 additions & 0 deletions tests/test_layer_norm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import random

import pytest
import torch

import ntops.torch
from tests.skippers import skip_if_cuda_not_available
from tests.utils import generate_arguments


@skip_if_cuda_not_available
@pytest.mark.parametrize("eps", (1e-8, 1e-5, 1e-3))
@pytest.mark.parametrize("bias_is_none", (False, True))
@pytest.mark.parametrize("weight_is_none", (False, True))
@pytest.mark.parametrize(*generate_arguments())
def test_cuda(shape, dtype, atol, rtol, weight_is_none, bias_is_none, eps):
device = "cuda"

input = torch.randn(shape, dtype=dtype, device=device)
normalized_shape = shape[-random.randint(1, len(shape)) :]
if weight_is_none:
weight = None
else:
weight = torch.randn(normalized_shape, dtype=dtype, device=device)
if bias_is_none:
bias = None
else:
bias = torch.randn(normalized_shape, dtype=dtype, device=device)

ninetoothed_output = ntops.torch.layer_norm(
input, normalized_shape, weight=weight, bias=bias, eps=eps
)
reference_output = torch.nn.functional.layer_norm(
input, normalized_shape, weight=weight, bias=bias, eps=eps
)

assert torch.allclose(ninetoothed_output, reference_output, atol=atol, rtol=rtol)