# RM-MMDnet

In [None]:
%pylab inline
import tensorflow as tf
import numpy as np
from ops import *
from utils import *
import os
import time
from glob import glob
from scipy.misc import imsave as ims
from random import randint
from data_providers import *
slim = tf.contrib.slim
import scipy as sp
import pickle

cifar = True
# %pylab inline

config = tf.ConfigProto(
    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=.3),
    device_count = {'GPU': 1}
)

In [None]:
def transform(image, npx=64, is_crop=True):
    if is_crop:
        cropped_image = center_crop(image, npx)
    else:
        cropped_image = image
    return np.array(cropped_image)/127.5 - 1.

In [None]:
params = {
    'batch_size':64,
    'image_dim':32*32*3,
    'c':3,
    'h':32,
    'w':32
}


In [None]:
    
def discriminator(image, reuse=False, i=0):
    
    with tf.variable_scope("disc") as scope:
        if reuse:
                scope.reuse_variables()
        h=32
        w=32
        h0 = image
        h0 = lrelu(conv2d(h0, 3, df_dim, name='d_h0_conv')) #16x16x32
        h1 = lrelu(tf.contrib.layers.batch_norm(conv2d(h0, df_dim, df_dim*2, name='d_h1_conv'+str(i)))) #8x8x64
        h2 = lrelu(tf.contrib.layers.batch_norm(conv2d(h1, df_dim*2, df_dim*2*2, name='d_h2_conv'+str(i)))) #4x4x128
        h3 = tf.reshape(h2, [batchsize, -1])
        h4 = (dense(h3, 4*4*df_dim*4, 3, scope='d_h4_lin'+str(i))) #2048

        return h4
    

        
    

        
def generator(z):
    with tf.variable_scope("gen") as scope:
        gf_dim=32
        h=32
        w=32
        z2 = dense(z, z_dim, 4*4*gf_dim*4, scope='g_h0_lin')
        h0 = tf.nn.relu(tf.contrib.layers.batch_norm(tf.reshape(z2, [-1, 4, 4, gf_dim*4]))) # 4x4x128
        h1 = tf.nn.relu(tf.contrib.layers.batch_norm(conv_transpose(h0, [batchsize, 8, 8, gf_dim*2], "g_h1"))) #8x8x64
        h2 = tf.nn.relu(tf.contrib.layers.batch_norm(conv_transpose(h1, [batchsize, 16, 16, gf_dim*1], "g_h2"))) #16x16x32
        h3 = tf.nn.tanh(conv_transpose(h2, [batchsize, 32, 32, 3], "g_h4"))

        return h3  
        

In [None]:
tf.reset_default_graph()
sdata=[]
batchsize = params['batch_size']
iscrop = False
imageshape = [32, 32, 3]
z_dim = 128
noise_dim = z_dim
gf_dim = 32
df_dim = 32
c_dim = 3
learningrate_gen = 1e-4
learningrate = 1e-4
beta1 = 0.5
NUM_PROJ=1
NUM_MMD = 10
batch_size = batchsize

images = tf.placeholder(tf.float32, [batchsize] + imageshape, name="real_images")


dloss=0.0
gloss=0.0
closs=0.0


gen_input = tf.placeholder(tf.float32, shape=[None, noise_dim], name='input_noise')
disc_input = tf.placeholder(tf.float32, shape=[batchsize] + imageshape, name='disc_input')
m_input = tf.placeholder(tf.float32, shape=[batchsize] + imageshape, name='disc_input')
n_input = tf.placeholder(tf.float32, shape=[batchsize] + imageshape, name='disc_input')


# Build Generator Network
gen_sample = generator(gen_input)

# Build 2 Discriminator Networks (one from noise input, one from generated samples)
disc_real = discriminator(disc_input)
disc_fake = discriminator(gen_sample,reuse=True)
disc_m = discriminator(disc_input+m_input,reuse=True)
disc_m_gen = discriminator(gen_sample+n_input,reuse=True)

# Build Loss
a = np.tile([1.,0.,0.],batch_size)
b = np.tile([0.,1.,0.],batch_size)
c = np.tile([0.,0.,1.],batch_size)
label_a = tf.reshape(a,[batch_size,3])
label_b = tf.reshape(b,[batch_size,3])
label_c = tf.reshape(c,[batch_size,3])

disc_loss_1 = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=disc_real, labels=label_a))
disc_loss_2 = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=disc_fake, labels=label_b))
disc_loss_3 = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=disc_m, labels=label_c))
disc_loss_4 = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=disc_m_gen, labels=label_c))
dloss = disc_loss_1 + disc_loss_2 + disc_loss_3 + disc_loss_4

r_p_m = (tf.nn.softmax(disc_m_gen)[:,0])/(tf.nn.softmax(disc_m_gen)[:,2])
r_q_m = (tf.nn.softmax(disc_m_gen)[:,1])/(tf.nn.softmax(disc_m_gen)[:,2])

gloss = tf.reduce_mean( tf.square(r_p_m - 1) - tf.square(r_q_m - 1) - 2*(r_q_m-1)*(r_p_m-r_q_m) )
                      
t_vars = tf.trainable_variables()
g_vars = [var for var in t_vars if 'gen' in var.name]
d_vars = [var for var in t_vars if 'disc' in var.name]

print(np.sum([np.prod(v.get_shape().as_list()) for v in tf.trainable_variables()]))
print(np.sum([np.prod(v.get_shape().as_list()) for v in tf.trainable_variables() if 'disc' in v.name]))
print(np.sum([np.prod(v.get_shape().as_list()) for v in tf.trainable_variables() if 'gen' in v.name]))

g_optim = tf.train.AdamOptimizer(learningrate_gen, beta1=beta1).minimize(gloss, var_list=g_vars)
d_optim = tf.train.AdamOptimizer(learningrate,     beta1=beta1).minimize(dloss, var_list=d_vars)


start_time = time.time()
display_z = np.random.uniform(-1, 1, [batchsize, z_dim]).astype(np.float32)

batch = CIFAR10DataProvider(batch_size=params['batch_size'])
batch_idxs = batch.num_batches
batch_idx = batch_idxs


sess = tf.InteractiveSession(config=config)

sess.run(tf.global_variables_initializer())
tf.initialize_all_variables().run()



In [None]:
# '------Training GAN--------------------'
for epoch in xrange(21):
    batch.new_epoch()
     
    
    for idx in range(batch_idxs):
        batch_images,_=batch.next()
        batch_z = np.random.uniform(-1, 1, [batchsize, z_dim]).astype(np.float32) 
        
        
        z = np.random.uniform(-1., 1., size=[batch_size, noise_dim])
        m = np.random.normal(0, .01, size=[batchsize] + imageshape)
        n_ = np.random.normal(0, .01, size=[batchsize] + imageshape)

        # Train
        feed_dict = {disc_input: batch_images, gen_input: z, m_input: m, n_input: n_}
        _, _, gl, dl = sess.run([d_optim, g_optim, gloss, dloss],
                                feed_dict=feed_dict)
        
        
#     '---------Printing intermediate results-------------'      
    if epoch % 10 == 0:
        
        print("Epoch: [%2d] [%4d/%4d] time: %4.4f, " % (epoch, idx, batch_idx, time.time() - start_time,))
        
        _,sdata = sess.run([gen_sample,gen_sample],feed_dict = {disc_input: batch_images, gen_input: z, m_input: m, n_input: n_})
        sdata = sdata[:64]
        sdata = np.expand_dims(sdata,0)
        img = merge(sdata[0],[8,8])
        img = (img+1.)/2.
        plt.imshow(img)
        plt.show()
        
    




In [None]:

gen_images = np.vstack([sess.run(gen_sample,feed_dict={ \
            gen_input: np.random.uniform(-1, 1, [batchsize, z_dim]).astype(np.float32) }) for _ in range(250)])

print(gen_images.shape)

np.save("cifar_breg",(gen_images+1.)/2.)