## pix2pix(U-Net + GAN) experiments

In [246]:
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

### Model definitions

In [247]:
from keras import objectives
from keras import backend as K
from keras.models import Sequential, Model
from keras.optimizers import Adam
from keras.layers import Input
from keras.layers.merge import concatenate
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.normalization import BatchNormalization
from keras.layers.convolutional import Conv2D, Conv2DTranspose, Cropping2D
from keras.layers.core import Activation, Dropout

In [248]:
# U-Net Generator
def g_unet(nf, model_name, in_ch=1, out_ch=1, batch_size=1, alpha=0.2):
    ''' параметры:
    input shape = (100, 100, in_ch)
    output = (100, 250, out_ch)
    nf - число фильтров на входном слое
    alpha - параметр LeakyReLU
    '''
    i = Input(shape=(100, 100, in_ch))
    # (100, 100, in_ch)
    
    conv1 = Conv2D(nf, (6, 6), padding='same', strides=(5, 5))(i)
    conv1 = BatchNormalization(axis=3)(conv1)
    x = LeakyReLU(alpha)(conv1)
    # (20, 20, nf)
    
    conv2 = Conv2D(nf*5, (6, 6), padding='same', strides=(5, 5))(x)
    conv2 = BatchNormalization(axis=3)(conv2)
    x = LeakyReLU(alpha)(conv2)
    # (4, 4, nf*5)
    
    conv3 = Conv2D(nf*10, (3, 3), padding='same', strides=(2, 2))(x)
    conv3 = BatchNormalization(axis=3)(conv3)
    x = LeakyReLU(alpha)(conv3)
    # (2, 2, nf*10)

    conv4 = Conv2D(nf*10, (2, 2), padding='valid', strides=(1, 1))(x)
    conv4 = BatchNormalization(axis=3)(conv4)
    x = LeakyReLU(alpha)(conv4)
    # (1, 1, nf*10)

    dconv1 = Conv2DTranspose(nf*10, (2, 2), strides=(1, 1))(x)
    dconv1 = BatchNormalization(axis=3)(dconv1)
    dconv1 = Dropout(0.5)(dconv1)
    x = concatenate([dconv1, conv3], axis=3)
    x = LeakyReLU(alpha)(x)
    # (2, 2, nf*(10 + 10))

    dconv2 = Conv2DTranspose(nf*5, (2, 2), strides=(2, 2))(x)
    dconv2 = BatchNormalization(axis=3)(dconv2)
    x = concatenate([dconv2, conv2], axis=3)
    x = LeakyReLU(alpha)(x)
    # (4, 4, nf*(5 + 5))

    dconv3 = Conv2DTranspose(nf, (2, 2), strides=(5, 5))(x)
    dconv3 = BatchNormalization(axis=3)(dconv3)
    x = concatenate([dconv3, conv1], axis=3)
    x = LeakyReLU(alpha)(x)
    # (20, 20, nf*(1 + 1))

    dconv4 = Conv2DTranspose(out_ch, (2, 2), strides=(13, 5))(x)
    # (260, 100, out_ch)
    
    dconv4 = Cropping2D((5, 0))(dconv4)
    # (250, 100, out_ch)

    out = Activation('tanh')(dconv4)
    unet = Model(i, out, name=model_name)
    
    return unet

In [253]:
# Discriminator
def discriminator(nf, a_ch=1, b_ch=1, opt=Adam(lr=2e-4, beta_1=0.5), alpha=0.2, model_name='d'):
    ''' параметры:
    a_ch - число каналов первого изображения
    b_ch - число каналов второго
    nf - число фильтров на входном слое
    alpha - параметр LeakyReLU
    '''
    i = Input(shape=(500, 100, a_ch + b_ch))
    # (500, 100, a_ch + b_ch)
    
    conv1 = Conv2D(nf, (6, 6), padding='same', strides=(5,5))(i)
    x = LeakyReLU(alpha)(conv1)
    # (100, 20, nf)
    
    conv2 = Conv2D(nf*5, (6, 6), padding='same', strides=(5,5))(x)
    x = LeakyReLU(alpha)(conv2)
    # (20, 4, nf*5)
    
    conv3 = Conv2D(1, (3, 3), padding='same', strides=(2,2))(x)
    out = Activation('sigmoid')(conv3)
    # (10, 2, 1)
    
    d = Model(i, out, name=model_name)
    
    def d_loss(y_true, y_pred):
        L = objectives.binary_crossentropy(K.batch_flatten(y_true), K.batch_flatten(y_pred))
        return L
    
    d.compile(optimizer=opt, loss=d_loss)
    return d

In [254]:
def full_generator(nf, in_ch=1, out_ch=1, batch_size=1, alpha=0.2, model_name='f_gen'):
    a1 = Input(shape=(100, 100, in_ch))
    a2 = Input(shape=(100, 100, in_ch))
    gen1 = g_unet(nf, 'unet1', in_ch, out_ch, batch_size, alpha)
    out1 = gen1(a1)
    gen2 = g_unet(nf, 'unet2', in_ch, out_ch, batch_size, alpha)
    out2 = gen2(a2)
    out = concatenate([out1, out2], axis=1)
    f_gen = Model([a1, a2], out, name=model_name)
    return f_gen

In [257]:
def pix2pix(atob, d, a_ch=1, b_ch=1, alpha=100, opt=Adam(lr=2e-4, beta_1=0.5), model_name='pix2pix'):
    '''
    atob - full generator
    d - discriminator
    '''
    a1 = Input(shape=(100, 100, a_ch))
    a2 = Input(shape=(100, 100, a_ch))
    b = Input(shape=(500, 100, b_ch))
    
    # генерируем картинку на основе a1 и a2 с помощью объединенного генератора:
    bp = atob([a1, a2])
    
    # дискриминатор получает на вход пару изображений
    d_in = concatenate([b, bp], axis=3)
    pix2pix = Model([a1, a2, b], d(d_in), name=model_name)
    
    def p2p_loss(y_true, y_pred):
        y_true_flat = K.batch_flatten(y_true)
        y_pred_flat = K.batch_flatten(y_pred)
        
        # adversarial loss
        L_adv = objectives.binary_crossentropy(y_true_flat, y_pred_flat)
        
        # atob loss
        b_flat = K.batch_flatten(b)
        bp_flat = K.batch_flatten(bp)
        L_atob = K.mean(K.abs(b_flat - bp_flat))
        
        return L_adv + alpha*L_atob
    
    # обучаем генератор - фризим дискриминатор
    pix2pix.get_layer('d').trainable = False
    
    pix2pix.compile(optimizer=opt, loss=p2p_loss)
    return pix2pix

### Dataset loading + preprocessing

In [None]:
from PIL import Image

In [None]:
H = 100
W = 500
dataPrefix = '../data/sand/full_dataset/trend1/panorama/sample'
dataExt = '.jpg'

def loadImage(i):
    fileName = dataPrefix + str(i) + dataExt
    im = Image.open(fileName)
    return np.array(im)

In [None]:
im = loadImage(1)

In [None]:
im.shape

In [None]:
im = im.reshape(1, 100, 100, 1)
y = np.array([0])

### Training

In [258]:
f_gen = full_generator(5)
d = discriminator(5)
p2p = pix2pix(f_gen, d)

In [259]:
f_gen.summary()

____________________________________________________________________________________________________
Layer (type)                     Output Shape          Param #     Connected to                     
input_110 (InputLayer)           (None, 100, 100, 1)   0                                            
____________________________________________________________________________________________________
input_111 (InputLayer)           (None, 100, 100, 1)   0                                            
____________________________________________________________________________________________________
unet1 (Model)                    (None, 250, 100, 1)   48021                                        
____________________________________________________________________________________________________
unet2 (Model)                    (None, 250, 100, 1)   48021                                        
___________________________________________________________________________________________

In [260]:
d.output_shape

(None, 10, 2, 1)

In [261]:
p2p.summary()

____________________________________________________________________________________________________
Layer (type)                     Output Shape          Param #     Connected to                     
input_115 (InputLayer)           (None, 100, 100, 1)   0                                            
____________________________________________________________________________________________________
input_116 (InputLayer)           (None, 100, 100, 1)   0                                            
____________________________________________________________________________________________________
input_117 (InputLayer)           (None, 500, 100, 1)   0                                            
____________________________________________________________________________________________________
f_gen (Model)                    (None, 500, 100, 1)   96042                                        
___________________________________________________________________________________________

### Using trained NN

In [262]:
p2p.layers

[<keras.engine.topology.InputLayer at 0x7f8372f5c860>,
 <keras.engine.topology.InputLayer at 0x7f83729e39e8>,
 <keras.engine.topology.InputLayer at 0x7f83729e39b0>,
 <keras.engine.training.Model at 0x7f8373075b70>,
 <keras.layers.merge.Concatenate at 0x7f8371ec9f98>,
 <keras.engine.training.Model at 0x7f8372a61da0>]

In [263]:
f_gen.output_shape

(None, 500, 100, 1)

In [264]:
d.input_shape

(None, 500, 100, 2)

In [265]:
d.output_shape

(None, 10, 2, 1)