In [None]:
# | default_exp utils/normalizations

# Imports

In [None]:
# | export


import torch
from torch import nn

# Normalizations

In [None]:
# | export


class DyT(nn.Module):
    # As proposed in Transformers without Normalization: https://arxiv.org/pdf/2503.10622

    def __init__(self, normalized_shape: int | list[int], alpha0: float = 0.5):
        super().__init__()

        if isinstance(normalized_shape, int):
            normalized_shape = (normalized_shape,)

        self.normalized_shape = normalized_shape
        self.alpha0 = alpha0

        self.alpha = nn.Parameter(torch.tensor([alpha0]))
        self.weight = nn.Parameter(torch.ones(normalized_shape, dtype=torch.float32))
        self.bias = nn.Parameter(torch.zeros(normalized_shape, dtype=torch.float32))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = torch.tanh(self.alpha * x)
        x = x * self.weight + self.bias
        return x

    def extra_repr(self) -> str:
        return f"normalized_shape={self.normalized_shape}, alpha0={self.alpha0}"

In [None]:
sample_input = 5 * torch.randn(1, 30, 20, 10)
test = DyT((20, 10))
output = test(sample_input)
test, output.shape, output.min(), output.max()


[1m([0m
    [1;35mDyT[0m[1m([0m[33mnormalized_shape[0m=[1m([0m[1;36m20[0m, [1;36m10[0m[1m)[0m, [33malpha0[0m=[1;36m0[0m[1;36m.5[0m[1m)[0m,
    [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m1[0m, [1;36m30[0m, [1;36m20[0m, [1;36m10[0m[1m][0m[1m)[0m,
    [1;35mtensor[0m[1m([0m[1;36m-1[0m., [33mgrad_fn[0m=[1m<[0m[1;95mMinBackward1[0m[39m>[0m[1;39m)[0m[39m,[0m
[39m    [0m[1;35mtensor[0m[1;39m([0m[1;36m1.0000[0m[39m, [0m[33mgrad_fn[0m[39m=<MaxBackward1[0m[1m>[0m[1m)[0m
[1m)[0m

# Architecture

In [None]:
# | export


def get_norm_layer(normalization_name: str, *args, **kwargs):
    if normalization_name == "layernorm":
        norm_layer = nn.LayerNorm(*args, **kwargs)
    elif normalization_name == "batchnorm" or normalization_name == "batchnorm1d":
        norm_layer = nn.BatchNorm1d(*args, **kwargs)
    elif normalization_name == "batchnorm2d":
        norm_layer = nn.BatchNorm2d(*args, **kwargs)
    elif normalization_name == "batchnorm3d":
        norm_layer = nn.BatchNorm3d(*args, **kwargs)
    elif normalization_name == "dyt":
        norm_layer = DyT(*args, **kwargs)
    else:
        raise ValueError(f"Normalization {normalization_name} not implemented")

    return norm_layer

In [None]:
get_norm_layer("layernorm", 10)

[1;35mLayerNorm[0m[1m([0m[1m([0m[1;36m10[0m,[1m)[0m, [33meps[0m=[1;36m1e[0m[1;36m-05[0m, [33melementwise_affine[0m=[3;92mTrue[0m[1m)[0m

# nbdev_export

In [None]:
!nbdev_export