In [1]:
import numpy as np
import tensorflow as tf

In [2]:
### hyper params
D_data = 784
D_noise = 10
batch_size = 50

In [3]:
image_input = tf.placeholder(tf.float32,shape=[None,D_data])
noise_input = tf.placeholder(tf.float32,shape=[None,D_noise])

In [4]:
from Classes.LayerClass import FC, BN
from Classes.ModuleClass import FF, LinearModule

#generative net architecture
#with tf.variable_scope('lol',collec)
layers = [FC([D_noise,100]),BN(),tf.nn.relu,
          FC([100,50]),BN(),tf.nn.relu,
          FC([50,D_data]),tf.nn.relu]
gen_func = LinearModule(layers)
print(gen_func)

#discriminative net architecture
layers = [FC([D_data,100]),BN(),tf.nn.relu,
          FC([100,50]),BN(),tf.nn.relu,
         FC([50,1]),tf.nn.sigmoid]
dis_func = LinearModule(layers)
print(dis_func)

gen_im = gen_func(noise_input)
tf.image_summary('generatedims',tf.reshape(gen_im,[batch_size,28,28,1]),max_images=30)
gen_out = dis_func(gen_im)
data_out = dis_func(image_input)


Trans0 => Shape: [10, 100]   Mapping: f(x)
BN_layer1 =>    Mapping: f(x)
<function relu at 0x108fd6e18>   Mapping: f(x)
Trans2 => Shape: [100, 50]   Mapping: f(x)
BN_layer3 =>    Mapping: f(x)
<function relu at 0x108fd6e18>   Mapping: f(x)
Trans4 => Shape: [50, 784]   Mapping: f(x)
<function relu at 0x108fd6e18>   Mapping: f(x)


Trans5 => Shape: [784, 100]   Mapping: f(x)
BN_layer6 =>    Mapping: f(x)
<function relu at 0x108fd6e18>   Mapping: f(x)
Trans7 => Shape: [100, 50]   Mapping: f(x)
BN_layer8 =>    Mapping: f(x)
<function relu at 0x108fd6e18>   Mapping: f(x)
Trans9 => Shape: [50, 1]   Mapping: f(x)
<function sigmoid at 0x108f77ae8>   Mapping: f(x)



In [5]:
def loss(x):
    return tf.reduce_mean(tf.reduce_sum(tf.log(x),reduction_indices=[1]))

def dis_loss_fn(gen_out,data_out):
    with tf.name_scope('disloss'):
        return -loss(1-gen_out) - loss(data_out) #want: gen_out -> close to zero ;; d_out -> close to one

def gen_loss_fn(gen_out,data_out):
    with tf.name_scope('genloss'):
        return -loss(gen_out) #want: gen_out -> to trick dis into ones

In [6]:
variables = tf.get_collection(tf.GraphKeys.VARIABLES)
index = int(len(variables)/2)

gen_loss = gen_loss_fn(gen_out,data_out)
tf.scalar_summary('genloss',gen_loss)
#gen_vars = tf.get_collection(tf.GraphKeys.VARIABLES, scope=gen_func.name)
gen_vars = variables[0:index-1]

dis_loss = dis_loss_fn(gen_out,data_out)
tf.scalar_summary('disloss',dis_loss)
#dis_vars = tf.get_collection(tf.GraphKeys.VARIABLES, scope=dis_func.name)
dis_vars = variables[index:]


def acc(x,t):
    return tf.reduce_mean(tf.cast(tf.equal(t,x), tf.float32))
accuracy = acc(gen_out,0.0) + acc(data_out,1.0)
tf.scalar_summary('accuracy', accuracy)

<tf.Tensor 'ScalarSummary_2:0' shape=() dtype=string>

In [7]:
gen_opt = tf.train.MomentumOptimizer(0.001,0.5)
#gen_opt = tf.train.RMSPropOptimizer(0.001, decay=0.99999, momentum=0.5)
#gen_opt = tf.train.AdamOptimizer(0.001)
gen_step = gen_opt.minimize(loss=gen_loss,var_list=gen_vars)

dis_opt = tf.train.MomentumOptimizer(0.001,0.5)
#dis_opt = tf.train.RMSPropOptimizer(0.001, decay=0.99999, momentum=0.5)
#dis_opt = tf.train.AdamOptimizer(0.001)
dis_step = dis_opt.minimize(loss=dis_loss,var_list=dis_vars)

In [8]:
def feeder(ims,labels):
    x = ims.reshape(batch_size,D_data)/255
    z = np.random.standard_normal((batch_size,D_noise))
    return {image_input:x,noise_input:z}

def fetcher(sess,feed):
    gL,dL,gO,dO,gim = sess.run([gen_loss,dis_loss,gen_out,data_out,gen_im],feed)
    print('\r Gen Loss: {:.7f} Dis Loss: {:.7f}'.format(gL,dL),end='')
    return gen_im

In [None]:
log_dir = 'Tensorboard/'
if tf.gfile.Exists(log_dir):
    tf.gfile.DeleteRecursively(log_dir)
tf.gfile.MakeDirs(log_dir)

from fuel.datasets import MNIST
mnist = MNIST(('train',))
state = mnist.open()
from fuel.schemes import ShuffledScheme 
scheme = ShuffledScheme(examples=int(mnist.num_examples), batch_size=batch_size)

In [None]:
merged = tf.merge_all_summaries()

with tf.Session() as sess:
    sess.run(tf.initialize_all_variables())
    writer = tf.train.SummaryWriter(log_dir,sess.graph)

    count = 0
    for e in range(100):
        L=0
        for i,request in enumerate(scheme.get_request_iterator()):
            #need to generalise this to different data types?
            ims,labels = mnist.get_data(state=state, request=request)

            feed = feeder(ims,labels)
            generated_ims = fetcher(sess,feed)

            sess.run(dis_step,feed)
            if i%50 == 0:
                sess.run(gen_step,feed)
            
            if i%1000 == 0:
                #Summaries
                summary = sess.run(merged,feed_dict=feed)
                writer.add_summary(summary, count)
                count+=1

 Gen Loss: 2.3003037 Dis Loss: 0.1082626