/
layer_norm.py
91 lines (74 loc) · 3.11 KB
/
layer_norm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
#
# For licensing see accompanying LICENSE file.
# Copyright (C) 2022 Apple Inc. All Rights Reserved.
#
from torch import nn, Tensor, Size
from typing import Optional, Union, List
import torch
from . import register_norm_fn
@register_norm_fn(name="layer_norm")
class LayerNorm(nn.LayerNorm):
"""
Applies `Layer Normalization <https://arxiv.org/abs/1607.06450>`_ over a input tensor
Args:
normalized_shape (int or list or torch.Size): input shape from an expected input
of size
.. math::
[* \times \text{normalized\_shape}[0] \times \text{normalized\_shape}[1]
\times \ldots \times \text{normalized\_shape}[-1]]
If a single integer is used, it is treated as a singleton list, and this module will
normalize over the last dimension which is expected to be of that specific size.
eps (Optional, float): Value added to the denominator for numerical stability. Default: 1e-5
elementwise_affine (bool): If ``True``, use learnable affine parameters. Default: ``True``
Shape:
- Input: :math:`(N, *)` where :math:`N` is the batch size
- Output: same shape as the input
"""
def __init__(
self,
normalized_shape: Union[int, List[int], Size],
eps: Optional[float] = 1e-5,
elementwise_affine: Optional[bool] = True,
*args,
**kwargs
):
super().__init__(
normalized_shape=normalized_shape,
eps=eps,
elementwise_affine=elementwise_affine,
)
def profile_module(self, input: Tensor) -> (Tensor, float, float):
params = sum([p.numel() for p in self.parameters()])
return input, params, 0.0
@register_norm_fn(name="layer_norm_2d")
class LayerNorm2D(nn.GroupNorm):
"""
Applies `Layer Normalization <https://arxiv.org/abs/1607.06450>`_ over a 4D input tensor
Args:
num_features (int): :math:`C` from an expected input of size :math:`(N, C, H, W)`
eps (Optional, float): Value added to the denominator for numerical stability. Default: 1e-5
elementwise_affine (bool): If ``True``, use learnable affine parameters. Default: ``True``
Shape:
- Input: :math:`(N, C, H, W)` where :math:`N` is the batch size, :math:`C` is the number of input channels,
:math:`H` is the input height, and :math:`W` is the input width
- Output: same shape as the input
"""
def __init__(
self,
num_features: int,
eps: Optional[float] = 1e-5,
elementwise_affine: Optional[bool] = True,
*args,
**kwargs
) -> None:
super().__init__(
num_channels=num_features, eps=eps, affine=elementwise_affine, num_groups=1
)
self.num_channels = num_features
def __repr__(self):
return "{}(num_channels={}, eps={}, affine={})".format(
self.__class__.__name__, self.num_channels, self.eps, self.affine
)
def profile_module(self, input: Tensor) -> (Tensor, float, float):
params = sum([p.numel() for p in self.parameters()])
return input, params, 0.0