In [None]:
import tensorflow as tf
import os
import matplotlib.pyplot as plt
import numpy as np
import datetime
from glob import glob
import scipy
import scipy.misc
import pickle
from tqdm import tqdm_notebook as tqdm
%matplotlib inline  

In [None]:
def get_batch(images, batch_size, index):
    start_idx = index * batch_size
    end_idx = (index + 1) * batch_size        
    batch  = images[start_idx: end_idx]
    return batch

In [None]:
def weight_variable(shape):
    return tf.Variable(tf.random_normal(shape, mean=0.0, stddev=0.01))

def bias_variable(shape):
    return tf.Variable(tf.random_normal(shape, mean=0.0, stddev=0.01))

def layer(x, shape, activation):
    W = weight_variable(shape)
    b = bias_variable([shape[1]])
    return activation(tf.matmul(x, W) + b)

In [None]:
image_size = 64 * 64 * 3
latent_dim = 100
encoder_internal_dim = 2048
decoder_internal_dim = 2048

input_shape = [None, image_size]
x = tf.placeholder(tf.float32, input_shape)

softplus = tf.nn.softplus
tanh = tf.nn.tanh

h_enc1 = layer(x, [image_size, encoder_internal_dim], activation=softplus)
h_enc2 = layer(h_enc1, [encoder_internal_dim, encoder_internal_dim], activation=softplus)
h_enc3 = layer(h_enc2, [encoder_internal_dim, encoder_internal_dim], activation=softplus)

W_mu = weight_variable([encoder_internal_dim, latent_dim])
b_mu = bias_variable([latent_dim])

W_log_sigma = weight_variable([encoder_internal_dim, latent_dim])
b_log_sigma = bias_variable([latent_dim])

    
z_mu = tf.matmul(h_enc3, W_mu) + b_mu
z_log_sigma = tf.matmul(h_enc3, W_log_sigma) + b_log_sigma

# reparametarization trick

# noise gaussian ε ~ N(0, 1)
epsilon = tf.random_normal(tf.stack([tf.shape(x)[0], latent_dim]))

# z = μ+σ^(1/2)*ε
z = z_mu + tf.exp(z_log_sigma/2) * epsilon


h_dec1 = layer(z, [latent_dim, decoder_internal_dim], activation=softplus)
h_dec2 = layer(h_dec1, [decoder_internal_dim, decoder_internal_dim], activation=softplus)
h_dec3 = layer(h_dec2, [decoder_internal_dim, decoder_internal_dim], activation=softplus)

# log(p(x|z)) (p is Bernoulli) reconstruction loss
y = layer(h_dec3, [decoder_internal_dim, image_size], activation=tf.nn.sigmoid)
recon = -tf.reduce_sum(x * tf.log(y + 1e-10) + (1 - x) * tf.log(1 - y + 1e-10), 1)/image_size


# KL(q(z|x)||p(z)) ~ -(1/2) * sum(1 + log(sigma^2) - mu^2 - sigma^2)
kl_div = - (1/2)* tf.reduce_sum(1.0 + 2.0 * z_log_sigma - tf.square(z_mu) - tf.exp(2.0 * z_log_sigma),1)/image_size

 

cost = tf.reduce_mean(recon + kl_div)
cost_kl = tf.reduce_mean(kl_div)
cost_recon = tf.reduce_mean(recon)
saver = tf.train.Saver()

In [None]:
if not os.path.exists("./model/"):
    os.mkdir("./model/")
if not os.path.exists("./pickle/"):
    os.mkdir("./pickle/")

In [None]:
% rm -rf img/
% mkdir img/


def create_images(i, sess, test_images, num_examples=20, image_size=28*28):
    h = w = int(np.sqrt(image_size/3))
    
    original = get_batch(test_images, num_examples, 0)
    reconstruction = sess.run(y, feed_dict={x: original})
    
    fig, axs = plt.subplots(2, num_examples, figsize=(20, 2))
    for example_i in range(num_examples):
        axs[0][example_i].imshow(np.reshape(original[example_i, :], (h, w, 3)))
        axs[1][example_i].imshow(np.reshape(np.reshape(reconstruction[example_i, ...], (image_size,)),(h, w, 3)))
        axs[0][example_i].axis('off')
        axs[1][example_i].axis('off')
    fig.savefig('img/reconstruction_%08d.png' % i)
    plt.close()

    
def create_latent_scatter_images(i, sess, test_images,test_labels):
    zs = sess.run(z, feed_dict={x: test_images})
    fig, ax = plt.subplots(1, 1)
    ax.clear()
    ax.scatter(zs[:, 0], zs[:, 1], c=np.argmax(test_labels, 1), alpha=0.2)
    ax.set_xlim([-6, 6])
    ax.set_ylim([-6, 6])
    ax.axis("off")
    fig.savefig("img/latent_scatter_%08d.png"% i)
    plt.close()

In [None]:
m = glob(os.path.join("./model/", "model*"))
latest_ckpt = sorted(m)[-1].split(".meta")[0]
latest_ckpt

In [None]:
def train(train_images, validation_images, batch_size=100, image_size=28*28, learning_rate = 0.005):
    test_images = validation_images[:10]
    sess = tf.Session()
    optimizer = tf.train.AdamOptimizer(learning_rate).minimize(cost)
    sess.run(tf.global_variables_initializer())
    if latest_ckpt:
        saver.restore(sess, latest_ckpt)
    
    num_epochs = 10000

    num_train_batches = len(train_images) // batch_size
    num_validation_batches = len(validation_images) // batch_size

    print("num of train batches: ", num_train_batches)
    print("num of validation batches: ", num_validation_batches)
    create_images(0, sess, test_images, num_examples=10, image_size=image_size)
    
    for epoch in range(num_epochs):
        print("epoch no.", epoch)
        
        for batch_idx in tqdm(range(num_train_batches)):            
            batch  = get_batch(train_images, batch_size, batch_idx)
            sess.run(optimizer, feed_dict={x: batch})
         
        train_cost = sess.run(cost, feed_dict={x: batch})
        print("train cost per a batch: ", train_cost)
        create_images(epoch, sess, test_images, num_examples=10, image_size=image_size)
        #create_latent_scatter_images(epoch, sess, test_images, test_labels)
        print("cost_kl: ", sess.run(cost_kl, feed_dict={x: batch}))
        print("cost_recon: ", sess.run(cost_recon, feed_dict={x: batch}))
      

        valid_cost = 0
        for j in range(num_validation_batches):
            batch  = get_batch(validation_images, batch_size, j)
            valid_cost += sess.run(cost, feed_dict={x: batch})
        print("validation cost per a batch: ", valid_cost / num_validation_batches)
        now = datetime.datetime.now().strftime('%Y_%m_%d_%H_%M_%S')
        save_path = saver.save(sess, "model/model_{}.ckpt".format(now))
        print("model saved in file: %s" % save_path)


In [None]:
def get_z(images):
    batch_size = 50
    n_batches = 1 + len(images) // batch_size
    print("num of  batches: ", n_batches)
    sess = tf.Session()
    sess.run(tf.global_variables_initializer())
    if latest_ckpt:
        saver.restore(sess, latest_ckpt)
    z_list = []
    for idx in tqdm(range(n_batches)):
        batch  = get_batch(images, batch_size, idx)
        z_list.append(sess.run(z, feed_dict={x: batch}))
    return np.concatenate(z_list, axis=0)

In [None]:
def imread(path):
    return scipy.misc.imread(path).astype(np.float)

def resize_width(image, width=64.):
    h, w = np.shape(image)[:2]
    return scipy.misc.imresize(image,[int((float(h)/w)*width),width])
        
def center_crop(x, height=64):
    h= np.shape(x)[0]
    j = int(round((h - height)/2.))
    return x[j:j+height,:,:]

def get_image(image_path, width=64, height=64):
    return center_crop(resize_width(imread(image_path), width = width),height=height)

def load_img_data(data):
    images = np.zeros((len(data),dim*dim*3), dtype = np.float)
    for i, d in enumerate(tqdm(data)):
        image = get_image(d, dim,dim)
        images[i] = image.flatten()/255.
    return images

In [None]:
data = glob(os.path.join("./data/celebA", "*.jpg"))
data = np.sort(data)
print(data[0:10])
print("num of data: ", len(data))

In [None]:
#data = data[:10000]

In [None]:
dim = int(np.sqrt(image_size/3))
test = get_image(data[0],dim,dim)
fig, ax = plt.subplots(nrows=1,ncols=1, figsize=(10,3))
ax.imshow(np.reshape(test, (dim,dim,3)), interpolation='nearest')

In [None]:
image_batch_size = 1000
n_image_batches = 1 + len(data) // image_batch_size

if os.path.exists("./pickle/images_0.pickle"):
    print("loading from pickles")
    images_list = []
    for i in tqdm(range(n_image_batches)):
        file_name = "./pickle/images_{}.pickle".format(i)
        with open(file_name, mode='rb') as f:
            images_list.append(pickle.load(f))
    images = np.concatenate(images_list, axis=0)
    del images_list[:]
    del images_list
else:
    images =  load_img_data(data)
        
    for i in tqdm(range(n_image_batches)):
        image_batch = get_batch(images, image_batch_size, i)
        file_name = "./pickle/images_{}.pickle".format(i)
        with open(file_name ,mode='wb') as f:
            pickle.dump(image_batch, f, protocol=4)

In [None]:
%%time
image_batch_size = 1000
n_image_batches = 1 + len(data) // image_batch_size

from concurrent.futures import ProcessPoolExecutor, as_completed

def load_pickel(i):
    file_name = "./pickle/images_{}.pickle".format(i)
    with open(file_name, mode='rb') as f:
        return pickle.load(f)
    

worker_num = 5
task_num = n_image_batches
with ProcessPoolExecutor(worker_num) as executor:
    futures = []
    for i in range(n_image_batches):
        futures.append(executor.submit(load_pickel, i))
images_list = [x.result() for x in as_completed(futures)]



In [None]:
%%time
image_batch_size = 1000
n_image_batches = 1 + len(data) // image_batch_size
images_list = []
for i in tqdm(range(n_image_batches)):
    file_name = "./pickle/images_{}.pickle".format(i)
    with open(file_name, mode='rb') as f:
        images_list.append(pickle.load(f))

In [None]:
images_list = [x.result() for x in as_completed(futures)]
images = np.concatenate(images_list, axis=0)
del images_list[:]
del images_list


In [None]:
#images = load_img_data(data)

In [None]:
len(images)

In [None]:
images[0].shape

In [None]:
test = images[-1]
fig, ax = plt.subplots(nrows=1,ncols=1, figsize=(10,3))
ax.imshow(np.reshape(test, (dim,dim,3)), interpolation='nearest')
ax.axis('off')
plt.show()

In [None]:
batch_size = 128
train_images = images[batch_size*2:]
validation_images = images[:batch_size*2]

In [None]:
train(train_images, validation_images, batch_size = batch_size, image_size=image_size, learning_rate = 0.001)

In [None]:
z_list = get_z(images[:batch_size*2])

In [None]:
len(z_list)

In [None]:
from scipy.spatial import distance
pairwise = distance.squareform(distance.pdist(z_list, metric="cosine"))

In [None]:
def similar_to(img_id, num=5, distance=False):
    img = pairwise[img_id]
    ids = np.argsort(img)[0: num] 
    if distance:
        dist = np.sort(img)[0: num]
        return [(x,y) for x, y in zip(ids, dist)]
    else:
        return ids

def show_sim_image(img_id):
    id_list = similar_to(img_id)
    dim = int(np.sqrt(image_size/3))
    fig, axs = plt.subplots(1, len(id_list), figsize=(20, 3))
    for i, id in enumerate(id_list):
        print(id)
        test = get_image(data[id],dim,dim)
        axs[i].imshow(np.reshape(test, (dim, dim, 3)),cmap='gray')
        axs[i].axis('off')
    plt.show()

In [None]:
show_sim_image(0)

In [None]:
fig, axs = plt.subplots(1, len(test_images), figsize=(20, 3))

for i, img in enumerate(test_images):
    axs[i].imshow(np.reshape(img, (dim, dim, 3)),cmap='gray')
    axs[i].axis('off')
plt.show()

In [None]:
test_images = images[:30]
f = [0,1]
m = [1,0]
test_labels=[f,m,f,m,m,f,f,m,m,m,f,f,m,f,f,m,m,m,f,f,m,f,m,f,m,m,m,m,f,f]
sess = tf.Session()
sess.run(tf.global_variables_initializer())
if latest_ckpt:
    saver.restore(sess, latest_ckpt)
zs = sess.run(z, feed_dict={x: test_images})
print(len(zs))
fig, ax = plt.subplots(1, 1)
ax.clear()
ax.scatter(zs[:, 0], zs[:, 1],c=np.argmax(test_labels, 1))
ax.set_xlim([-6, 6])
ax.set_ylim([-6, 6])
fig.show()


In [None]:
sess = tf.Session()
sess.run(tf.global_variables_initializer())
if latest_ckpt:
    saver.restore(sess, latest_ckpt)

ds = tf.contrib.distributions
mu, sigma = sess.run([tf.exp(z_mu),tf.exp(z_log_sigma)], feed_dict={x:validation_images})
p = ds.Normal(loc=mu[0].tolist(), scale=sigma[0].tolist())
q = ds.Normal(loc=mu[200].tolist(), scale=sigma[200].tolist())
kl = tf.reduce_sum(ds.kl_divergence(p, q))
result = sess.run(kl)


In [None]:
result

In [None]:
result

In [None]:
h = p.sample([1000])
_h = sess.run(h)
__h = sorted([hi[2] for hi in _h])

In [None]:
import numpy as np
import scipy.stats as stats



#fit = stats.norm.pdf(__h, mu[1][0], sigma[1][0])  #this is a fitting indeed

#plt.plot(__h,fit,'-o')

plt.hist(__h, normed=True)

plt.show() 

In [None]:
# mnist

In [None]:
import tensorflow.examples.tutorials.mnist.input_data as input_data
mnist = input_data.read_data_sets('MNIST_DATA', one_hot=True)

In [None]:
train_images = mnist.train.images
validation_images = mnist.validation.images
test_images = mnist.test.images
test_labels = mnist.test.labels

In [None]:
train_images[0].shape

In [None]:
train(train_images, validation_images, test_images, test_labels, image_size=28*28)