One of the challenges in the study of generative adversarial networks is the instability of its training. Spectral Normalization is a weight normalization technique to stabilize the training of the discriminator. It is computationally light and easy to incorporate into existing implementations. Spectrally normalized GANs (SN-GANs) are capable of generating images of better or equal quality relative to the previous training stabilization techniques on CIFAR10, STL-10, and ILSVRC2012 datasets. Spectral Normalization  has the following properties:
<br><br>(a) Lipschitz constant is the only hyper-parameter to be tuned, and the algorithm does not require intensive tuning of the only hyper-parameter for satisfactory performance.
<br>(b) Implementation is simple and the additional computational cost is small.
<br><br> Spectral Normalization leads to higher Inception Score and FID score.

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
def l2normalize(v):
    return v / v.norm()

In [3]:
# Ref: https://github.com/christiancosgrove/pytorch-spectral-normalization-gan/blob/master/spectral_normalization.py

class SpectralNorm(nn.Module):

    def __init__(self, module):
        super().__init__()
        self.module = module
        self.name = "weight"
        if not self.made_params():
            self.make_params()
            

    def _update_u_v(self):
        u = getattr(self.module, "u")
        v = getattr(self.module, "v")
        w = getattr(self.module, self.name)

        height = w.data.shape[0] # Num of output channels i.e. C
        
        v.data = l2normalize(torch.mv(torch.t(w.view(height,-1).data), u.data)) # C x (H*W) -> (H*W) x C -> [(H*W) x C] x [C x 1] -> gives v matrix of order (H*W) x 1
        u.data = l2normalize(torch.mv(w.view(height,-1).data, v.data)) # C x (H*W) -> [C x (H*W)] x [(H*W) x 1] -> gives u matrix of order C x 1

        sigma = u.dot(w.view(height, -1).mv(v)) # Spectral Norm    u^t x w x v
        setattr(self.module, self.name, nn.Parameter(w / sigma))
        

    def made_params(self):
        try:
            u = getattr(self.module, "u")
            v = getattr(self.module, "v")
            return True
        except AttributeError:
            return False


    def make_params(self):
        w = getattr(self.module, self.name) # weight

        height = w.data.shape[0] # Num of output channels i.e. C
        width = w.view(height, -1).data.shape[1] # H x W

        u = nn.Parameter(w.data.new(height).normal_(0, 1), requires_grad=False) # C x 1    w.data.new() creates a new tensor of same dtype as w
        v = nn.Parameter(w.data.new(width).normal_(0, 1), requires_grad=False) # (H*W) x 1
        u.data = l2normalize(u.data)
        v.data = l2normalize(v.data)

        self.module.register_parameter("u", u)
        self.module.register_parameter("v", v)
    

    def forward(self, *args):
        self._update_u_v()
        return self.module.forward(*args)

In [4]:
conv1 = SpectralNorm(nn.Conv2d(5, 3, kernel_size=3, padding=1))

In [5]:
conv2 = nn.Conv2d(5,3,3,1)

In [6]:
inp = torch.randn(10,5,64,64)

In [7]:
conv2(inp).shape

torch.Size([10, 3, 62, 62])

In [8]:
out = conv1(inp)

In [9]:
out.shape

torch.Size([10, 3, 64, 64])

In [10]:
for type_str, model in [('model', conv1)]:
    print(type_str)
    for name_str, param in model.named_parameters():
        print("{:21} {:19} {}".format(name_str, str(param.shape), param.numel()))
        print()

model
module.weight         torch.Size([3, 5, 3, 3]) 135

module.bias           torch.Size([3])     3

module.u              torch.Size([3])     3

module.v              torch.Size([45])    45



In [11]:
for type_str, model in [('model', conv2)]:
    print(type_str)
    for name_str, param in model.named_parameters():
        print("{:21} {:19} {}".format(name_str, str(param.shape), param.numel()))
        print()

model
weight                torch.Size([3, 5, 3, 3]) 135

bias                  torch.Size([3])     3

