In [1]:
import os
import datetime
import imageio
import skimage
import scipy # 

import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt

from glob import glob
from IPython.display import Image

from tensorflow.python.keras.callbacks import TensorBoard
import smtplib

import cv2
import numpy as np

tf.logging.set_verbosity(tf.logging.ERROR)

In [2]:
def kmeans(img):
    K = 5
    rows = img.shape[1]
    cols = img.shape[0]
    n = img.shape[0] * img.shape[1]
    data = img.reshape(-1,3)
    data = np.float32(data)

    criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 10, 1.0)
    center,labels,colors = cv2.kmeans(data, K, None, criteria,  10, cv2.KMEANS_PP_CENTERS)

    for i in range(0,n):
        data[i][0] = colors[labels[i], 0]
        data[i][1] = colors[labels[i], 1]
        data[i][2] = colors[labels[i], 2]

    reduced = data.reshape(cols, rows, 3)
    reduced = np.uint8(reduced)

    return reduced

def getClrs(img):
    reduced = kmeans(img)
    data = reduced.reshape(-1,3)
    diff_clrs = []
    init = 0
    for color in data:
        if init == 0:
            diff_clrs = np.asarray([color])
            init = 1
        else:
            if color not in diff_clrs:
                diff_clrs = np.concatenate((diff_clrs, np.asarray([color])), axis = 0)

    return diff_clrs

def checkClrs(clrs):
    height = 64
    width = 32
    init = 0
    for color in clrs:
        for i in range(width):
            if init == 0:
                data = np.array([color])
                init = 1
            else:
                data = np.concatenate((data,np.array([color])), axis = 0)

    data = data.reshape(1,len(data),3)
    data2 = data
    for i in range(height):
        data = np.concatenate((data,data2), axis = 0)

    return(data)

In [3]:
class DataLoader():
    def __init__(self, dataset_name, img_res=(128, 128)):
        self.dataset_name = dataset_name
        self.img_res = img_res

    def load_data(self, batch_size=1, is_testing=False):
        data_type = "train" if not is_testing else "test"
        path = glob('%s/%s/*' % (self.dataset_name, data_type))
        
        batch_images = np.random.choice(path, size=batch_size)

        imgs_A = []
        imgs_B = []
        for img_path in batch_images:
            img = self.imread(img_path)

            h, w, _ = img.shape
            _w = int(w/2)
            img_A, img_B = img[:, :_w, :], img[:, _w:, :]

            img_A = scipy.misc.imresize(img_A, self.img_res)
            img_B = scipy.misc.imresize(img_B, self.img_res)

            # If training => do random flip
            if not is_testing and np.random.random() < 0.5:
                img_A = np.fliplr(img_A)
                img_B = np.fliplr(img_B)

            imgs_A.append(img_A)
            imgs_B.append(img_B)

        imgs_A = np.array(imgs_A)/127.5 - 1.
        imgs_B = np.array(imgs_B)/127.5 - 1.

        return imgs_A, imgs_B

    def load_batch(self, batch_size=1, is_testing=False):
        data_type = "train" if not is_testing else "val"
        path = glob('%s/%s/*' % (self.dataset_name, data_type))

        self.n_batches = int(len(path) / batch_size)

        for i in range(self.n_batches-1):
            batch = path[i*batch_size:(i+1)*batch_size]
            imgs_A, imgs_B = [], []
            for img in batch:
                img = self.imread(img)
                h, w, _ = img.shape
                half_w = int(w/2)
                img_A = img[:, :half_w, :]
                img_B = img[:, half_w:, :]

                img_A = scipy.misc.imresize(img_A, self.img_res)
                img_B = scipy.misc.imresize(img_B, self.img_res)

                if not is_testing and np.random.random() > 0.5:
                        img_A = np.fliplr(img_A)
                        img_B = np.fliplr(img_B)

                imgs_A.append(img_A)
                imgs_B.append(img_B)

            imgs_A = np.array(imgs_A)/127.5 - 1.
            imgs_B = np.array(imgs_B)/127.5 - 1.

            yield imgs_A, imgs_B


    def imread(self, path):
        return imageio.imread(path).astype(np.float)
    

In [5]:
def dataLoaderTest():
# Input shape
    img_rows = 128
    img_cols = 128
    channels = 3
    img_shape = (img_rows, img_cols, channels)

    # Configure data loader
    dataset_name = 'dataset2'
    data_loader = DataLoader(dataset_name=dataset_name,
                                    img_res=(img_rows, img_cols))
    
    os.makedirs('./images2/%s' % dataset_name, exist_ok=True)
    
    batch_size = 1
    for batch_i, (real_clrs, imgs) in enumerate(data_loader.load_batch(batch_size)):
        img = 0.5 * real_clrs[0] + 0.5
        print(img.shape)
        print(img[0][0])
        clrs = getClrs(img*256)
        print(clrs)
        check = checkClrs(clrs)
        r, c = 1,1
        titles = ['Condition']
        fig, axs = plt.subplots(r, c)
        axs.imshow(check)
        axs.set_title(check)
        axs.axis('off')
        fig.savefig("./images2/%s/%d_%d.png" % (dataset_name, 1, batch_i))
        plt.close()
        input("press something bro")
        
dataLoaderTest()
    
        
        
    

`imresize` is deprecated in SciPy 1.0.0, and will be removed in 1.3.0.
Use Pillow instead: ``numpy.array(Image.fromarray(arr).resize())``.
`imresize` is deprecated in SciPy 1.0.0, and will be removed in 1.3.0.
Use Pillow instead: ``numpy.array(Image.fromarray(arr).resize())``.


(128, 128, 3)
[0.96862745 0.82352941 0.71372549]


  if s != self._text:


[[228 175 152]
 [135 105 120]
 [ 12  12  33]
 [ 57  86 150]
 [ 30  40  95]]
press something bro
(128, 128, 3)
[0.21960784 0.10196078 0.07058824]
[[ 68  56  60]
 [234 145  39]
 [183 185 141]
 [ 53 137 172]]
press something bro
(128, 128, 3)
[0.82352941 0.8        0.82352941]
[[190 184 187]
 [105  99 101]
 [ 47  41  43]
 [  5   0   2]
 [241 240 243]]
press something bro


KeyboardInterrupt: 

In [79]:
class ClrPipe():
    def __init__(self):
        # Input shape
        self.img_rows = 128
        self.img_cols = 128
        self.channels = 3
        self.img_shape = (self.img_rows, self.img_cols, self.channels)
        
         # Configure data loader
        self.dataset_name = 'dataset2'
        self.data_loader = DataLoader(dataset_name=self.dataset_name,
                                      img_res=(self.img_rows, self.img_cols))
        
        # Number of filters in the first layer of G and D
        self.gf = 64
        
         # Build the generator
        self.generator = self.build_generator()
        
        
        # Input images and their conditioning images
        img_A = tf.keras.layers.Input(shape=self.img_shape)
        img_B = tf.keras.layers.Input(shape=self.img_shape)
        
        optimizer = tf.keras.optimizers.Adam(0.0002, 0.5)
        self.generator.compile(loss=['mse'], loss_weights=[15], optimizer=optimizer)
        
    def build_generator(self):
        """U-Net Generator"""

        def conv2d(layer_input, filters, f_size=4, bn=True):
            """Layers used during downsampling"""
            d = tf.keras.layers.Conv2D(filters, kernel_size=f_size, strides=2, padding='same')(layer_input)
            d = tf.keras.layers.LeakyReLU(alpha=0.2)(d)
            if bn:
                d = tf.keras.layers.BatchNormalization(momentum=0.8)(d)
            return d
        
        # Image input
        d0 = tf.keras.layers.Input(shape=self.img_shape)

        # Downsampling
        d1 = conv2d(d0, self.gf, bn=False)
        d2 = conv2d(d1, self.gf*2)
        d3 = conv2d(d2, self.gf*4)
        d4 = conv2d(d3, self.gf*8)
        d5 = conv2d(d4, self.gf*8)
        d6 = conv2d(d5, self.gf*8)
        d7 = conv2d(d6, self.gf*8)
       
        #Channel Pool
        out1 = conv2d(d7,self.gf*4)
        print(out1.shape)
        out2 = conv2d(out1,self.gf*2)
        print(out2.shape)
        out3 = conv2d(out2,self.gf)
        print(out3.shape)
        out4 = conv2d(out3, 32)
        print(out4.shape)
        out5 = conv2d(out4, 15)
        print(out5.shape)
        
        #output_img = tf.keras.layers.Conv2D(self.channels, kernel_size=4, strides=1, padding='same', activation='tanh')(u7)
        return tf.keras.models.Model(d0, out5)
    
    def trainTest(self, batch_size=1):
        for batch_i, (real_clrs, imgs) in enumerate(self.data_loader.load_batch(batch_size)):
            fake_clrs = self.generator.predict(imgs)
            prediction = 0.5 * fake_clrs[0][0][0] + 0.5
            print(prediction.shape)
            normalized_prediction = prediction.reshape(-1,3)
            print(normalized_prediction.shape)
            print(normalized_prediction)
            colors = checkClrs(normalized_prediction)
            print(colors.shape)
            r, c = 1, 1
            titles = ['Condition']
            fig, axs = plt.subplots(r, c)
            axs.imshow(colors)
            axs.set_title('hi')
            axs.axis('off')
            fig.savefig("./images2/%s/%d_%d.png" % ('dataset2', 1, batch_i))
            plt.close()
            input("write something bro")
    
    def train(self, epochs, batch_size=1, sample_interval=50):
        start_time = datetime.datetime.now()

        for epoch in range(epochs):
            for batch_i, (real_clrs, imgs) in enumerate(self.data_loader.load_batch(batch_size)):

                # Condition on image input and generate fake color palettes
                fake_clrs = self.generator.predict(imgs)
                
                real_clrs = getClrs(imgs)
                print(real_clrs.shape)
                
                g_loss = self.generator.train_on_batch(fake_clrs, real_clrs)

                elapsed_time = datetime.datetime.now() - start_time
                # Plot the progress
                print ("[Epoch %d/%d] [Batch %d/%d] [G loss: %f] time: %s" % (epoch, epochs,
                                                                        batch_i, self.data_loader.n_batches,
                                                                        g_loss[0]/batch_size,
                                                                        elapsed_time))

                # If at save interval => save generated image samples
                if batch_i % sample_interval == 0:
                    self.sample_images(epoch, batch_i)


    def sample_images(self, epoch, batch_i):
        os.makedirs('./images/%s' % self.dataset_name, exist_ok=True)
        r, c = 3, 3

        imgs_A, imgs_B = self.data_loader.load_data(batch_size=3, is_testing=True)
        fake_A = self.generator.predict(imgs_B)

        gen_imgs = np.concatenate([imgs_B, fake_A, imgs_A])

        # Rescale images 0 - 1
        gen_imgs = 0.5 * gen_imgs + 0.5

        titles = ['Condition', 'Generated', 'Original']
        fig, axs = plt.subplots(r, c)
        cnt = 0
        for i in range(r):
            for j in range(c):
                axs[i,j].imshow(gen_imgs[cnt])
                axs[i, j].set_title(titles[i])
                axs[i,j].axis('off')
                cnt += 1
        fig.savefig("./images/%s/%d_%d.png" % (self.dataset_name, epoch, batch_i))
        plt.close()

In [80]:
net = ClrPipe()
net.generator.summary()

(?, 1, 1, 256)
(?, 1, 1, 128)
(?, 1, 1, 64)
(?, 1, 1, 32)
(?, 1, 1, 15)
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_83 (InputLayer)        (None, 128, 128, 3)       0         
_________________________________________________________________
conv2d_317 (Conv2D)          (None, 64, 64, 64)        3136      
_________________________________________________________________
leaky_re_lu_316 (LeakyReLU)  (None, 64, 64, 64)        0         
_________________________________________________________________
conv2d_318 (Conv2D)          (None, 32, 32, 128)       131200    
_________________________________________________________________
leaky_re_lu_317 (LeakyReLU)  (None, 32, 32, 128)       0         
_________________________________________________________________
batch_normalization_v1_288 ( (None, 32, 32, 128)       512       
_________________________________________________________________
conv

In [None]:
net.trainTest(1)

`imresize` is deprecated in SciPy 1.0.0, and will be removed in 1.3.0.
Use Pillow instead: ``numpy.array(Image.fromarray(arr).resize())``.
`imresize` is deprecated in SciPy 1.0.0, and will be removed in 1.3.0.
Use Pillow instead: ``numpy.array(Image.fromarray(arr).resize())``.


(15,)
(5, 3)
[[0.50000006 0.49999997 0.4999998 ]
 [0.5000004  0.50000095 0.50000083]
 [0.5        0.5000004  0.5000015 ]
 [0.50000083 0.49999997 0.50000083]
 [0.49999994 0.4999999  0.5000011 ]]
(65, 160, 3)
write something bro
(15,)
(5, 3)
[[0.5000011  0.49999997 0.49999994]
 [0.50000024 0.5000008  0.5       ]
 [0.5        0.5000005  0.5000012 ]
 [0.50000066 0.49999994 0.5000003 ]
 [0.49999994 0.4999999  0.5000002 ]]
(65, 160, 3)
write something bro
(15,)
(5, 3)
[[0.50000054 0.50000095 0.4999999 ]
 [0.50000054 0.50000054 0.49999994]
 [0.49999994 0.5000004  0.5000013 ]
 [0.50000167 0.5        0.5000014 ]
 [0.50000006 0.4999998  0.50000083]]
(65, 160, 3)
write something bro
(15,)
(5, 3)
[[0.5000008  0.50000066 0.49999985]
 [0.5000001  0.5000005  0.50000006]
 [0.49999997 0.50000036 0.5000017 ]
 [0.50000155 0.49999997 0.5000012 ]
 [0.49999994 0.49999988 0.5000003 ]]
(65, 160, 3)
write something bro
(15,)
(5, 3)
[[0.5000013  0.5000011  0.5       ]
 [0.50000024 0.50000024 0.49999994]
 [0.5  