In [1]:
import numpy as np
import tensorflow as tf
import os,time
from glob import glob

from ops import batch_norm,linear,conv2d,deconv2d,lrelu
from image_helpers import *

In [2]:
#Parameter Defintion
is_crop=True
batch_size=64
image_size=108
sample_size=64
image_shape=[64,64,3]

z_dim=100

gf_dim=64
df_dim=64

learning_rate=0.0002
beta1=0.5

In [3]:
import numpy as np
from PIL import Image
img=Image.open("download.jpg")
img.load()
img = np.asarray( img, dtype="int32" )
img.shape


(102, 102, 3)

In [4]:
d_bn1 = batch_norm(name='d_bn1')
d_bn2 = batch_norm(name='d_bn2')
d_bn3 = batch_norm(name='d_bn3')

g_bn0 = batch_norm(name='g_bn0')
g_bn1 = batch_norm(name='g_bn1')
g_bn2 = batch_norm(name='g_bn2')
g_bn3 = batch_norm(name='g_bn3')

In [5]:
def discriminator(image,reuse=False):
    if reuse:
        tf.get_variable_scope().reuse_variables()
        
    h0=lrelu(conv2d(image,df_dim,name='d_h0_conv'))
    h1=lrelu(d_bn1(conv2d(h0,df_dim*2,name='d_h1_conv')))
    h2=lrelu(d_bn2(conv2d(h1,df_dim*4,name='d_h2_conv')))
    h3=lrelu(d_bn3(conv2d(h2,df_dim*8,name='d_h3_conv')))
    h4=linear(tf.reshape(h3,[batch_size,-1]),1,'d_h3_lin')
             
    return tf.nn.sigmoid(h4),h4

In [6]:
def generator(z):
    z_=linear(z,gf_dim*8*4*4,'g_h0_lin')
    h0=tf.nn.relu(g_bn0(tf.reshape(z_,[-1,4,4,gf_dim*8])))
    h1=tf.nn.relu(g_bn1(deconv2d(h0,[batch_size,8,8,gf_dim*4],name='g_h1')))
    h2=tf.nn.relu(g_bn2(deconv2d(h1,[batch_size,16,16,gf_dim*2],name='g_h2')))
    h3=tf.nn.relu(g_bn3(deconv2d(h2,[batch_size,32,32,gf_dim*1],name='g_h3')))
    h4 = deconv2d(h3, [batch_size, 64, 64, 3], name='g_h4')
    
    return tf.nn.tanh(h4)

In [7]:
images=tf.placeholder(tf.float32,[batch_size]+image_shape,name='real_images')
sample_images=tf.placeholder(tf.float32,[sample_size]+image_shape,name="sample_images")
z=tf.placeholder(tf.float32,[None,z_dim])

G=generator(z)
D,D_logits=discriminator(images)
D_,D_logits_=discriminator(G,reuse=True)

#cost fn
d_loss_real=tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(D_logits,tf.ones_like(D)))
d_loss_fake=tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(D_logits_,tf.zeros_like(D_)))
d_loss=d_loss_real+d_loss_fake

g_loss=tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(D_logits_,tf.ones_like(D_)))

In [8]:
#Optimizers
t_vars=tf.trainable_variables()

d_vars=[var for var in t_vars if 'd_' in var.name]
g_vars=[var for var in t_vars if 'g_' in var.name]

d_optim=tf.train.AdamOptimizer(learning_rate,beta1=beta1).minimize(d_loss,var_list=d_vars)
g_optim=tf.train.AdamOptimizer(learning_rate,beta1=beta1).minimize(g_loss,var_list=g_vars)


In [9]:
sess=tf.Session()
sess.run(tf.global_variables_initializer())

saver=tf.train.Saver()

In [10]:
#DATASET 
data=glob(os.path.join('img_align_celeba/','*.jpg'))

sample_z=np.random.uniform(-1,1,size=(sample_size,z_dim))
sample_files=data[0:sample_size]
sample=[get_image(sample_file,image_size,is_crop) for sample_file in sample_files]
sample_images=np.reshape(np.array(sample).astype(np.float32),[sample_size]+image_shape)


In [21]:
#Training
counter=1
start_time=time.time()
for epoch in range(1):
    np.random.shuffle(data)
    batchidxs=int(len(data)/batch_size)
    
    for idx in range(batchidxs):
        batch_files=data[idx*batch_size:(idx+1)*batch_size]
        batch=[get_image(batch_file,image_size,is_crop=is_crop) for batch_file in batch_files]
        batch_images=np.reshape(np.array(batch).astype(np.float32),[batch_size]+image_shape)
        
        batch_z=np.random.uniform(-1,1,[batch_size,z_dim]).astype(np.float32)
        
        #Update D network
        sess.run([d_optim],feed_dict={images:batch_images,z:batch_z})
        
        #Update G Network
        sess.run([g_optim],feed_dict={z:batch_z})
        sess.run([g_optim],feed_dict={z:batch_z})
        
        errD_fake = d_loss_fake.eval({z: batch_z}, session=sess)
        errD_real = d_loss_real.eval({images: batch_images}, session=sess)
        errG = g_loss.eval({z: batch_z}, session=sess)
        
        counter+=1
        print('Counter: [%2d] [%4d%4d] time:%4.4f,d_loss:%.8f,g_loss:%.8f' %(counter, idx, batchidxs, time.time() - start_time, errD_fake+errD_real, errG) )
        
        if np.mod(counter,50)==1:
            samples,dl,gl=sess.run([G,d_loss,g_loss],feed_dict={z:sample_z,images:sample_images})
            save_images(samples,'samples\\')
            print('[Sample] d_loss: %.8f, g_loss: %.8f' % (dl, gl))
            
    

Counter: [ 2] [   03165] time:12.4142,d_loss:3.11594152,g_loss:0.47967890
Counter: [ 3] [   13165] time:24.7225,d_loss:2.80728006,g_loss:1.44114816
Counter: [ 4] [   23165] time:37.0987,d_loss:3.27091002,g_loss:0.19224194
Counter: [ 5] [   33165] time:49.3759,d_loss:1.40870380,g_loss:2.17443991
Counter: [ 6] [   43165] time:61.6372,d_loss:2.55670500,g_loss:0.27468506
Counter: [ 7] [   53165] time:73.7274,d_loss:3.36097765,g_loss:0.22207320
Counter: [ 8] [   63165] time:86.0206,d_loss:2.98257113,g_loss:0.74043840
Counter: [ 9] [   73165] time:98.3988,d_loss:3.56873131,g_loss:0.55239958
Counter: [10] [   83165] time:110.7801,d_loss:2.36280870,g_loss:0.46738738
Counter: [11] [   93165] time:122.9243,d_loss:1.87536061,g_loss:0.72975487
Counter: [12] [  103165] time:135.0685,d_loss:1.63017440,g_loss:0.89076966
Counter: [13] [  113165] time:147.4817,d_loss:1.14713836,g_loss:1.18704164
Counter: [14] [  123165] time:159.8360,d_loss:1.60891676,g_loss:0.62380171
Counter: [15] [  133165] time:172

KeyboardInterrupt: 

In [24]:
saver.save(sess,"checkpoint\\all_variables.chk")

'checkpoint\\all_variables.chk'

In [26]:
saver.restore(sess, "checkpoint\\all_variables.chk")
print(sess.run(tf.all_variables()))

Instructions for updating:
Please use tf.global_variables instead.
[array([[ 0.01102006,  0.01705633, -0.00060646, ..., -0.01157028,
        -0.02942483, -0.017461  ],
       [-0.00040834,  0.00715254,  0.00745347, ...,  0.03881093,
         0.03808177, -0.00855326],
       [ 0.0039945 , -0.02106941, -0.02665521, ...,  0.01640008,
        -0.00301098,  0.01563549],
       ..., 
       [ 0.03282744, -0.00428099,  0.0138113 , ...,  0.01004282,
        -0.00538335,  0.00819604],
       [-0.00334191,  0.0120117 , -0.02231237, ...,  0.00409787,
        -0.01726586,  0.02788056],
       [ 0.00291106, -0.00688444,  0.01736336, ..., -0.03953795,
         0.01012639, -0.01378447]], dtype=float32), array([-0.00129302, -0.00932913,  0.00516461, ...,  0.00381078,
        0.0111304 ,  0.0136718 ], dtype=float32), array([ -5.85746625e-03,  -8.46001133e-03,  -7.82437436e-03,
        -9.20925569e-03,   8.42535589e-03,   1.71357859e-02,
        -8.78435001e-03,  -1.21675991e-02,   8.15710635e-04,
     