In [1]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import cv2
import os
import matplotlib.image as mpimg

  from ._conv import register_converters as _register_converters


In [2]:
def load_img(path):
    img = cv2.imread(path)/255
    return cv2.resize(img, (100,100))

In [3]:
def mynet(inp, reuse=False):
    with tf.variable_scope("model"):
        print(inp.shape)
        with tf.variable_scope("conv1") as scope:
            net = tf.contrib.layers.conv2d(inp, 32, 15, 
                                        stride=1,
                                        activation_fn=tf.nn.relu, 
                                        padding='VALID',
                                        weights_initializer=tf.keras.initializers.he_normal(),
                                        scope=scope,
                                        reuse=reuse)
            
            print(net.shape)
            
            net = tf.contrib.layers.max_pool2d(net, kernel_size=2, stride=2)
                                    
            net = tf.contrib.layers.batch_norm(net, reuse=reuse, scope=scope)
            
            print(net.shape)
            

        with tf.variable_scope("conv2") as scope:
            net = tf.contrib.layers.conv2d(net, 64, 8,
                                        stride=1,
                                        activation_fn=tf.nn.relu, 
                                        padding='VALID',
                                        weights_initializer=tf.keras.initializers.he_normal(),
                                        scope=scope,
                                        reuse=reuse)
            
            print(net.shape)
            
            net = tf.contrib.layers.max_pool2d(net, kernel_size=3, stride=3)
                    
            net = tf.contrib.layers.batch_norm(net, reuse=reuse, scope=scope)
            
            print(net.shape)
        
        
        with tf.variable_scope("conv3") as scope:
            net = tf.contrib.layers.conv2d(net, 256, 5, 
                                        stride=1,
                                        activation_fn=tf.nn.relu,
                                        padding='VALID',
                                        weights_initializer=tf.keras.initializers.he_normal(),
                                        scope=scope,
                                        reuse=reuse)
            
            print(net.shape)
            
            net = tf.contrib.layers.max_pool2d(net, kernel_size=2, stride=2)

            net = tf.contrib.layers.batch_norm(net, reuse=reuse, scope=scope)
            
            print(net.shape)

        
        with tf.variable_scope("fc1") as scope:
            net = tf.contrib.layers.flatten(net)
            print(net.shape)
        
        with tf.variable_scope("fc2") as scope:
            net = tf.contrib.layers.fully_connected(net, 64, 
                                                    activation_fn=tf.nn.relu, 
                                                    reuse=reuse, 
                                                    scope=scope)
            print(net.shape)
        
        
    return net


def contrastive_loss(model1, model2, y, margin):
    with tf.name_scope("contrastive-loss"):
        d = tf.sqrt(tf.reduce_sum(tf.pow(model1-model2, 2), 1, keepdims=True))
        tmp = y * tf.square(d)    
        tmp2 = (1 - y) * tf.square(tf.maximum((margin - d),0))
    return tf.reduce_mean(tmp + tmp2)/2

In [4]:
left = tf.placeholder(tf.float32, [None, 100, 100, 3], name='left')
right = tf.placeholder(tf.float32, [None, 100, 100, 3], name='right')

label = tf.placeholder(tf.int32, [None, 1], name='label') # 1 if same, 0 if different
label = tf.to_float(label)

margin = 1

left_output = mynet(left, reuse=False)
right_output = mynet(right, reuse=True)

loss = contrastive_loss(left_output, right_output, label, margin)
optim = tf.train.AdamOptimizer(0.0005).minimize(loss)

init_op = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init_op)

(?, 100, 100, 3)
(?, 86, 86, 32)
(?, 43, 43, 32)
(?, 36, 36, 64)
(?, 12, 12, 64)
(?, 8, 8, 256)
(?, 4, 4, 256)
(?, 4096)
(?, 64)
(?, 100, 100, 3)
(?, 86, 86, 32)
(?, 43, 43, 32)
(?, 36, 36, 64)
(?, 12, 12, 64)
(?, 8, 8, 256)
(?, 4, 4, 256)
(?, 4096)
(?, 64)


In [7]:
dataset_path = "/content/Sketchy/"
photo_path = os.path.join(dataset_path, 'photo/')
sketch_path = os.path.join(dataset_path, 'sketch/')

def get_dict():
    
    photo_dictionary = {}

    for category in os.listdir(photo_path):
        category_path = os.path.join(photo_path, category)

        photo_dictionary[category] = os.listdir(category_path)

    sketch_dictionary = {}

    for category in os.listdir(sketch_path):
        category_path = os.path.join(sketch_path, category)

        sketch_dictionary[category] = os.listdir(category_path) 
    
    return photo_dictionary, sketch_dictionary

In [8]:
def get_batch(photo_dictionary, sketch_dictionary):
    
    
    l = []
    p_ = []
    s_ = []

    for _ in range(128): 

        if np.random.uniform() >= 0.5:

            photo_class = np.random.choice(list(photo_dictionary))
            photo = np.random.choice(photo_dictionary[photo_class])
            photo_dictionary[photo_class].remove(photo)
            p = photo_class + '/' + photo

            sketch_class = photo_class
            sketch = np.random.choice(sketch_dictionary[sketch_class])
            sketch_dictionary[sketch_class].remove(sketch)
            s = sketch_class + '/' + sketch
            label = 1

        else:

            x = list(photo_dictionary)
            photo_class = np.random.choice(x)
            photo = np.random.choice(photo_dictionary[photo_class])
            photo_dictionary[photo_class].remove(photo)
            p = photo_class + '/' + photo
            x.remove(photo_class)

            sketch_class = np.random.choice(x)
            sketch = np.random.choice(sketch_dictionary[sketch_class])
            sketch_dictionary[sketch_class].remove(sketch)
            s = sketch_class + '/' + sketch
            label = 0

        p_.append(os.path.join(dataset_path, 'photo/', p))
        s_.append(os.path.join(dataset_path, 'sketch/', s))
        l.append(label)
    
    images = np.array([load_img(i) for i in p_])
    sketches = np.array([load_img(i) for i in s_])
    labels = np.array(l)

    return images, sketches, labels

In [14]:
test_image_paths = []
for category in os.listdir(photo_path):
    category_path = os.path.join(photo_path, category + '/')
    image_paths = np.random.choice(os.listdir(category_path), size=20, replace=False)
    for i in range(20):
        test_image_paths.append(category_path + image_paths[i])
np.random.shuffle(test_image_paths)
len(np.unique(test_image_paths))

2500

In [15]:
test_sketch_path = sketch_path + np.random.choice(os.listdir(sketch_path)) 
test_sketch_path = test_sketch_path + '/' + np.random.choice(os.listdir(test_sketch_path))
test_sketch_path

'/content/Sketchy/sketch/lion/n02129165_12039-2.png'

In [16]:
test_sketch = load_img(test_sketch_path)
test_sketch = np.expand_dims(test_sketch, 0)
test_sketch.shape

(1, 100, 100, 3)

In [None]:
test_images = []
for path in test_image_paths:
    test_images.append(load_img(path))
test_images = np.array(test_images)
test_images.shape

In [None]:
from PIL import Image
Image.open(test_sketch_path)

In [None]:
for epoch in range(1000000):
    photo_dictionary, sketch_dictionary = get_dict()
    p_batch, s_batch, lab_batch = get_batch(photo_dictionary, sketch_dictionary)
    lab_batch = lab_batch.reshape(-1, 1)
    [_, loss_] = sess.run([optim,loss], {left: p_batch, right: s_batch, label:lab_batch})
    
    
    if epoch%100==0:
        
        sketch_repr = sess.run([left_output], {left: test_sketch})
        sketch_repr = np.squeeze(np.array(sketch_repr), 1)
        print(sketch_repr.shape)
        sketch_representations = np.tile(sketch_repr, 2496).reshape(2496, 64)
        print(sketch_representations.shape)
        
        batch_size = 8
        n_batches = len(test_images) // batch_size
        image_representations = []

        for i in range(n_batches):
            img_repr = sess.run([left_output], {left: test_images[i*batch_size : (i+1)*batch_size]})
            img_repr = np.squeeze(np.array(img_repr), 0)
            image_representations.append(img_repr)
        image_representations = np.vstack(image_representations)

        diff = np.sqrt(np.mean((sketch_representations - image_representations)**2, -1))
        top_k = np.argsort(diff)[:5]

        print ('##' + str(epoch) + ' : loss == ' + str(loss_))

        plt.figure(figsize=(20, 20))
        for i in range(5):    
            img = mpimg.imread(test_image_paths[top_k[i]])
            plt.subplot(1, 5, i+1)
            plt.imshow(img)
            plt.xticks([])
            plt.yticks([])
        plt.show()