## Import modules needed

In [1]:
import torch
import torch.nn as nn
import functools
from torch.autograd import Variable
import numpy as np
import matplotlib.pyplot as plt
import torch.nn.functional as F

## Definition of basic functions

In [2]:
def weights_init(m):
    """
    This function initialize the weights of neural networks
    """
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
        if m.bias is not None:
            nn.init.constant_(m.bias.data, 0.0)
    elif classname.find('BatchNorm2d') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0.0)
    elif classname.find('Linear') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
        if m.bias is not None:
            nn.init.constant_(m.bias.data, 0.0)
            
def get_norm_layer(norm_type='batch'):
    """
    This function initialize the normalization type of neural networks
    """
    if norm_type == 'batch':
        norm_layer = nn.BatchNorm2d
    elif norm_type == 'instance':
        norm_layer = nn.InstanceNorm2d
    elif norm_type == 'layer':
        norm_layer = nn.LayerNorm
    else:
        raise NotImplementedError('Normalization layer [%s] is not implemented' % norm_type)
    return norm_layer

def print_network(model):
    """
    This function prints the structure and the number of training parameters of neural networks
    """
    num_params = 0
    for param in model.parameters():
        num_params += param.numel()
    print(model)
    print("The number of parameters: {}".format(num_params))

In [3]:
class AdaptiveInstanceNorm2d(nn.Module):
    '''
    Cited from the "MaskGAN: Towards Diverse and Interactive Facial Image Manipulation"
    '''
    def __init__(self, num_features, eps=1e-5, momentum=0.1):
        super(AdaptiveInstanceNorm2d, self).__init__()
        self.num_features = num_features
        self.eps = eps
        self.momentum = momentum
        # weight and bias are dynamically assigned
        self.weight = None
        self.bias = None
        # just dummy buffers, not used
        self.register_buffer('running_mean', torch.zeros(num_features))
        self.register_buffer('running_var', torch.ones(num_features))

    def forward(self, x):
        assert self.weight is not None and self.bias is not None, "Please assign weight and bias before calling AdaIN!"
        b, c = x.size(0), x.size(1)
        running_mean = self.running_mean.repeat(b)
        running_var = self.running_var.repeat(b)

        # Apply instance norm
        x_reshaped = x.contiguous().view(1, b * c, *x.size()[2:])

        out = F.batch_norm(
            x_reshaped, running_mean, running_var, self.weight, self.bias,
            True, self.momentum, self.eps)

        return out.view(b, c, *x.size()[2:])

    def __repr__(self):
        return self.__class__.__name__ + '(' + str(self.num_features) + ')'

In [3]:
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, norm='none', activation='relu', padtype='zero'):
        super(ConvBlock, self).__init__()
        """
        Initialize the padding operation
        """
        if padtype == 'reflection':
            self.pad = nn.ReflectionPad2d(padding)
        elif padtype == 'replication':
            self.pad = nn.ReplicationPad2d(padding)
        elif padtype == 'zero':
            self.pad = nn.ZeroPad2d(padding)
        elif padtype == 'constant':
            self.pad = nn.ConstantPad2d(padding)
        else:
            assert 0, "Wrong choice of padding type!"
            
        """
        Initialize the Normalization Layer
        """
        if norm == 'Bn':
            self.norm = nn.BatchNorm2d(out_channels)
        elif norm == 'Lbn':
            self.norm = nn.LazyBatchNorm2d(out_channels)
        elif norm == 'In':
            self.norm = nn.InstanceNorm2d(out_channels)
        elif norm == 'Lin':
            self.norm = nn.LazyInstanceNorm2d(out_channels)
        elif norm == 'Adain':
            self.norm = AdaptiveInstanceNorm2d(out_channels)
        elif norm == 'none':
            self.norm = None
        else:
            assert 0, "Wrong choice of Normalization Layer!"
        
        """
        Initialize the Activation Layer
        """
        if activation == 'ReLU':
            self.activation = nn.ReLU(inplace=True)
        elif activation == 'Sigmoid':
            self.activation = nn.Sigmoid()
        elif activation == 'Tanh':
            self.activation = nn.Tanh()
        elif activation == 'Softmax':
            self.activation = nn.Softmax(dim=1)
        elif activation == 'LeakyReLU':
            self.activation = nn.LeakyReLU(0.2, inplace=True)
        elif activation == 'none':
            self.activation = None
        else:
            assert 0, "Wrong choice of Activation Layer!"
            
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, bias=True)        
        
    def forward(self, x):
        x = self.conv(self.pad(x))
        if self.norm:
            x = self.norm(x)
        if self.activation:
            x = self.activation(x)
        return x

In [4]:
class SFT(nn.Module):
    def __init__(self):
        super(SFT, self).__init__()
        self.conv1 = nn.Conv2d(64, 64, 1)
        self.conv2 = nn.Conv2d(64, 64, 1) 
        self.conv3 = nn.Conv2d(64, 64, 1)
        self.conv4 = nn.Conv2d(64, 64, 1)

    def forward(self, x, y):
        '''
        x is the feture map
        y is the conditions
        '''
        gamma = self.conv2(F.leaky_relu(self.conv1(y), 0.1, inplace=True))
        beta = self.conv4(F.leaky_relu(self.conv3(y), 0.1, inplace=True))
        return x * gamma + beta

## Implementation of the Spatial Style Encoder

In [None]:
class SpatialStyleEncoder(nn.Module):
    def __init__(self, num_adain_para):
        super(SpatialStyleEncoder, self).__init__()
        self.label_conv_1 = ConvBlock(19, 16, 7, 1, 3)
        self.label_conv_2 = ConvBlock(16, 32, 4, 2, 1)
        self.label_conv_3 = ConvBlock(32, 64, 4, 2, 1, activation='none')
        self.label_conv_4 = ConvBlock(64, 64, 4, 2, 1)
        self.label_conv_5 = ConvBlock(64, 64, 4, 2, 1)
        self.label_conv_6 = ConvBlock(32, 64, 4, 2, 1, activation='none')
        self.style_conv_1 = ConvBlock(3, 16, 7, 1, 3)
        self.style_conv_2 = ConvBlock(16, 32, 4, 2, 1)
        self.style_conv_3 = ConvBlock(32, 64, 4, 2, 1)
        self.sft_layer_1 = SFT()
        self.style_conv_4 = ConvBlock(64, 64, 4, 2, 1)
        self.style_conv_5 = ConvBlock(64, 64, 4, 2, 1)
        self.style_conv_6 = ConvBlock(64, 64, 4, 2, 1)
        self.sft_layer_2 = SFT()
        self.average_pool = nn.AdaptiveAvgPool2d(1)
        self.conv_last = nn.Conv2d(64, num_adain_para, 1, 1, 0)
    
    def forward(self, x, y):
        '''
        x is the labed mask picture
        y is the origin picture
        '''
        y_out_1 = self.label_conv_1(y)
        y_out_1 = self.label_conv_2(y_out_1)
        y_out_1 = self.label_conv_3(y_out_1)
        y_out_2 = self.label_conv_4(y_out_1)
        y_out_2 = self.label_conv_5(y_out_2) 
        y_out_2 = self.label_conv_6(y_out_2) 
        x = self.style_conv_1(x)
        x = self.style_conv_2(x)  
        x = self.style_conv_3(x)
        x = self.sft_layer_1(x, y_out_1)
        x = self.style_conv_4(x)
        x = self.style_conv_5(x)  
        x = self.style_conv_6(x)
        x = self.sft_layer_2(x, y_out_2)
        x = nn.AdaptiveAvgPool2d(1)
        out = self.conv_last(x)

        return out