In [None]:
import tensorflow as tf
slim = tf.contrib.slim
import numpy as np
from ops import *
from utils import *
import os
import time
from glob import glob
from scipy.misc import imsave as ims
from data_providers import *
import scipy as sp
from keras import metrics
from keras import backend as K
%pylab inline
import datetime
import sys
import json
import gc

In [None]:
params = {
    'batch_size':32,
    'image_dim':128*128*3,
    'c':3,
    'h':64,
    'w':64,
    'im_channels':3,
    'im_height':64,
    'im_width':64,
    'latent_code':11,
    'latent_noise_dim':100,
    'color_dim':50
}


## Model definition

In [None]:
def encoder(x,eps, reuse=False):
    with tf.variable_scope("enc") as scope:
        if reuse:
                scope.reuse_variables()
        gf_dim=64
        h=64
        w=64
        s2, s4, s8, s16 = int(gf_dim/2), int(gf_dim/4), int(gf_dim/8), int(gf_dim/16)
        x2 = dense(x, params['latent_code'], gf_dim*8*s16*s16, scope='g_h0_lin')
        
        mu = tf.reshape(slim.fully_connected(eps,gf_dim*8*s16*s16),[-1, s16, s16, gf_dim * 8])
        sig = tf.reshape(slim.fully_connected(eps,gf_dim*8*s16*s16),[-1, s16, s16, gf_dim * 8])
        h0 = tf.contrib.layers.instance_norm(tf.reshape(x2, [-1, s16, s16, gf_dim * 8]),center=False,scale=False)
        h0 = tf.nn.relu(sig*h0 + mu)
        
        mu = tf.reshape(slim.fully_connected(eps,gf_dim*4*s8*s8),[batchsize, s8, s8, gf_dim*4])
        sig = tf.reshape(slim.fully_connected(eps,gf_dim*4*s8*s8),[batchsize, s8, s8, gf_dim*4])
        h1 = tf.contrib.layers.instance_norm(conv_transpose(h0, [batchsize, s8, s8, gf_dim*4], "g_h1"),center=False,scale=False)
        h1 = tf.nn.relu(sig*h1 + mu)
        
        mu = tf.reshape(slim.fully_connected(eps,gf_dim*2*s4*s4),[batchsize, s4, s4, gf_dim*2])
        sig = tf.reshape(slim.fully_connected(eps,gf_dim*2*s4*s4),[batchsize, s4, s4, gf_dim*2])
        h2 = tf.contrib.layers.instance_norm(conv_transpose(h1, [batchsize, s4, s4, gf_dim*2], "g_h2"),center=False,scale=False)
        h2 = tf.nn.relu(sig*h2 + mu)
        
        mu = tf.reshape(slim.fully_connected(eps,gf_dim*1*s2*s2),[batchsize, s2, s2, gf_dim*1])
        sig = tf.reshape(slim.fully_connected(eps,gf_dim*1*s2*s2),[batchsize, s2, s2, gf_dim*1])
        h3 = tf.contrib.layers.instance_norm(conv_transpose(h2, [batchsize, s2, s2, gf_dim*1], "g_h3"),center=False,scale=False)
        h3 = tf.nn.relu(sig*h3 + mu)
        
        h4 = tf.nn.sigmoid(conv_transpose(h3, [batchsize, h, w, 3], "g_h4"))
        
        return h4


In [None]:
def discriminator(image,c, reuse=False):
    
    with tf.variable_scope("disc") as scope:
        if reuse:
                scope.reuse_variables()
        df_dim=params['h']
        h=params['h']
        w=params['w']
        
        h0 = tf.concat([image, c], 3)
        
        h0 = lrelu(conv2d(h0, 6, df_dim, name='d_h0_conv'))
        h1 = lrelu(tf.contrib.layers.batch_norm(conv2d(h0, df_dim, df_dim*2, name='d_h1_conv'))) #8x8x64
        h2 = lrelu(tf.contrib.layers.batch_norm(conv2d(h1, df_dim*2, df_dim*4, name='d_h2_conv'))) #4x4x128
        h3 = lrelu(tf.contrib.layers.batch_norm(conv2d(h2, df_dim*4, df_dim*8, name='d_h3_conv'))) #4x4x128
        h3 = tf.reshape(h3, [batchsize, -1])
        
        logit = (dense(h3, 8192, 1, scope='d_mu_lin')) 
        return logit
    

        

def decoder(y,eps, reuse=False):
    with tf.variable_scope("dec") as scope:
        if reuse:
                scope.reuse_variables()
        gf_dim=64
        df_dim=params['h']
        h=64
        w=64
        s2, s4, s8, s16 = int(gf_dim/2), int(gf_dim/4), int(gf_dim/8), int(gf_dim/16)
        
        #Conv
        h0 = lrelu(conv2d(y, 3, df_dim, name='d_h0_conv'))
        
        mu = tf.reshape(slim.fully_connected(eps,gf_dim*2*s4*s4),[-1, s4, s4, gf_dim * 2])
        sig = tf.reshape(slim.fully_connected(eps,gf_dim*2*s4*s4),[-1, s4, s4, gf_dim * 2])
        h1 = tf.contrib.layers.instance_norm(conv2d(h0, df_dim, df_dim*2, name='d_h1_conv'),center=False,scale=False)
        h1 = lrelu(sig*h1 + mu) 
        
        mu = tf.reshape(slim.fully_connected(eps,gf_dim*4*s8*s8),[-1, s8, s8, gf_dim * 4])
        sig = tf.reshape(slim.fully_connected(eps,gf_dim*4*s8*s8),[-1, s8, s8, gf_dim * 4])
        h2 = tf.contrib.layers.instance_norm(conv2d(h1, df_dim*2, df_dim*4, name='d_h2_conv'),center=False,scale=False)
        h2 = lrelu(sig*h2 + mu) 
        
        mu = tf.reshape(slim.fully_connected(eps,gf_dim*8*s16*s16),[-1, s16, s16, gf_dim * 8])
        sig = tf.reshape(slim.fully_connected(eps,gf_dim*8*s16*s16),[-1, s16, s16, gf_dim * 8])
        h3 = tf.contrib.layers.instance_norm(conv2d(h2, df_dim*4, df_dim*8, name='d_h3_conv'),center=False,scale=False)
        h3 = lrelu(sig*h3 + mu)
        
        h3 = tf.reshape(h3, [batchsize, -1])

        y2 = dense(h3, 8192+0, gf_dim*8*s16*s16, scope='g_h0_lin')
        
        #Deconv
        mu = tf.reshape(slim.fully_connected(eps,gf_dim*8*s16*s16),[-1, s16, s16, gf_dim * 8])
        sig = tf.reshape(slim.fully_connected(eps,gf_dim*8*s16*s16),[-1, s16, s16, gf_dim * 8])
        h0 = tf.contrib.layers.instance_norm(tf.reshape(y2, [-1, s16, s16, gf_dim * 8]),center=False,scale=False)
        h0 = tf.nn.relu(sig*h0 + mu)
        
        mu = tf.reshape(slim.fully_connected(eps,gf_dim*4*s8*s8),[batchsize, s8, s8, gf_dim*4])
        sig = tf.reshape(slim.fully_connected(eps,gf_dim*4*s8*s8),[batchsize, s8, s8, gf_dim*4])
        h1 = tf.contrib.layers.instance_norm(conv_transpose(h0, [batchsize, s8, s8, gf_dim*4], "g_h1"),center=False,scale=False)
        h1 = tf.nn.relu(sig*h1 + mu)
        
        mu = tf.reshape(slim.fully_connected(eps,gf_dim*2*s4*s4),[batchsize, s4, s4, gf_dim*2])
        sig = tf.reshape(slim.fully_connected(eps,gf_dim*2*s4*s4),[batchsize, s4, s4, gf_dim*2])
        h2 = tf.contrib.layers.instance_norm(conv_transpose(h1, [batchsize, s4, s4, gf_dim*2], "g_h2"),center=False,scale=False)
        h2 = tf.nn.relu(sig*h2 + mu)
        
        mu = tf.reshape(slim.fully_connected(eps,gf_dim*1*s2*s2),[batchsize, s2, s2, gf_dim*1])
        sig = tf.reshape(slim.fully_connected(eps,gf_dim*1*s2*s2),[batchsize, s2, s2, gf_dim*1])
        h3 = tf.contrib.layers.instance_norm(conv_transpose(h2, [batchsize, s2, s2, gf_dim*1], "g_h3"),center=False,scale=False)
        h3 = tf.nn.relu(sig*h3 + mu)
        
        h4 = tf.nn.sigmoid(conv_transpose(h3, [batchsize, h, w, 3], "g_h4"))
        
        return h4
        

In [None]:
tf.reset_default_graph()
batchsize = params['batch_size']
imageshape = [64, 64, 3]
c_dim = params['latent_code']
gf_dim = 64
df_dim = 64
learningrate = 0.0001
beta1 = 0.5


images = tf.placeholder(tf.float32, [batchsize] + imageshape, name="real_images")
c_input = tf.placeholder(tf.float32, [batchsize] + [c_dim], name="code")
eps = tf.placeholder(tf.float32, [batchsize,params['latent_noise_dim']], name="noise")

#Encoder
E_enc = encoder(c_input,eps)

#Decoder
G_dec = decoder(E_enc,eps)

#Discriminator 
img_logits = discriminator(images,E_enc)
gen_logits = discriminator(G_dec,E_enc,reuse=True)

#Disc Loss
d_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=img_logits,labels=tf.ones_like(img_logits)))
d_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=gen_logits,labels=tf.zeros_like(gen_logits)))
d_loss = (d_loss_real + d_loss_fake)/2.

#Gen Loss
g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=gen_logits,labels=tf.ones_like(gen_logits)))

#Enc Loss
e_loss = tf.reduce_mean(metrics.mean_squared_error(images, E_enc))

#Optimisation
t_vars = tf.trainable_variables()
g_vars = [var for var in t_vars if 'dec' in var.name]
d_vars = [var for var in t_vars if 'disc' in var.name]
e_vars = [var for var in t_vars if 'enc' in var.name]

optim_e = tf.train.AdamOptimizer(learningrate, beta1=beta1).minimize(e_loss, var_list=e_vars)
optim_g = tf.train.AdamOptimizer(learningrate, beta1=beta1).minimize(g_loss, var_list=g_vars)
optim_d = tf.train.AdamOptimizer(learningrate, beta1=beta1).minimize(d_loss, var_list=d_vars)

start_time = time.time()

try:
    tf.Session.close(sess)
except:
    print("nothing to close")
    
sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())
tf.initialize_all_variables().run()

## read data

In [None]:
def unison_shuffled_copies(a, b):
    assert len(a) == len(b)
    p = numpy.random.permutation(len(a))
    return a[p], b[p]

TRAIN_STOP=100000
data = json.load(open('../img_store'))[:TRAIN_STOP]
c_store=np.load('../c_store_full')[:TRAIN_STOP]
gc.collect()

TRAIN_STOP = 100000

train_data64=np.asarray(data[:TRAIN_STOP])
train_c_store = c_store[:TRAIN_STOP]

gc.collect()

## Training

In [None]:
from tqdm import *

for epoch in tqdm_notebook(xrange(15)):
    
    train_data64, train_c_store = unison_shuffled_copies(train_data64, train_c_store)
    gc.collect()
    
    print(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
    batch_idxs = len(train_data64) // params['batch_size']

    
    for idx in tqdm_notebook(range(batch_idxs)):
        
        batch_files = train_data64[idx*params['batch_size']:(idx+1)*params['batch_size']]
        c_batch = np.asarray(train_c_store[idx*params['batch_size']:(idx+1)*params['batch_size']])

        eps_batch = np.random.normal(0,1,[params['batch_size'], params['latent_noise_dim']]).astype(np.float32)
        
        batch_img = [
          get_image(batch_file,
                    input_height=178,
                    input_width=218,
                    resize_height=64,
                    resize_width=64,
                    is_crop=False)/255. for batch_file in batch_files]

        batch_img = np.array(batch_img).astype(np.float32)
        
        # Train
        feed_dict = {c_input:c_batch,images:batch_img, 
                     eps:eps_batch}
        if epoch < 0:
            _ = sess.run([optim_e],
                                feed_dict=feed_dict)        
        else:
            _ = sess.run([optim_g,optim_d],
                                feed_dict=feed_dict)
           
   
        if idx % 500 == 0:

            print("Epoch: [%2d] [%4d/%4d] time: %4.4f, " % (epoch, idx, batch_idxs, time.time() - start_time,))
    
            E_enc_data, G_dec_data = sess.run([E_enc,G_dec],feed_dict = feed_dict)
            sdata = E_enc_data[:5]
            sdata = np.clip(sdata,0,1)
            sdata = np.expand_dims(sdata,0)
            img = merge(sdata[0],[1,5])
            plt.figure(figsize=(8,8))
            plt.imshow(img)
            plt.show()
            
            sdata = G_dec_data[:5]
            sdata = np.clip(sdata,0,1)
            sdata = np.expand_dims(sdata,0)
            img = merge(sdata[0],[1,5])
            plt.figure(figsize=(8,8))
            plt.imshow(img)
            plt.show()
           
            sdata = batch_img[:5]
            sdata = np.clip(sdata,0,1)
            sdata = np.expand_dims(sdata,0)
            img = merge(sdata[0],[1,5])
            plt.figure(figsize=(8, 8))
            plt.imshow(img)
            plt.show()

## Get traversal gif

In [None]:
import imageio

c_batch = np.asarray(train_c_store[idx*params['batch_size']:(idx+1)*params['batch_size']])
c_batch_ = copy(c_batch)

lb=0
ub=0
latent_dim = 2
lb = np.min(train_c_store[:,latent_dim])
ub = np.max(train_c_store[:,latent_dim])
c_batch_ = np.asarray([c_batch[2] for _ in range(32)])
c_batch_[:,latent_dim] = np.linspace(lb, ub, params['batch_size'])
eps_test = np.random.normal(0,1,[1, params['latent_noise_dim']])
latent_noise = eps_test.repeat(32, axis=0)
images_ = []

feed_dict = {c_input:c_batch_, eps:latent_noise,images: batch_img,}
G_dec_data = sess.run(G_dec,feed_dict = feed_dict)

#Visualise
for val in range(32):
    sdata = G_dec_data[val:val+1]
    sdata = np.clip(sdata,0,1)
    sdata = np.expand_dims(sdata,0)
    img = merge(sdata[0],[1,1])
    images_.append(img)
    plt.figure(figsize=(1,1))
    plt.imshow(img)
    plt.show()
imageio.mimsave('traversal.gif', images_)