# instance Normalization
对于图像风格迁移这类生成式的任务来说，我们不再关注数据集的“统一规律”，而是要注重像素，每个像素点的信息都是十分重要，于是BN这种每个批量的所有样本都做归一化的算法就不太适用了，因为BN计算归一化统计量的时候考虑了一个批量中所有图片的内容，从而造成了每个样本独特细节的丢失，同理对于LN这类考虑一个样本所有通道的算法来说，可能忽略了不同通道的差异。

*IN的算法与BN类似，计算归一化统计量的时候，考虑单个样本，单个通道的所有元素*

![](figures/Norm.png)

**虽然不是所有的模型都以独立同分布为基础，但是它可以简化常规机器学习模型的训练、提成机器学习模型的预测能力**

对比：
1. BatchNormalization：由于每个batch的分布不同，会产生Internal Covarivate Shift 的问题，引入BN解决这个问题，对象是对batch的每一层计算平均值和方差，对不同样本的同一个通道的特征做归一化(内部协变量偏移(Internal Covariate Shift)，简单的来说就是输出的分布发生了偏移，和输入的分布不一致；)
2. InstanceNormalization: 生成的结果主要依赖于某个实例的时候使用比较好，例如生成式任务等；对象是整个样本本身，对每个channel的特征进行归一化操作
3. LayerNormalization: BN容易受到batch_size的影响，而且在RNN这种变长网络中不实用，所以使用LN；计算每一层每一个样本的均值和方差，同一个样本的不同通道做归一化操作

## Note 
InstanceNorm2d and LayerNorm are very similar, but have some subtle differences. InstanceNorm2d is applied on each channel of channeled data like RGB images, but LayerNorm is usually applied on entire sample and often in NLP tasks. Additionally, LayerNorm applies elementwise affine transform, while InstanceNorm2d usually don’t apply affine transform.

In [36]:
# 模型优化之Instance Normalization
import torch as t
import torch.nn as nn
from torch import Tensor
def instance_norm(X:Tensor,
                    gamma:Tensor,
                    beta:Tensor,
                    moving_mean:Tensor,
                    moving_var:Tensor,
                    eps:float,
                    momentum:float)->tuple[Tensor,Tensor,Tensor]:
    
    """
    Instance Normalization

    Parameters:
        X: the inputs, it should be tensor with 2 or more dimensions, where the first dimension has 
        "batch size"
        gamma: known as "scale" in tensorflow , multiply by gamma
        beta: known as "center" in tensorflow , add by beta
        moving_mean: running mean of the dataset
        moving_var: running variance of the dataset
        eps: a value added to the denominator for numerical stability. Default: 1e-5
        momentum: the value used for running_mean and running_var computation. default 0.9
    """
    # 如果在预测模式下，我们直接使用传入的移动平均所得到的均值和方差
    if not t.is_grad_enabled():
        X_hat = (X-moving_mean)/t.sqrt(moving_var+eps)
    else:
        # B*C*H*W 所以shape的长度是4，没有全连接层，因为做的是图像
        assert len(X.shape)==4
        mean = X.mean(dim=(2,3),keepdim=True)
        var = t.pow((X-mean),2).mean(dim =(2,3),keepdim=True)
        # 训练模式下，使用当前的mean和variance来做标准化
        X_hat=(X-mean)/t.sqrt(var+eps)
        # 更新移动平均的均值和方差
        moving_mean = momentum*moving_mean + (1-momentum)*mean
        moving_var=momentum*moving_var+(1-momentum)*var
    # 进行缩放和移位，即乘以gamma加上beta
    Y = gamma*X_hat+beta
    return Y,moving_mean.data,moving_var.data


# 定义IN模块(nn.instanceNorm2d)


In [37]:
class InstanceNorm(nn.Module):
    def __init__(self,num_featues:int,eps=1e-5,momentum=0.9)->None:
        super().__init__()
        self.eps = eps
        self.momentum = momentum
        # 参与求梯度和迭代的拉伸和偏移 scale and center，分别初始化为1和0
        self.gamma = nn.Parameter(t.ones((1,num_featues,1,1)))
        self.beta=nn.Parameter(t.zeros((1,num_featues,1,1)))
        # 非模型参数的变量初始化为0和1
        self.moving_mean = t.zeros((1,num_featues,1,1))
        self.moving_var = t.ones((1,num_featues,1,1))
    
    def forward(self,X:Tensor)->Tensor:
        if self.moving_mean.device!=X.device:
            self.moving_mean=self.moving_mean.to(X.device)
            self.moving_var=self.moving_var.to(X.device)
        # 保存更新之后的moving_mean和moving_var
        Y,self.moving_mean,self.moving_var=instance_norm(X,self.gamma,self.beta,self.moving_mean,self.moving_var,self.eps,self.momentum)
        return Y

# 看看效果

In [38]:
#X = t.arange(0, 27*2, 1, dtype=t.float32).reshape(2, 3, 3, 3)
X=t.randn((2, 3, 3, 3),requires_grad=False)
X.shape


torch.Size([2, 3, 3, 3])

In [39]:
X

tensor([[[[-0.2945, -1.5122,  1.0822],
          [ 0.0144, -0.8685, -0.2099],
          [ 1.9219,  0.8777, -1.8266]],

         [[-0.3839,  0.2849, -0.5585],
          [-0.3494, -0.3829,  0.0442],
          [-0.0484,  0.7183,  0.3327]],

         [[ 0.0616, -0.2440, -2.1194],
          [ 0.5784, -0.3306, -0.4595],
          [-0.1795,  1.0353, -1.9222]]],


        [[[ 0.2257, -0.6211, -0.5136],
          [ 1.1761, -0.2628, -0.1389],
          [ 0.2762, -1.4807,  0.3589]],

         [[-0.1493,  0.3094, -0.3921],
          [-0.8907,  0.7273, -0.7800],
          [ 0.2006,  1.0684,  0.8334]],

         [[-0.7727,  0.9414,  0.9452],
          [-1.2443,  1.2667,  0.7676],
          [ 1.0871,  0.6903, -0.6807]]]])

In [41]:
in_pytorch = nn.InstanceNorm2d(3)
in_mymethod = InstanceNorm(3)

In [42]:
in_pytorch.forward(X)

tensor([[[[-0.1761, -1.2275,  1.0127],
          [ 0.0907, -0.6717, -0.1030],
          [ 1.7378,  0.8362, -1.4991]],

         [[-0.8684,  0.8112, -1.3068],
          [-0.7817, -0.8659,  0.2068],
          [-0.0259,  1.8995,  0.9312]],

         [[ 0.4704,  0.1574, -1.7627],
          [ 0.9995,  0.0688, -0.0632],
          [ 0.2234,  1.4673, -1.5608]]],


        [[[ 0.4752, -0.7274, -0.5747],
          [ 1.8248, -0.2185, -0.0425],
          [ 0.5469, -1.9480,  0.6643]],

         [[-0.3792,  0.3102, -0.7442],
          [-1.4937,  0.9383, -1.3273],
          [ 0.1467,  1.4511,  1.0979]],

         [[-1.2332,  0.6779,  0.6821],
          [-1.7590,  1.0406,  0.4841],
          [ 0.8403,  0.3979, -1.1306]]]])

In [43]:
in_mymethod.forward(X)

tensor([[[[-0.1761, -1.2275,  1.0127],
          [ 0.0907, -0.6717, -0.1030],
          [ 1.7378,  0.8362, -1.4991]],

         [[-0.8684,  0.8112, -1.3068],
          [-0.7817, -0.8659,  0.2068],
          [-0.0259,  1.8995,  0.9312]],

         [[ 0.4704,  0.1574, -1.7627],
          [ 0.9995,  0.0688, -0.0632],
          [ 0.2234,  1.4673, -1.5608]]],


        [[[ 0.4752, -0.7274, -0.5747],
          [ 1.8248, -0.2185, -0.0425],
          [ 0.5469, -1.9480,  0.6643]],

         [[-0.3792,  0.3102, -0.7442],
          [-1.4937,  0.9383, -1.3273],
          [ 0.1467,  1.4511,  1.0979]],

         [[-1.2332,  0.6779,  0.6821],
          [-1.7590,  1.0406,  0.4841],
          [ 0.8403,  0.3979, -1.1306]]]], grad_fn=<AddBackward0>)