In [1]:
import torch
from torch import nn

In [19]:
class InstanceNorm(nn.Module):

    def __init__(self,
                 num_channels: int = 3,
                 eps: float = 1e-5,
                 apply_affine: bool = True) -> None:
        super().__init__()

        self.num_channels = num_channels
        self.eps = eps
        self.apply_affine = apply_affine

        if self.apply_affine:
            self.scale = nn.Parameter(torch.ones(num_channels))
            self.shift = nn.Parameter(torch.zeros(num_channels))

    def forward(self, x: torch.Tensor):
        """
        x is a tensor of shape [batch_size, channels, *] .
        * denotes any number of (possibly 0) dimensions.
        For example, in an image (2D) convolution this will be [batch_size, channels, height, width]
        """
        
        assert self.num_channels == x.shape[1]

        x_old_shape = x.shape
        x = x.flatten(start_dim=2) # to [batch_size, channels, n]
        var, mean = torch.var_mean(x, dim = [-1], keepdim=True)

        x_hat = (x - mean) / torch.sqrt(var + self.eps)
        print(x_hat.shape)
        #x_hat = x_hat.flatten(start_dim=2)
        
        if self.apply_affine:
            x_hat = self.scale.reshape(1, -1, 1) * x_hat + self.shift.reshape(1, -1, 1)
            
        return x_hat.reshape(x_old_shape)


In [24]:


def print_info(x):
    print(f"""
    shape: {x.shape}
    mean: {x.mean(dim = [2, 3])}
    var: {x.var(dim=[2, 3])}
    """
    )


x = torch.rand([2, 3, 2, 4])
print_info(x)

instance_norm = InstanceNorm(3)
x = instance_norm(x)
print_info(x)



    shape: torch.Size([2, 3, 2, 4])
    mean: tensor([[0.6075, 0.5635, 0.5203],
        [0.3285, 0.4389, 0.4299]])
    var: tensor([[0.0836, 0.0629, 0.0732],
        [0.0558, 0.1078, 0.1039]])
    
torch.Size([2, 3, 8])

    shape: torch.Size([2, 3, 2, 4])
    mean: tensor([[-4.4703e-08, -8.4750e-08, -2.9802e-08],
        [-2.9802e-08,  4.4703e-08, -1.4901e-08]], grad_fn=<MeanBackward1>)
    var: tensor([[0.9999, 0.9998, 0.9999],
        [0.9998, 0.9999, 0.9999]], grad_fn=<VarBackward0>)
    


In [7]:
l

[1, 2, 3]