In [1]:
%run include.ipynb
from torch.nn import Parameter

class Net_block(object):
    
    @staticmethod
    def Conv2d(params):
        # params: input_channels, out_channels, kernel_size,
        # stride, padding, bias
        return nn.Conv2d(params[0], params[1], params[2],
                         params[3], params[4], bias=params[5])
    
    @staticmethod
    def Conv3d(params):
        # params: in_channels(int), out_channels(int), kernel_size(int), ...
        # stride(int), padding(int), bias(bool)
        return nn.Conv3d(params[0], params[1], params[2],
                         params[3], params[4], bias=params[5])
        
    @staticmethod
    def SNConv2d(params):
        # params: input_channels, out_channels, kernel_size,
        # stride, padding, bias
        return SpectralNorm(nn.Conv2d(params[0], params[1], params[2],
                            params[3], params[4], bias=params[5]))
    
    @staticmethod
    def ConvT2d(params):
        # params: input_channels, out_channels, kernel_size,
        # stride, padding, bias
        return nn.ConvTranspose2d(params[0], params[1],
               params[2], params[3], params[4], bias=params[5])
    
    @staticmethod
    def BN1d(params):
        # params: num_features
        return nn.BatchNorm1d(params[0])
    
    @staticmethod
    def BN2d(params):
        # params: num_features
        return nn.BatchNorm2d(params[0])
    
    @staticmethod
    def BN3d(params):
        # params: num_features
        return nn.BatchNorm3d(params[0])
    
    @staticmethod
    def IN2d(params):
        #params: num_features
        return nn.InstanceNorm2d(params[0])
    
    @staticmethod
    def Dropout(params):
        #params: dropout ratio
        return nn.Dropout(params[0])
    
    @staticmethod
    def ReLU(params):
        # params: inplace
        return nn.ReLU(params[0])
    
    @staticmethod
    def LeakyReLU(params):
        # params: negative_slope, inplace
        return nn.LeakyReLU(params[0], params[1])
    
    @staticmethod
    def Tanh(params):
        return nn.Tanh()
    
    @staticmethod
    def Sigmoid(params):
        return nn.Sigmoid()
    
    @staticmethod
    def AvgPool2d(params):
        # params: kernel_size, stride, padding
        return nn.AvgPool2d(params[0], params[1], params[2])
    
    @staticmethod
    def AvgPool3d(params):
        # params: kernel_size(int), stride(int), padding(int)
        return nn.AvgPool3d(params[0], params[1], params[2])
    
    @staticmethod
    def MaxPool2d(params):
        # params: kernel_size(int), stride(int), padding(int)
        return nn.MaxPool2d(params[0], params[1], params[2])
    
    @staticmethod
    def MaxPool3d(params):
        # params: kernel_size(int), stride(int), padding(int)
        return nn.MaxPool3d(params[0], params[1], params[2])
    
    @staticmethod
    def Interpolate(params):
        # params: scale_factor, mode
        return Interpolate_(params[0], params[1])
    
    @staticmethod
    def RefPad2d(params):
        return nn.ReflectionPad2d(params[0])
    
    @staticmethod
    def RepPad2d(params):
        return nn.ReplicationPad2d(params[0])

    @staticmethod
    def Identity(params):
        return Identity_()
    
    @staticmethod
    def Squeeze(params):
        return Squeeze_()
    
    @staticmethod
    def View(params):
        return View_()
    
    @staticmethod
    def ResBlock2d(params):
        # params: fin, fout, kernel_size, padding_type,
        # norm_type, use_dropout, bias, addon_ratio, Conv2d type
        return ResnetBlock(params[0], params[1], params[2], params[3],
               params[4], params[5], params[6], params[7], params[8])
    
    @staticmethod
    def FC(params):
        # params: in_features(int), out_features(out), bias(bool)
        return nn.Linear(params[0], params[1], bias=params[2])
    
class Block_mapping(object):
    
    module_mapping = {
        "FC":           Net_block.FC,
        "Conv2d":       Net_block.Conv2d,
        "Conv3d":       Net_block.Conv3d,
        "SNConv2d":     Net_block.SNConv2d,
        "ConvT2d":      Net_block.ConvT2d,
        "View":         Net_block.View,
        "BN1d":         Net_block.BN1d,
        "BN2d":         Net_block.BN2d,
        "BN3d":         Net_block.BN3d,
        "IN2d":         Net_block.IN2d,
        "Dropout":      Net_block.Dropout,
        "Relu":         Net_block.ReLU,
        "LeakyRelu":    Net_block.LeakyReLU,
        "Tanh":         Net_block.Tanh,
        "Sigmoid":      Net_block.Sigmoid,
        "AvgPool2d":    Net_block.AvgPool2d,
        "AvgPool3d":    Net_block.AvgPool3d,
        "MaxPool2d":    Net_block.MaxPool2d,
        "MaxPool3d":    Net_block.MaxPool3d,
        "Interpolate":  Net_block.Interpolate,
        "RefPad2d":     Net_block.RefPad2d,
        "RepPad2d":     Net_block.RepPad2d,
        "Squeeze":      Net_block.Squeeze,
        "ResBlock2d":   Net_block.ResBlock2d,
        "None":         Net_block.Identity
    }

class Squeeze_(nn.Module):
    def forward(self, x):
        return torch.squeeze(x)
    
class Identity_(nn.Module):
    def forward(self, x):
        return x
    
class View_(nn.Module):
    # the format is NHWC
    def forward(self, x):
        return x.view(x.size(0), -1)
    
class Interpolate_(nn.Module):
    def __init__(self, scale_factor, mode):
        super(Interpolate_, self).__init__()
        self.scale_factor = scale_factor
        self.mode = mode
    def forward(self, x):
        return nn.functional.interpolate(x, scale_factor=self.scale_factor, mode=self.mode)

# ResnetBlock changese only C, NOT H or W
class ResnetBlock(nn.Module):

    def __init__(self, fin, fout, kernel_size, padding_type, norm_type, use_dropout, use_bias, addon_ratio, conv2d_type):
        super(ResnetBlock, self).__init__()
        self.addon_ratio = addon_ratio
        self.conv_block, self.x_block = self.build_conv_block(fin, fout, kernel_size, padding_type,
                                            norm_type, use_dropout, use_bias, conv2d_type)

    def build_conv_block(self, fin, fout, kernel_size, padding_type, norm_type, use_dropout, use_bias, conv2d_type):
        """Construct a convolutional block.
        Parameters:
            fin(int)            -- the number of channels in the input
            fout(int)           -- the number of channels in the output
            padding_type (str)  -- the name of padding layer: reflect | replicate | zero
            norm_layer          -- normalization layer
            use_dropout (bool)  -- if use dropout layers.
            use_bias (bool)     -- if the conv layer uses bias or not
            conv2d_type         -- to use normal conv2d or SNconv2d block
        Returns a conv block (with a conv layer, a normalization layer, and a non-linearity layer (ReLU))
        """
        assert(conv2d_type == "Conv2d" or conv2d_type == "SNConv2d")
        x_block    = []
        conv_block = []
        p = 0
        if padding_type == 'reflect':
            conv_block += [Block_mapping.module_mapping["RefPad2d"]([1])]
            x_block    += [Block_mapping.module_mapping["RefPad2d"]([1])]
        elif padding_type == 'replicate':
            conv_block += [Block_mapping.module_mapping["RepPad2d"]([1])]
            x_block    += [Block_mapping.module_mapping["RepPad2d"]([1])]
        elif padding_type == 'zero':
            p = 1
        else:
            raise NotImplementedError('padding [%s] is not implemented' % padding_type)
            
        if fin != fout:
            x_block += [Block_mapping.module_mapping[conv2d_type]([fin, fout, kernel_size, 1, p, False]),
                        Block_mapping.module_mapping[norm_type]([fout])]
        else:
            x_block += [Block_mapping.module_mapping["None"]([])]

        conv_block += [
            Block_mapping.module_mapping[conv2d_type]([fin, fout, kernel_size, 1, p, use_bias]),
            Block_mapping.module_mapping[norm_type]([fout]),
            Block_mapping.module_mapping["Relu"]([True])]
        if use_dropout:
            conv_block += [Block_mapping.module_mapping["Dropout"]([0.5])]

        p = 0
        if padding_type == 'reflect':
            conv_block += [Block_mapping.module_mapping["RefPad2d"]([1])]
        elif padding_type == 'replicate':
            conv_block += [Block_mapping.module_mapping["RepPad2d"]([1])]
        elif padding_type == 'zero':
            p = 1
        else:
            raise NotImplementedError('padding [%s] is not implemented' % padding_type)
        conv_block += [
            Block_mapping.module_mapping[conv2d_type]([fout, fout, kernel_size, 1, p, use_bias]),
            Block_mapping.module_mapping[norm_type]([fout])]

        return nn.Sequential(*conv_block), nn.Sequential(*x_block)

    def forward(self, x):
        out = self.x_block(x) + self.conv_block(x) * self.addon_ratio
        return out
    
# ===== Defination for Spectral Normalization
def l2normalize(v, eps=1e-12):
    return v / (v.norm() + eps)

class SpectralNorm(nn.Module):
    def __init__(self, module, name='weight', power_iterations=1):
        super(SpectralNorm, self).__init__()
        self.module = module
        self.name = name
        self.power_iterations = power_iterations
        if not self._made_params():
            self._make_params()

    def _update_u_v(self):
        u = getattr(self.module, self.name + "_u")
        v = getattr(self.module, self.name + "_v")
        w = getattr(self.module, self.name + "_bar")

        height = w.data.shape[0]
        for _ in range(self.power_iterations):
            v.data = l2normalize(torch.mv(torch.t(w.view(height,-1).data), u.data))
            u.data = l2normalize(torch.mv(w.view(height,-1).data, v.data))

        # sigma = torch.dot(u.data, torch.mv(w.view(height,-1).data, v.data))
        sigma = u.dot(w.view(height, -1).mv(v))
        setattr(self.module, self.name, w / sigma.expand_as(w))

    def _made_params(self):
        try:
            u = getattr(self.module, self.name + "_u")
            v = getattr(self.module, self.name + "_v")
            w = getattr(self.module, self.name + "_bar")
            return True
        except AttributeError:
            return False

    def _make_params(self):
        w = getattr(self.module, self.name)

        height = w.data.shape[0]
        width = w.view(height, -1).data.shape[1]

        u = Parameter(w.data.new(height).normal_(0, 1), requires_grad=False)
        v = Parameter(w.data.new(width).normal_(0, 1), requires_grad=False)
        u.data = l2normalize(u.data)
        v.data = l2normalize(v.data)
        w_bar = Parameter(w.data)

        del self.module._parameters[self.name]

        self.module.register_parameter(self.name + "_u", u)
        self.module.register_parameter(self.name + "_v", v)
        self.module.register_parameter(self.name + "_bar", w_bar)


    def forward(self, *args):
        self._update_u_v()
        return self.module.forward(*args)