In [1]:
import torch
from torch import nn

In [4]:
def batch_norm(X, gamma, beta, moving_mean, moving_var, eps, momentum):
    # 通过is_grad_enabled来判断当前模式是训练模式还是预测模式
    if not torch.is_grad_enabled():
        # 如果是预测模式，直接使用传入的均值和方差
        X_hat = (X - moving_mean) / torch.sqrt(moving_var + eps)
    else:
        assert len(X.shape) in (2, 4)
        if len(X.shape) == '2':
            # 使用使用全连接层的情况下，计算特征维上的均值和方差
            mean = X.mean(dim=0)
            var = ((X - mean) ** 2).mean(dim=0)
        else:
            # 使用二维卷积层的情况下，计算通道维上(axis=1)的均值和方差
            mean = X.mean(dim=(0, 2, 3), keepdim=True)
            var = ((X - mean) ** 2).mean(dim=(0, 2, 3), keepdim=True)
        # 训练模式下用当前的均值和方差做标准化
        X_hat = (X - mean) / torch.sqrt(var + eps)  # 卷积时此处会广播
        # 更新移动平均的均值和方差
        moving_mean = momentum * moving_mean + (1.0 - momentum) * mean
        moving_var = momentum * moving_var + (1.0 - momentum) * var
    Y = gamma * X_hat + beta
    return Y, moving_mean.data, moving_var.data

class BatchNorm(nn.Module):
    def __init__(self, num_features, num_dims):
        super().__init__()

        if num_dims == 2:
            shape = (1, num_features)
        else:
            shape = (1, num_features, 1, 1)

        # 参与求梯度和迭代的拉伸和偏移参数，分别初始化成1和0
        self.gamma = nn.Parameter(torch.ones(shape))
        self.beta = nn.Parameter(torch.zeros(shape))
        # 不参与求梯度和迭代的变量，全在内存上初始化成0
        self.moving_mean = torch.zeros(shape)
        self.moving_var = torch.zeros(shape)
        
    def forward(self, X):
        # 如果X不在内存上，将moving_mean和moving_var复制到X所在显存上
        if self.moving_mean.device != X.device:
            self.moving_mean = self.moving_mean.to(X.device)
            self.moving_var = self.moving_var.to(X.device)
        # 保存更新过的moving_mean和moving_var
        Y, self.moving_mean, self.moving_var = batch_norm(
            X, self.gamma, self.beta, self.moving_mean,
            self.moving_var, eps=1e-5, momentum=0.9)
        return Y












In [5]:
BN = BatchNorm(num_features=64, num_dims=4)
X = torch.rand((2, 64, 100, 40))
Y = BN(X)
print(Y.shape)

torch.Size([2, 64, 100, 40])


In [11]:
class SSN2d(nn.Module):
    def __init__(self, num_channels, S):
        super().__init__()
        self.S = S
        shape = (1, num_channels*S, 1, 1)
        # 参与求梯度和迭代的拉伸和偏移参数，分别初始化成1和0
        self.gamma = nn.Parameter(torch.ones(shape))
        self.beta = nn.Parameter(torch.zeros(shape))
        # 不参与求梯度和迭代的变量，全在内存上初始化成0
        self.moving_mean = torch.zeros(shape)
        self.moving_var = torch.zeros(shape)
    
    def subspectral_normalization(self, X, gamma, beta, S, moving_mean, moving_var, eps, momentum):
        N, C, T, F = X.size()
        X = X.view(N, C*S, T, F//S)
        # 通过is_grad_enabled来判断当前模式是训练模式还是预测模式
        if not torch.is_grad_enabled():
            # 如果是预测模式，直接使用传入的均值和方差
            X_hat = (X - moving_mean) / torch.sqrt(moving_var + eps)
        else:
            # 训练模式下用当前的均值和方差做标准化
            mean = X.mean(dim=(0, 2, 3), keepdim=True)
            var = ((X - mean) ** 2).mean(dim=(0, 2, 3), keepdim=True)
            X_hat = (X - mean) / torch.sqrt(var + eps)
            # 更新移动平均的均值和方差
            moving_mean = momentum * moving_mean + (1.0 - momentum) * mean
            moving_var = momentum * moving_var + (1.0 - momentum) * var
        Y = (gamma * X_hat + beta).view(N, C, T, F)
        return Y, moving_mean.data, moving_var.data
    def forward(self, X):
        # 如果X不在内存上，将moving_mean和moving_var复制到X所在显存上
        if self.moving_mean.device != X.device:
            self.moving_mean = self.moving_mean.to(X.device)
            self.moving_var = self.moving_var.to(X.device)
        # 保存更新过的moving_mean和moving_var
        Y, self.moving_mean, self.moving_var = self.subspectral_normalization(
            X, self.gamma, self.beta, self.S, self.moving_mean, 
            self.moving_var, eps=1e-5, momentum=0.9)
        return Y

In [15]:
SSN = SSN2d(num_channels=4, S=4)
X = torch.rand((2, 4, 2, 8))
Y = SSN(X)
print(Y)

tensor([[[[ 0.1144, -0.3538,  1.2873, -1.4031,  1.2023, -0.5034, -1.6049,
           -0.1083],
          [ 1.3196, -0.9511, -1.2040, -1.0659, -1.1939, -0.9900, -0.7897,
            1.1286]],

         [[-1.7339,  0.4955,  1.3794,  0.1014, -1.2262, -0.4035,  0.0567,
            0.0411],
          [ 0.4297, -0.9388,  0.2911, -1.2167,  1.2421, -0.8202,  1.2921,
            0.2289]],

         [[-0.9226, -0.9577, -0.4917,  1.5446, -1.4477,  0.9155,  0.5207,
            0.3896],
          [ 0.1566, -0.2950,  1.4262,  0.4455,  0.3263, -1.1356, -0.4894,
            1.2034]],

         [[-1.1183, -0.7862,  1.4285,  1.5689, -0.8271,  1.0257,  0.6956,
           -0.6104],
          [-0.7010,  1.5601,  0.3912, -1.1139, -1.1063,  0.2836, -0.2685,
           -1.4253]]],


        [[[-0.7654, -1.0932,  1.0647,  1.1492,  0.7270, -0.7571,  1.5385,
           -0.4940],
          [-0.2911,  1.5667,  0.1866,  0.4391,  1.2680,  1.0745,  0.3879,
           -0.8854]],

         [[ 0.3204, -0.6707, -1.0126, 