In [2]:
import torch

In [1]:
# Layer normalization
# To implement Layer Normalization on a tensor with shape (B,T,C) we need to calculate the mean and variance across C.

In [3]:
x = torch.randn(2, 3, 4)
x

tensor([[[-0.4960,  1.2352,  1.2042, -0.8860],
         [ 0.8651, -1.2128,  0.0826,  1.1599],
         [-0.2576,  0.0948, -1.2240,  0.8942]],

        [[ 0.3061,  0.1813, -1.4799, -0.5743],
         [-0.5189,  0.2496, -0.5847,  0.9014],
         [ 0.5438, -0.1844,  0.8721,  1.5228]]])

In [4]:
mean = x.mean(-1, keepdim=True)
mean

tensor([[[ 0.2644],
         [ 0.2237],
         [-0.1232]],

        [[-0.3917],
         [ 0.0119],
         [ 0.6886]]])

In [5]:
var = ((x - mean) ** 2).mean(-1, keepdim=True)
var

tensor([[[0.9318],
         [0.8428],
         [0.5781]],

        [[0.5082],
         [0.3713],
         [0.3781]]])

In [6]:
eps = 1e-5
x_norm = (x-mean) / torch.sqrt(var + eps)

In [7]:
x_norm

tensor([[[-0.7877,  1.0057,  0.9736, -1.1917],
         [ 0.6986, -1.5647, -0.1537,  1.0198],
         [-0.1769,  0.2866, -1.4478,  1.3381]],

        [[ 0.9789,  0.8037, -1.5265, -0.2561],
         [-0.8709,  0.3902, -0.9790,  1.4597],
         [-0.2354, -1.4196,  0.2984,  1.3566]]])

In [8]:
class LayerNorm(torch.nn.Module):
    def __init__(self, features, eps=1e-5):
        super().__init__()
        self.gamma = torch.nn.Parameter(torch.ones(features))           # scale parameter
        self.beta = torch.nn.Parameter(torch.zeros(features))           # shift parameter
        self.eps = eps

    def forward(self, x):
        mean = x.mean(-1, keepdim=True)
        var = ((x - mean) ** 2).mean(-1, keepdim=True)
        x_norm = (x - mean) / torch.sqrt(var + self.eps)
        return self.gamma * x_norm + self.beta

In [9]:
ln = LayerNorm(4)

In [10]:
ln(x)

tensor([[[-0.7877,  1.0057,  0.9736, -1.1917],
         [ 0.6986, -1.5647, -0.1537,  1.0198],
         [-0.1769,  0.2866, -1.4478,  1.3381]],

        [[ 0.9789,  0.8037, -1.5265, -0.2561],
         [-0.8709,  0.3902, -0.9790,  1.4597],
         [-0.2354, -1.4196,  0.2984,  1.3566]]], grad_fn=<AddBackward0>)

In [4]:
dic = {'a': 1, 'b': 2}
vocab = set(dic.keys())
w = ["ab", "cd","ac","b","ba"]
for i in w:
    if any(ch not in vocab for ch in i):
        print(i)

cd
ac
