In [4]:

import torch
import torch.nn as nn
 
class AdaptiveLayerNorm(nn.Module):
    def __init__(self, normalized_shape, eps=1e-5):
        super(AdaptiveLayerNorm, self).__init__()
        self.eps = eps
        self.gamma_net = nn.Sequential(
            nn.Linear(normalized_shape, normalized_shape),
            nn.ReLU(),
            nn.Linear(normalized_shape, normalized_shape)
        )
        self.beta_net = nn.Sequential(
            nn.Linear(normalized_shape, normalized_shape),
            nn.ReLU(),
            nn.Linear(normalized_shape, normalized_shape)
        )
 
    def forward(self, x):
        mean = x.mean(-1, keepdim=True)
        std = x.std(-1, keepdim=True)
        gamma = self.gamma_net(x)
        beta = self.beta_net(x)
        x_normalized = (x - mean) / (std + self.eps)
        return gamma * x_normalized + beta
 
# 示例用法
x = torch.randn(10, 20)  # 假设输入形状为 (batch_size, feature_dim)
ada_ln = AdaptiveLayerNorm(20)
output = ada_ln(x)
print(output.shape)

torch.Size([10, 20])


In [2]:
import torch as T 

x1 = T.randn([2,1,3,4])
x2 = T.randn([2,3,4])
x3 = x2.clone()[None, :, :, :]
y1 = x1-x2
y2 = x1-x3
print(y1)
print(y1.shape)
print(y2)
print(y2.shape)

tensor([[[[ 0.2226,  0.5888, -1.9424, -0.3032],
          [ 1.2601,  0.1420, -0.7135,  2.4239],
          [-1.5754, -0.6671, -0.1235, -0.1898]],

         [[ 0.4960, -0.7030, -0.0987, -0.5353],
          [ 0.0945,  0.9687, -1.8395,  1.2124],
          [-2.1711,  1.1008, -0.2923, -1.1552]]],


        [[[ 2.1922,  0.1262, -1.2223,  1.0639],
          [ 1.7572, -0.9813,  2.1042,  1.9242],
          [-0.2036,  0.4515, -0.1637,  1.0263]],

         [[ 2.4656, -1.1656,  0.6214,  0.8319],
          [ 0.5916, -0.1546,  0.9782,  0.7127],
          [-0.7993,  2.2195, -0.3325,  0.0609]]]])
torch.Size([2, 2, 3, 4])
tensor([[[[ 0.2226,  0.5888, -1.9424, -0.3032],
          [ 1.2601,  0.1420, -0.7135,  2.4239],
          [-1.5754, -0.6671, -0.1235, -0.1898]],

         [[ 0.4960, -0.7030, -0.0987, -0.5353],
          [ 0.0945,  0.9687, -1.8395,  1.2124],
          [-2.1711,  1.1008, -0.2923, -1.1552]]],


        [[[ 2.1922,  0.1262, -1.2223,  1.0639],
          [ 1.7572, -0.9813,  2.1042,  1.9242]