In [1]:
%matplotlib inline

In [2]:
from __future__ import print_function, division
import scipy
import tensorflow as tf
from keras.datasets import mnist
from keras.layers import Input, Dense, Reshape, Flatten, Dropout, Concatenate, concatenate
from keras.layers import BatchNormalization, Activation, ZeroPadding2D
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D, Conv2DTranspose, Deconv2D
from keras.models import Sequential, Model
from keras.optimizers import Adam
import datetime
import matplotlib.pyplot as plt
import sys
from data_loader import DataLoader
import numpy as np
import os

Using TensorFlow backend.


In [3]:
img_rows = 32
img_cols = 32
channels = 3
img_shape = (img_rows, img_cols, channels)
gen_inshape = (img_rows, img_cols, 1)
gen_outshape = (img_rows, img_cols, 3)
patch = int(img_rows / 2**4)
disc_patch = (patch, patch, 1)


dataset_name = 'cifar-10-batches-img'
data_loader = DataLoader(dataset_name=dataset_name,
                         img_res=(img_rows, img_cols))

In [4]:
imgs_A= data_loader.load_data(batch_size=10, is_testing=False)

In [6]:
def build_generator():
    inputs = Input(shape=gen_inshape)

    leaky_relu = tf.nn.leaky_relu
    e1 = Conv2D(64,kernel_size = (4,4),strides = 1, activation = leaky_relu,padding='same')(inputs)
    e1 = BatchNormalization()(e1)
    e2 = Conv2D(128,kernel_size = (4,4),strides = 2, activation = leaky_relu,padding='same')(e1)
    e2 = BatchNormalization()(e2)
    e3 = Conv2D(256,kernel_size = (4,4),strides = 2, activation = leaky_relu,padding='same')(e2)
    e3 = BatchNormalization()(e3)
    e4 = Conv2D(512,kernel_size = (4,4),strides = 2, activation = leaky_relu,padding='same')(e3)
    e4 = BatchNormalization()(e4)
    e5 = Conv2D(512,kernel_size = (4,4),strides = 2, activation = leaky_relu,padding='same')(e4)
    e5 = BatchNormalization()(e5)
    
    d1 = Deconv2D(512,kernel_size = (4,4),strides = 2, activation = 'relu',padding='same')(e5)
    d1 = BatchNormalization()(d1)
    d1 = Dropout(0.5)(d1)
    d1 = concatenate([d1,e4],axis = 3)
    d2 = Deconv2D(256,kernel_size = (4,4),strides = 2, activation = 'relu',padding='same')(d1)
    d2 = BatchNormalization()(d2)
    d2 = Dropout(0.5)(d2)
    d2 = concatenate([d2,e3],axis = 3)
    d3 = Deconv2D(128,kernel_size = (4,4),strides = 2, activation = 'relu',padding='same')(d2)
    d3 = BatchNormalization()(d3)
    d3 = concatenate([d3,e2],axis = 3)
    d4 = Deconv2D(64,kernel_size = (4,4),strides = 2, activation = 'relu',padding='same')(d3)
    d4 = BatchNormalization()(d4)
    d4 = concatenate([d4,e1],axis = 3)
    
    d5 = Conv2D(3,kernel_size = (1,1),strides = 1,activation='tanh')(d4)


    return Model(inputs, d5)

In [7]:
def build_discriminator():
    img_A = Input(shape=img_shape)
    img_B = Input(shape=gen_inshape)
    leaky_relu = tf.nn.leaky_relu
    combined_imgs = Concatenate(axis=3)([img_A, img_B])

    d1 = Conv2D(64, kernel_size=(4,4), strides = 2, activation = leaky_relu, padding="same")(combined_imgs)
    d2 = Conv2D(128, kernel_size=(4,4), strides = 2, activation = leaky_relu, padding="same")(d1)
    d2 = BatchNormalization()(d2)
    d3 = Conv2D(256, kernel_size=(4,4), strides = 2, activation = leaky_relu, padding="same")(d2)
    d3 = BatchNormalization()(d3)
    d4 = Conv2D(512, kernel_size=(4,4), strides = 1, activation = leaky_relu, padding="same")(d3)
    d4 = BatchNormalization()(d4)
    d5 = Conv2D(1, kernel_size=(4,4), strides = 1, activation = 'sigmoid', padding="same")(d4)

    return Model([img_A, img_B], d5)

In [8]:
generator = build_generator()
discriminator = build_discriminator()

In [9]:
from skimage import io, color
def keras_preprocess(img):
    lab = color.rgb2lab(img/255.0)
    L,A,B=cv2.split(lab)
    L = L / 50 - 1
    A = A / 110
    B = B / 110
    ret = cv2.merge((L,A,B))
    return ret

In [10]:
from skimage import io, color
def keras_postprocess(img):
    L,A,B=cv2.split(img)
    L = (L + 1) / 2 * 100
    A = A * 110
    B = B * 110
    lab = cv2.merge((L,A,B))
    ret = color.lab2rgb(lab)
    ret = ret*255
    return ret

In [11]:
from keras import losses
gen_optimizer = Adam(0.005, 0.5)
dis_optimizer = Adam(0.0005, 0.5)

discriminator.compile(loss='binary_crossentropy',
            optimizer=dis_optimizer)

img_A = Input(shape=img_shape)
img_B = Input(shape=gen_inshape)

fake_A = generator(img_B)

discriminator.trainable = False

valid = discriminator([fake_A, img_B])



combined = Model(inputs=[img_A, img_B], outputs=[valid, fake_A])
combined.compile(loss=['binary_crossentropy', 'mae'],
                      loss_weights=[1, 100],
                      optimizer=gen_optimizer)

In [14]:
import cv2
def sample_images(epoch, batch_i):
    r, c = 3, 3

    img = data_loader.load_data(batch_size=3, is_testing=True)
    
    img = img.astype('float64')
    imgs_B = []
    for i in img:
        imgs_B.append(color.rgb2grey(i).reshape(img_rows,img_cols,1))
    imgs_B = np.array(imgs_B)
    
    fake_A = generator.predict(imgs_B)

    titles = ['Gray', 'Generated', 'Original']
    fig, axs = plt.subplots(r, c)
    cnt = 0
    for i in range(r):
        axs[i,0].imshow(imgs_B[i].reshape(img_rows,img_cols),cmap='gray')
        axs[i,0].set_title(titles[0])
        axs[i,0].axis('off')
        
        axs[i,1].imshow(keras_postprocess(fake_A[i])/255.0)
        axs[i,1].set_title(titles[1])
        axs[i,1].axis('off')
        
        axs[i,2].imshow(img[i].astype('int'))
        axs[i,2].set_title(titles[2])
        axs[i,2].axis('off')
    fig.savefig("images_samples/%d_%d.png" % (epoch, batch_i))
    plt.close()

In [17]:
import cv2
from tqdm import tqdm
def train(epochs, batch_size=1, sample_interval=50):

    start_time = datetime.datetime.now()
    fake = np.zeros((batch_size,) + (4,4,1))
    valid = np.ones((batch_size,) + (4,4,1))
    for epoch in range(epochs):
        print (epoch)
        for batch_i, img in enumerate(tqdm(data_loader.load_batch(batch_size))):
            img = img.astype('float64')
            imgs_A = []
            imgs_B = []
            for i in img:
                imgs_A.append(keras_preprocess(i))
                imgs_B.append(color.rgb2grey(i).reshape(img_rows,img_cols,1))
            imgs_A = np.array(imgs_A)
            imgs_B = np.array(imgs_B)
            fake_A = generator.predict(imgs_B)
            d_loss_real = discriminator.train_on_batch([imgs_A, imgs_B], valid)
            d_loss_fake = discriminator.train_on_batch([fake_A, imgs_B], fake)
            d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

            g_loss = combined.train_on_batch([imgs_A, imgs_B], [valid, imgs_A])

            elapsed_time = datetime.datetime.now() - start_time

            #print ("[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f] time: %s" % (epoch, epochs,
            #                                                        batch_i, data_loader.n_batches,
            #                                                        d_loss,
            #                                                        g_loss[0],
            #                                                        elapsed_time))

            if batch_i % sample_interval == 0:
                sample_images(epoch, batch_i)



In [18]:
train(epochs=200, batch_size=256, sample_interval=200)

0it [00:00, ?it/s]

0


  'Discrepancy between trainable weights and collected trainable'
  warn('Color data out of range: Z < 0 in %s pixels' % invalid[0].size)
  warn('Color data out of range: Z < 0 in %s pixels' % invalid[0].size)
  warn('Color data out of range: Z < 0 in %s pixels' % invalid[0].size)
194it [01:30,  2.30it/s]
0it [00:00, ?it/s]

1


194it [01:23,  2.30it/s]
0it [00:00, ?it/s]

2


194it [01:23,  2.36it/s]
0it [00:00, ?it/s]

3


194it [01:23,  2.36it/s]
0it [00:00, ?it/s]

4


194it [01:23,  2.35it/s]
0it [00:00, ?it/s]

5


  warn('Color data out of range: Z < 0 in %s pixels' % invalid[0].size)
  warn('Color data out of range: Z < 0 in %s pixels' % invalid[0].size)
  warn('Color data out of range: Z < 0 in %s pixels' % invalid[0].size)
194it [01:23,  2.35it/s]
0it [00:00, ?it/s]

6


194it [01:23,  2.31it/s]
0it [00:00, ?it/s]

7


194it [01:23,  2.36it/s]
0it [00:00, ?it/s]

8


194it [01:23,  2.30it/s]
0it [00:00, ?it/s]

9


194it [01:23,  2.42it/s]
0it [00:00, ?it/s]

10


194it [01:22,  2.37it/s]
0it [00:00, ?it/s]

11


194it [01:22,  2.33it/s]
0it [00:00, ?it/s]

12


194it [01:23,  2.37it/s]
0it [00:00, ?it/s]

13


194it [01:23,  2.41it/s]
0it [00:00, ?it/s]

14


194it [01:23,  2.38it/s]
0it [00:00, ?it/s]

15


194it [01:22,  2.37it/s]
0it [00:00, ?it/s]

16


194it [01:22,  2.35it/s]
0it [00:00, ?it/s]

17


194it [01:23,  2.34it/s]
0it [00:00, ?it/s]

18


  warn('Color data out of range: Z < 0 in %s pixels' % invalid[0].size)
194it [01:24,  2.29it/s]
0it [00:00, ?it/s]

19


194it [01:24,  2.29it/s]
0it [00:00, ?it/s]

20


194it [01:24,  2.36it/s]
0it [00:00, ?it/s]

21


194it [01:23,  2.34it/s]
0it [00:00, ?it/s]

22


194it [01:24,  2.24it/s]
0it [00:00, ?it/s]

23


194it [01:23,  2.35it/s]
0it [00:00, ?it/s]

24


194it [01:23,  2.36it/s]
0it [00:00, ?it/s]

25


194it [01:22,  2.38it/s]
0it [00:00, ?it/s]

26


194it [01:22,  2.31it/s]
0it [00:00, ?it/s]

27


194it [01:23,  2.31it/s]
0it [00:00, ?it/s]

28


194it [01:22,  2.36it/s]
0it [00:00, ?it/s]

29


194it [01:22,  2.37it/s]
0it [00:00, ?it/s]

30


194it [01:22,  2.36it/s]
0it [00:00, ?it/s]

31


194it [01:23,  2.33it/s]
0it [00:00, ?it/s]

32


194it [01:22,  2.31it/s]
0it [00:00, ?it/s]

33


194it [01:23,  2.37it/s]
0it [00:00, ?it/s]

34


194it [01:23,  2.38it/s]
0it [00:00, ?it/s]

35


  warn('Color data out of range: Z < 0 in %s pixels' % invalid[0].size)
194it [01:23,  2.27it/s]
0it [00:00, ?it/s]

36


194it [01:24,  2.27it/s]
0it [00:00, ?it/s]

37


194it [01:24,  2.31it/s]
0it [00:00, ?it/s]

38


194it [01:25,  2.34it/s]
0it [00:00, ?it/s]

39


194it [01:23,  2.26it/s]
0it [00:00, ?it/s]

40


194it [01:24,  2.30it/s]
0it [00:00, ?it/s]

41


194it [01:23,  2.38it/s]
0it [00:00, ?it/s]

42


194it [01:22,  2.35it/s]
0it [00:00, ?it/s]

43


194it [01:23,  2.31it/s]
0it [00:00, ?it/s]

44


194it [01:23,  2.33it/s]
0it [00:00, ?it/s]

45


194it [01:23,  2.36it/s]
0it [00:00, ?it/s]

46


194it [01:23,  2.38it/s]
0it [00:00, ?it/s]

47


194it [01:23,  2.34it/s]
0it [00:00, ?it/s]

48


194it [01:23,  2.38it/s]
0it [00:00, ?it/s]

49


194it [01:23,  2.35it/s]
0it [00:00, ?it/s]

50


194it [01:23,  2.35it/s]
0it [00:00, ?it/s]

51


194it [01:22,  2.37it/s]
0it [00:00, ?it/s]

52


194it [01:22,  2.29it/s]
0it [00:00, ?it/s]

53


194it [01:22,  2.31it/s]
0it [00:00, ?it/s]

54


194it [01:23,  2.36it/s]
0it [00:00, ?it/s]

55


194it [01:22,  2.31it/s]
0it [00:00, ?it/s]

56


194it [01:22,  2.36it/s]
0it [00:00, ?it/s]

57


194it [01:22,  2.33it/s]
0it [00:00, ?it/s]

58


194it [01:22,  2.32it/s]
0it [00:00, ?it/s]

59


194it [01:22,  2.32it/s]
0it [00:00, ?it/s]

60


194it [01:22,  2.40it/s]
0it [00:00, ?it/s]

61


194it [01:23,  2.34it/s]
0it [00:00, ?it/s]

62


194it [01:23,  2.26it/s]
0it [00:00, ?it/s]

63


194it [01:22,  2.37it/s]
0it [00:00, ?it/s]

64


194it [01:22,  2.36it/s]
0it [00:00, ?it/s]

65


194it [01:23,  2.34it/s]
0it [00:00, ?it/s]

66


194it [01:22,  2.35it/s]
0it [00:00, ?it/s]

67


194it [01:23,  2.33it/s]
0it [00:00, ?it/s]

68


194it [01:24,  2.32it/s]
0it [00:00, ?it/s]

69


194it [01:23,  2.35it/s]
0it [00:00, ?it/s]

70


194it [01:22,  2.35it/s]
0it [00:00, ?it/s]

71


194it [01:22,  2.36it/s]
0it [00:00, ?it/s]

72


194it [01:22,  2.37it/s]
0it [00:00, ?it/s]

73


194it [01:22,  2.40it/s]
0it [00:00, ?it/s]

74


194it [01:22,  2.33it/s]
0it [00:00, ?it/s]

75


194it [01:23,  2.33it/s]
0it [00:00, ?it/s]

76


194it [01:22,  2.37it/s]
0it [00:00, ?it/s]

77


194it [01:22,  2.41it/s]
0it [00:00, ?it/s]

78


194it [01:22,  2.40it/s]
0it [00:00, ?it/s]

79


194it [01:24,  2.32it/s]
0it [00:00, ?it/s]

80


194it [01:23,  2.34it/s]
0it [00:00, ?it/s]

81


194it [01:22,  2.33it/s]
0it [00:00, ?it/s]

82


194it [01:22,  2.33it/s]
0it [00:00, ?it/s]

83


194it [01:23,  2.35it/s]
0it [00:00, ?it/s]

84


194it [01:23,  2.36it/s]
0it [00:00, ?it/s]

85


194it [01:23,  2.34it/s]
0it [00:00, ?it/s]

86


194it [01:23,  2.35it/s]
0it [00:00, ?it/s]

87


194it [01:23,  2.30it/s]
0it [00:00, ?it/s]

88


194it [01:23,  2.34it/s]
0it [00:00, ?it/s]

89


194it [01:22,  2.38it/s]
0it [00:00, ?it/s]

90


194it [01:23,  2.40it/s]
0it [00:00, ?it/s]

91


194it [01:22,  2.34it/s]
0it [00:00, ?it/s]

92


194it [01:22,  2.33it/s]
0it [00:00, ?it/s]

93


  warn('Color data out of range: Z < 0 in %s pixels' % invalid[0].size)
194it [01:22,  2.35it/s]
0it [00:00, ?it/s]

94


194it [01:23,  2.35it/s]
0it [00:00, ?it/s]

95


194it [01:23,  2.37it/s]
0it [00:00, ?it/s]

96


194it [01:23,  2.28it/s]
0it [00:00, ?it/s]

97


194it [01:23,  2.35it/s]
0it [00:00, ?it/s]

98


194it [01:22,  2.35it/s]
0it [00:00, ?it/s]

99


194it [01:22,  2.37it/s]
0it [00:00, ?it/s]

100


194it [01:22,  2.32it/s]
0it [00:00, ?it/s]

101


194it [01:22,  2.36it/s]
0it [00:00, ?it/s]

102


194it [01:23,  2.28it/s]
0it [00:00, ?it/s]

103


194it [01:22,  2.34it/s]
0it [00:00, ?it/s]

104


194it [01:22,  2.34it/s]
0it [00:00, ?it/s]

105


194it [01:23,  2.31it/s]
0it [00:00, ?it/s]

106


194it [01:23,  2.36it/s]
0it [00:00, ?it/s]

107


194it [01:22,  2.35it/s]
0it [00:00, ?it/s]

108


194it [01:23,  2.34it/s]
0it [00:00, ?it/s]

109


194it [01:23,  2.36it/s]
0it [00:00, ?it/s]

110


194it [01:23,  2.41it/s]
0it [00:00, ?it/s]

111


194it [01:22,  2.33it/s]
0it [00:00, ?it/s]

112


194it [01:23,  2.38it/s]
0it [00:00, ?it/s]

113


194it [01:23,  2.38it/s]
0it [00:00, ?it/s]

114


194it [01:23,  2.37it/s]
0it [00:00, ?it/s]

115


194it [01:22,  2.33it/s]
0it [00:00, ?it/s]

116


194it [01:24,  2.28it/s]
0it [00:00, ?it/s]

117


194it [01:26,  2.27it/s]
0it [00:00, ?it/s]

118


194it [01:23,  2.37it/s]
0it [00:00, ?it/s]

119


194it [01:23,  2.38it/s]
0it [00:00, ?it/s]

120


194it [01:23,  2.36it/s]
0it [00:00, ?it/s]

121


194it [01:23,  2.39it/s]
0it [00:00, ?it/s]

122


194it [01:23,  2.37it/s]
0it [00:00, ?it/s]

123


194it [01:23,  2.27it/s]
0it [00:00, ?it/s]

124


194it [01:23,  2.29it/s]
0it [00:00, ?it/s]

125


194it [01:23,  2.36it/s]
0it [00:00, ?it/s]

126


194it [01:23,  2.25it/s]
0it [00:00, ?it/s]

127


194it [01:24,  2.32it/s]
0it [00:00, ?it/s]

128


194it [01:23,  2.37it/s]
0it [00:00, ?it/s]

129


194it [01:23,  2.34it/s]
0it [00:00, ?it/s]

130


194it [01:23,  2.29it/s]
0it [00:00, ?it/s]

131


194it [01:24,  2.30it/s]
0it [00:00, ?it/s]

132


194it [01:23,  2.29it/s]
0it [00:00, ?it/s]

133


194it [01:22,  2.35it/s]
0it [00:00, ?it/s]

134


194it [01:23,  2.37it/s]
0it [00:00, ?it/s]

135


194it [01:22,  2.37it/s]
0it [00:00, ?it/s]

136


194it [01:22,  2.29it/s]
0it [00:00, ?it/s]

137


194it [01:22,  2.37it/s]
0it [00:00, ?it/s]

138


194it [01:23,  2.36it/s]
0it [00:00, ?it/s]

139


194it [01:23,  2.33it/s]
0it [00:00, ?it/s]

140


194it [01:23,  2.29it/s]
0it [00:00, ?it/s]

141


194it [01:24,  2.29it/s]
0it [00:00, ?it/s]

142


194it [01:24,  2.31it/s]
0it [00:00, ?it/s]

143


194it [01:24,  2.38it/s]
0it [00:00, ?it/s]

144


53it [00:23,  2.36it/s]

KeyboardInterrupt: 