# Pixel to Pixel生成对抗网络

CGAN : https://arxiv.org/abs/1411.1784

Pixel-to-Pixel GAN : https://arxiv.org/abs/1611.07004

In [13]:
import mxnet as mx
import numpy as np

from mxnet import nd
from mxnet import gluon
from mxnet import image
from mxnet import autograd

%matplotlib inline 
import matplotlib as mlt
mlt.rcParams['figure.dpi'] = 120
import matplotlib.pyplot as plt

In [14]:
epochs = 100
batch_size = 10

use_gpu = True
ctx = mx.gpu() if use_gpu else mx.cpu()

lr = 0.0002
beta1 = 0.5
lambda1 = 100

pool_size = 50

## 下载并预处理数据集

In [3]:
dataset = 'facades'

In [4]:
import tarfile

img_wd = 256
img_ht = 256

train_img_path = '%s/train' % (dataset)
val_img_path = '%s/val' % (dataset)

def download_data(dataset):
    if not os.path.exists(dataset):
        url = 'https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets/%s.tar.gz' % (dataset)
        os.mkdir(dataset)
        data_file = gluon.utils.download(url)
        with tarfile.open(data_file) as tar:
            tar.extractall(path='.')
        os.remove(data_file)

## ``mx.image.fixed_crop``用法：

```
def fixed_crop(src, x0, y0, w, h, size=None, interp=2):
    """Crop src at fixed location, and (optionally) resize it to size.

    Parameters
    ----------
    src : NDArray
        Input image
    x0 : int
        Left boundary of the cropping area
    y0 : int
        Top boundary of the cropping area
    w : int
        Width of the cropping area
    h : int
        Height of the cropping area
    size : tuple of (w, h)
        Optional, resize to new size after cropping
    interp : int, optional, default=2
        Interpolation method. See resize_short for details.

    Returns
    -------
    NDArray
        An `NDArray` containing the cropped image.
    """
```

**另外由于pixel2pixel GAN支持双向训练，因此我们使用is_reversed来确定是否交换输入和输出的图片。**

In [5]:
def load_data(path, batch_size, is_reversed=False):
    img_in_list = []
    img_out_list = []
    for path, _, fnames in os.walk(path):
        for fname in fnames:
            if not fname.endswith(path):
                continue
            img = os.path.join(path, fname)
            img_arr = image.imread(img).astype(np.float32) / 127.5 - 1
            img_arr = image.imresize(img_arr, img_wd*2, img_ht)
            
            img_arr_in, img_arr_out = [image.fixed_crop(img_arr, 0, 0, img_wd, img_ht), 
                                       image.fixed_crop(img_arr, img_wd, 0, img_wd, img_ht)]
            img_arr_in, img_arr_out = [nd.transpose(img_arr_in, (2, 0, 1)),
                                       nd.transpose(img_arr_out, (2, 0, 1))] # channel X height X width
            img_arr_in, img_arr_out = [img_arr_in.reshape((1,) + img_arr_in.shape), 
                                       img_arr_out.reshape((1,) + img_arr_out.shape)] # batch X channel X height X width
            
            img_in_list.append(img_arr_out if is_reversed else img_arr_in)
            img_out_list.append(img_arr_in if is_reversed else img_arr_out)
            
    return mx.io.NDArrayIter(data=[nd.concat(*img_in_list, dim=0), nd.concat(*img_out_list, dim=0)], 
                             batch_size=batch_size)

# download dataset

# load train data

# load valid data

## Visualization

In [6]:
def visualize(img_arr):
    # recover to the origin
    plt.imshow(((img_arr.asnumpy().transpose(1,2,0) + 1.0) * 127.5).astype(np.uint8))
    plt.axis('off')
    
def preview_train_data():
    img_in_list, img_out_list = train_data.next().data
    for i in range(4):
        plt.subplot(2, 4, i+1)
        visualize(img_in_list[i])
        plt.subplot(2, 4, i+5)
        visualize(img_out_list[i])
    plt.show()
    
## preview_train_data()

## 定义网络

* Generator : U-Net(with skip connection)
* Discriminator : PatchGAN

<img src="http://gluon.mxnet.io/_images/Pixel2pixel-Unet.png" width="500">

### U-Net Generator

In [7]:
class UnetSkipUnit(gluon.nn.Block):
    def __init__(self, inner_channels, outer_channels, inner_block=None, 
                 innermost=False, outermost=False, use_dropout=False, 
                 use_bias=False, **kwargs):
        super().__init__(**kwargs)
        
        with self.name_scope():
            self.outermost = outermost
            # en_conv
            en_conv = gluon.nn.Conv2D(channels    = inner_channels,
                                      in_channels = outer_channels,
                                      kernel_size = 4,
                                      strides     = 2,
                                      padding     = 1,
                                      use_bias    = use_bias)
            
            # en_relu
            en_relu = gluon.nn.LeakyReLU(alpha = 0.2)
            
            # en_norm
            en_norm = gluon.nn.BatchNorm(momentum    = 0.1, 
                                         in_channels = inner_channels)
            
            # de_relu
            de_relu = gluon.nn.Activation(activation = 'relu')
            
            # de_norm
            de_norm = gluon.nn.BatchNorm(momentum    = 0.1, 
                                         in_channels = outer_channels)
            
        if innermost: # if the block is the innest
            de_conv = gluon.nn.Conv2DTranspose(channels    = outer_channels,
                                               in_channels = inner_channels,
                                               kernel_size = 4,
                                               strides     = 2,
                                               padding     = 1,
                                               use_bias    = use_bias)
            encoder = [en_relu, en_conv]
            decoder = [de_relu, de_conv, de_norm]
            model = encoder + decoder
        elif outermost: # if the block is the outest
            de_conv = gluon.nn.Conv2DTranspose(channels    = outer_channels,
                                               in_channels = inner_channels*2,
                                               kernel_size = 4,
                                               strides     = 2,
                                               padding     = 1,
                                               use_bias    = use_bias)
            encoder = [en_conv]
            decoder = [de_relu, de_conv, gluon.nn.Activation('tanh')]
            model = encoder + [inner_block] + decoder
        else: # middle 
            de_conv = gluon.nn.Conv2DTranspose(channels    = outer_channels,
                                               in_channels = inner_channels*2,
                                               kernel_size = 4,
                                               strides     = 2,
                                               padding     = 1,
                                               use_bias    = use_bias)
            encoder = [en_relu, en_conv, en_norm]
            decoder = [de_relu, de_conv, de_norm]
            model = encoder + [inner_block] + decoder
        
        if use_dropout:
            model += [gluon.nn.Dropout(rate=0.5)]
            
        self.unit = gluon.nn.Sequential()
        with self.unit.name_scope():
            for block in model:
                self.unit.add(block) 

    def forward(self, X):
        out = X
        if self.outermost:
            out = self.unit(out)
        else:
            out = nd.concat(self.unit(out), out, dim=1)
        return out 

这里需要注意的是，这样搭建网络的话是说网络的结构是U型的，但是一个输入比如(1, 3, 256, 256)，在过网络feature的时候还是要从outermost也就是最外层的conv走一遍feature，因此可以理解为图片是自底向上过网络抽取feature的。

In [8]:
class UnetGenerator(gluon.nn.Block):
    def __init__(self, in_channels, num_downs, ngf=64, use_dropout=True, **kwargs):
        super().__init__(**kwargs)
            
        # Build Generator Structure
        unet = UnetSkipUnit(inner_channels = ngf * 8, 
                            outer_channels = ngf * 8,
                            innermost      = True)
        for _ in range(num_downs - 5):
            unet = UnetSkipUnit(inner_channels = ngf * 8, 
                                outer_channels = ngf * 8, 
                                inner_block    = unet, 
                                use_dropout    = use_dropout)
            
        unet = UnetSkipUnit(inner_channels = ngf * 8,
                            outer_channels = ngf * 4,
                            inner_block    = unet)
        unet = UnetSkipUnit(inner_channels = ngf * 4,
                            outer_channels = ngf * 2,
                            inner_block    = unet)
        unet = UnetSkipUnit(inner_channels = ngf * 2,
                            outer_channels = ngf * 1,
                            inner_block    = unet)
        unet = UnetSkipUnit(inner_channels = ngf,
                            outer_channels = in_channels,
                            inner_block    = unet,
                            outermost      = True)
        with self.name_scope():
            self.model = unet
            
    def forward(self, X):
        return self.model(X)     

In [9]:
unet = UnetGenerator(in_channels=3, num_downs=8)
unet.initialize()

In [10]:
a = nd.random.normal(shape=(1, 3, 256, 256))
y = unet(a)

In [25]:
y.shape

(1, 3, 256, 256)

In [26]:
unet

UnetGenerator(
  (model): UnetSkipUnit(
    (unit): Sequential(
      (0): Conv2D(3 -> 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): UnetSkipUnit(
        (unit): Sequential(
          (0): LeakyReLU(0.2)
          (1): Conv2D(64 -> 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
          (2): BatchNorm(use_global_stats=False, eps=1e-05, axis=1, fix_gamma=False, momentum=0.1, in_channels=128)
          (3): UnetSkipUnit(
            (unit): Sequential(
              (0): LeakyReLU(0.2)
              (1): Conv2D(128 -> 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
              (2): BatchNorm(use_global_stats=False, eps=1e-05, axis=1, fix_gamma=False, momentum=0.1, in_channels=256)
              (3): UnetSkipUnit(
                (unit): Sequential(
                  (0): LeakyReLU(0.2)
                  (1): Conv2D(256 -> 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
                  (2): 

### PatchGAN Discriminator 

In [268]:
class Discriminator(gluon.nn.HybridBlock):
    def __init__(self, in_channels, ndf=64, n_layers=3, 
                 use_sigmoid=False, use_bias=False, **kwargs):
        super().__init__(self, **kwargs)
        
        with self.name_scope():
            self.model = gluon.nn.HybridSequential()
            kernel_size = 4
            
    

In [270]:
shift_x = np.arange(0, 10)
shift_y = np.arange(0, 10)
shift_x, shift_y = np.meshgrid(shift_x, shift_y)

In [275]:
shifts = np.vstack((shift_x.ravel(), shift_y.ravel(), shift_x.ravel(), shift_y.ravel()))

In [287]:
shifts = shifts.transpose()

In [296]:
anchor = np.array([[0.63310546875, 0.4555439453125, 0.6444779785156249, 0.46747368164062497],
                     [0.6953125, 0.42838525390625, 0.733072900390625, 0.470702978515625]])

In [299]:
anchors = anchor.reshape((1, 2, 4)) + shifts.reshape((1, 100, 4)).transpose((1, 0, 2))

In [308]:
shifts.reshape((1, 100, 4)).transpose((1, 0, 2))

array([[[0, 0, 0, 0]],

       [[1, 0, 1, 0]],

       [[2, 0, 2, 0]],

       [[3, 0, 3, 0]],

       [[4, 0, 4, 0]],

       [[5, 0, 5, 0]],

       [[6, 0, 6, 0]],

       [[7, 0, 7, 0]],

       [[8, 0, 8, 0]],

       [[9, 0, 9, 0]],

       [[0, 1, 0, 1]],

       [[1, 1, 1, 1]],

       [[2, 1, 2, 1]],

       [[3, 1, 3, 1]],

       [[4, 1, 4, 1]],

       [[5, 1, 5, 1]],

       [[6, 1, 6, 1]],

       [[7, 1, 7, 1]],

       [[8, 1, 8, 1]],

       [[9, 1, 9, 1]],

       [[0, 2, 0, 2]],

       [[1, 2, 1, 2]],

       [[2, 2, 2, 2]],

       [[3, 2, 3, 2]],

       [[4, 2, 4, 2]],

       [[5, 2, 5, 2]],

       [[6, 2, 6, 2]],

       [[7, 2, 7, 2]],

       [[8, 2, 8, 2]],

       [[9, 2, 9, 2]],

       [[0, 3, 0, 3]],

       [[1, 3, 1, 3]],

       [[2, 3, 2, 3]],

       [[3, 3, 3, 3]],

       [[4, 3, 4, 3]],

       [[5, 3, 5, 3]],

       [[6, 3, 6, 3]],

       [[7, 3, 7, 3]],

       [[8, 3, 8, 3]],

       [[9, 3, 9, 3]],

       [[0, 4, 0, 4]],

       [[1, 4, 1

In [309]:
anchors

array([[[0.63310547, 0.45554395, 0.64447798, 0.46747368],
        [0.6953125 , 0.42838525, 0.7330729 , 0.47070298]],

       [[1.63310547, 0.45554395, 1.64447798, 0.46747368],
        [1.6953125 , 0.42838525, 1.7330729 , 0.47070298]],

       [[2.63310547, 0.45554395, 2.64447798, 0.46747368],
        [2.6953125 , 0.42838525, 2.7330729 , 0.47070298]],

       [[3.63310547, 0.45554395, 3.64447798, 0.46747368],
        [3.6953125 , 0.42838525, 3.7330729 , 0.47070298]],

       [[4.63310547, 0.45554395, 4.64447798, 0.46747368],
        [4.6953125 , 0.42838525, 4.7330729 , 0.47070298]],

       [[5.63310547, 0.45554395, 5.64447798, 0.46747368],
        [5.6953125 , 0.42838525, 5.7330729 , 0.47070298]],

       [[6.63310547, 0.45554395, 6.64447798, 0.46747368],
        [6.6953125 , 0.42838525, 6.7330729 , 0.47070298]],

       [[7.63310547, 0.45554395, 7.64447798, 0.46747368],
        [7.6953125 , 0.42838525, 7.7330729 , 0.47070298]],

       [[8.63310547, 0.45554395, 8.64447798, 0.46747368]

In [300]:
anchors

array([[[0.63310547, 0.45554395, 0.64447798, 0.46747368],
        [0.6953125 , 0.42838525, 0.7330729 , 0.47070298]],

       [[1.63310547, 0.45554395, 1.64447798, 0.46747368],
        [1.6953125 , 0.42838525, 1.7330729 , 0.47070298]],

       [[2.63310547, 0.45554395, 2.64447798, 0.46747368],
        [2.6953125 , 0.42838525, 2.7330729 , 0.47070298]],

       [[3.63310547, 0.45554395, 3.64447798, 0.46747368],
        [3.6953125 , 0.42838525, 3.7330729 , 0.47070298]],

       [[4.63310547, 0.45554395, 4.64447798, 0.46747368],
        [4.6953125 , 0.42838525, 4.7330729 , 0.47070298]],

       [[5.63310547, 0.45554395, 5.64447798, 0.46747368],
        [5.6953125 , 0.42838525, 5.7330729 , 0.47070298]],

       [[6.63310547, 0.45554395, 6.64447798, 0.46747368],
        [6.6953125 , 0.42838525, 6.7330729 , 0.47070298]],

       [[7.63310547, 0.45554395, 7.64447798, 0.46747368],
        [7.6953125 , 0.42838525, 7.7330729 , 0.47070298]],

       [[8.63310547, 0.45554395, 8.64447798, 0.46747368]

In [295]:
anchors.reshape((100, 4))

array([[0.63310547, 0.45554395, 0.64447798, 0.46747368],
       [1.63310547, 0.45554395, 1.64447798, 0.46747368],
       [2.63310547, 0.45554395, 2.64447798, 0.46747368],
       [3.63310547, 0.45554395, 3.64447798, 0.46747368],
       [4.63310547, 0.45554395, 4.64447798, 0.46747368],
       [5.63310547, 0.45554395, 5.64447798, 0.46747368],
       [6.63310547, 0.45554395, 6.64447798, 0.46747368],
       [7.63310547, 0.45554395, 7.64447798, 0.46747368],
       [8.63310547, 0.45554395, 8.64447798, 0.46747368],
       [9.63310547, 0.45554395, 9.64447798, 0.46747368],
       [0.63310547, 1.45554395, 0.64447798, 1.46747368],
       [1.63310547, 1.45554395, 1.64447798, 1.46747368],
       [2.63310547, 1.45554395, 2.64447798, 1.46747368],
       [3.63310547, 1.45554395, 3.64447798, 1.46747368],
       [4.63310547, 1.45554395, 4.64447798, 1.46747368],
       [5.63310547, 1.45554395, 5.64447798, 1.46747368],
       [6.63310547, 1.45554395, 6.64447798, 1.46747368],
       [7.63310547, 1.45554395,

In [313]:
np.random.choice??