In [14]:
from utils import *
import os
import random

In [2]:
IMAGE_DIM = [32, 32, 3]
BATCH_SIZE = 64
IMAGE_SIZE = 108
Z_SIZE = 100
GF_DIM = 32
DF_DIM = 32
GFC_DIM = 1024
DFC_DIM = 1024
C_DIM = 3
LEARNING_RATE = 0.0002
BETA1 = 0.5

In [3]:
# batch norm layers
d_bn1 = batch_norm(name="d_bn1")
d_bn2 = batch_norm(name="d_bn2")
d_bn3 = batch_norm(name="d_bn3")

g_bn0 = batch_norm(name="g_bn0")
g_bn1 = batch_norm(name="g_bn1")
g_bn2 = batch_norm(name="g_bn2")
g_bn3 = batch_norm(name="g_bn3")

In [4]:
def discriminator(image, reuse=None):
    with tf.variable_scope('discriminator') as scope:
        if reuse:
            scope.reuse_variables()

        h0 = leaky_relu(conv_layer(image, DF_DIM, name="d_h0"))
        h1 = leaky_relu(d_bn1(conv_layer(h0, DF_DIM * 2, name="d_h1")))
        h2 = leaky_relu(d_bn2(conv_layer(h1, DF_DIM * 4, name="d_h2")))
        h3 = leaky_relu(d_bn3(conv_layer(h2, DF_DIM * 8, name="d_h3")))
        h4 = dense(tf.reshape(h3, [BATCH_SIZE, -1]), 1, scope="d_h4")

        return tf.nn.sigmoid(h4), h4

In [5]:
def generator(z):
    with tf.variable_scope('generator') as scope:
        z1 = dense(z, outputF=2 * 2 * GF_DIM * 8, scope="g_h0")
        h0 = tf.nn.relu(g_bn0(tf.reshape(z1, [-1, 2, 2,GF_DIM * 8])))
        h1 = tf.nn.relu(g_bn1(conv_transpose_layer(h0, [BATCH_SIZE, 4, 4, GF_DIM * 4], name="g_h1")))
        h2 = tf.nn.relu(g_bn2(conv_transpose_layer(h1, [BATCH_SIZE, 8, 8, GF_DIM * 2], name="g_h2")))
        h3 = tf.nn.relu(g_bn3(conv_transpose_layer(h2, [BATCH_SIZE, 16, 16, GFC_DIM * 1], name="g_h3")))
        h4 = conv_transpose_layer(h3, [BATCH_SIZE, 32, 32, 3])

        return tf.nn.tanh(h4)

In [6]:
image_in = tf.placeholder(tf.float32, [BATCH_SIZE] + IMAGE_DIM, name="image_input")
z_in = tf.placeholder(tf.float32, [None, Z_SIZE], name="z")

G = generator(z_in)
D_real, D_real_logits = discriminator(image_in)
D_model, D_model_logits = discriminator(G, reuse=True)

In [7]:
d_real_loss = tf.nn.sigmoid_cross_entropy_with_logits(logits=D_real_logits, labels=tf.ones_like(D_real))
d_model_loss = tf.nn.sigmoid_cross_entropy_with_logits(logits=D_model_logits, labels=tf.zeros_like(D_model))

d_loss = d_real_loss + d_model_loss

In [8]:
g_loss = tf.nn.sigmoid_cross_entropy_with_logits(logits=D_model_logits, labels=tf.ones_like(D_model))

In [9]:
t = tf.trainable_variables()

g_vars = [var for var in t if "g_" in var.name]
d_vars = [var for var in t if "d_" in var.name]

In [10]:
saver = tf.train.Saver()

In [11]:
d_optim = tf.train.AdamOptimizer(LEARNING_RATE, beta1=BETA1).minimize(d_loss, var_list=d_vars)
g_optim = tf.train.AdamOptimizer(LEARNING_RATE, beta1=BETA1).minimize(g_loss, var_list=g_vars)

In [12]:
init  = tf.global_variables_initializer()

sess = tf.InteractiveSession()
sess.run(init)

In [22]:
count = 0

batch = unpickle('data/data_batch_1')
print(batch.keys())

for epoch in range(10):
    batch_idx = 3000
    for i in range(batch_idx):
        batchnum = random.randint(0,150)
        trainingData = batch[b'data'][batchnum*BATCH_SIZE:(batchnum+1)*BATCH_SIZE]
        trainingData = transform(trainingData, is_crop=False)
        batch_images = np.reshape(trainingData,(BATCH_SIZE,3,32,32))
        batch_images = np.swapaxes(batch_images,1,3)
        
        z = np.random.uniform(-1, 1, [BATCH_SIZE, Z_SIZE]).astype(np.float32)
        
        sess.run(d_optim, feed_dict={ image_in: batch_images, z_in: z })
        sess.run(g_optim, feed_dict={ z_in: z })
        
        loss_d = sess.run(d_loss, feed_dict={ image_in: batch_images, z_in: z })
        loss_g = sess.run(g_loss, feed_dict={ z_in: z })
        
        count += 1
        
        print('Epoch: {}, Iteration: {} d_loss: {}, g_loss: {}'.format(epoch, i, loss_d, loss_g))
        
        if count % 100 == 0:
            saver.save(sess, os.getcwd() + "/training", global_step=count)

dict_keys([b'data', b'filenames', b'labels', b'batch_label'])
Epoch: 0, Iteration: 0 d_loss: [[ 1.23776627]
 [ 1.15765929]
 [ 1.30966949]
 [ 1.77457833]
 [ 1.51890016]
 [ 1.48587155]
 [ 1.88778627]
 [ 2.80384827]
 [ 1.58188057]
 [ 1.17355335]
 [ 1.7306993 ]
 [ 2.02813292]
 [ 1.73157263]
 [ 1.71489406]
 [ 1.67143798]
 [ 1.81202984]
 [ 1.39193726]
 [ 1.91472149]
 [ 1.42351151]
 [ 1.52233624]
 [ 1.71970272]
 [ 1.9429245 ]
 [ 1.81131697]
 [ 1.8920387 ]
 [ 1.52665973]
 [ 1.87995315]
 [ 1.64005399]
 [ 1.73177528]
 [ 1.58247733]
 [ 1.94356751]
 [ 1.33981895]
 [ 1.29888284]
 [ 1.45193219]
 [ 1.60130501]
 [ 1.76810265]
 [ 1.48709226]
 [ 1.48702061]
 [ 1.02818441]
 [ 1.55724895]
 [ 1.41838169]
 [ 2.33544016]
 [ 1.56199491]
 [ 1.574512  ]
 [ 2.01990604]
 [ 2.42387056]
 [ 1.49511957]
 [ 1.99328423]
 [ 1.42451024]
 [ 1.24647236]
 [ 1.80301833]
 [ 1.86453795]
 [ 1.48942852]
 [ 1.65260231]
 [ 1.68276954]
 [ 1.85759115]
 [ 1.22307396]
 [ 2.39600897]
 [ 1.23188567]
 [ 1.73287392]
 [ 1.39085746]
 [ 1.36

Epoch: 0, Iteration: 5 d_loss: [[ 0.89905411]
 [ 0.88524616]
 [ 1.22018242]
 [ 1.64742899]
 [ 1.09752858]
 [ 1.05318904]
 [ 0.83346939]
 [ 0.98676968]
 [ 0.88859797]
 [ 0.98614925]
 [ 0.97261161]
 [ 1.42242599]
 [ 1.10698628]
 [ 0.86265063]
 [ 0.83775997]
 [ 0.8333596 ]
 [ 0.94855028]
 [ 1.00691152]
 [ 1.03235865]
 [ 0.77422869]
 [ 1.09914565]
 [ 1.25640965]
 [ 1.07218587]
 [ 1.06833684]
 [ 1.04537487]
 [ 0.90459663]
 [ 1.09662056]
 [ 0.89797616]
 [ 0.80924338]
 [ 0.89657688]
 [ 0.79911315]
 [ 0.82818699]
 [ 0.82603288]
 [ 0.76476312]
 [ 1.36911619]
 [ 0.93024981]
 [ 1.02577233]
 [ 0.7684592 ]
 [ 1.01113546]
 [ 1.43308759]
 [ 1.04616261]
 [ 1.10402226]
 [ 0.82869518]
 [ 0.75332409]
 [ 0.83207059]
 [ 0.80690038]
 [ 0.91613948]
 [ 1.3153863 ]
 [ 0.96501476]
 [ 1.17756259]
 [ 1.05309963]
 [ 1.40038097]
 [ 1.30269766]
 [ 1.01192033]
 [ 0.91112065]
 [ 1.05163765]
 [ 0.83272731]
 [ 1.06447577]
 [ 1.4388206 ]
 [ 0.69716704]
 [ 0.85935616]
 [ 0.85460424]
 [ 1.08003163]
 [ 0.96416283]], g_loss:

Epoch: 0, Iteration: 10 d_loss: [[ 1.24190712]
 [ 0.74844849]
 [ 0.74561942]
 [ 0.33155566]
 [ 0.74969304]
 [ 0.64617932]
 [ 0.64399195]
 [ 0.6836679 ]
 [ 0.71579379]
 [ 0.51795453]
 [ 0.78445637]
 [ 0.5276612 ]
 [ 0.5694021 ]
 [ 0.51464784]
 [ 0.68966776]
 [ 0.53291631]
 [ 0.37450615]
 [ 0.81360161]
 [ 0.37695506]
 [ 0.88960421]
 [ 0.63281035]
 [ 0.64127463]
 [ 0.63693291]
 [ 0.50389111]
 [ 0.63427591]
 [ 0.64571255]
 [ 0.56296533]
 [ 0.89645898]
 [ 0.61417317]
 [ 0.64553583]
 [ 0.58696336]
 [ 0.57745212]
 [ 0.69292027]
 [ 0.50265902]
 [ 0.52373862]
 [ 0.44776264]
 [ 0.52461177]
 [ 0.70488906]
 [ 0.78734899]
 [ 1.00560308]
 [ 0.65581316]
 [ 0.71637291]
 [ 0.38538623]
 [ 0.43711245]
 [ 0.76987576]
 [ 0.42612532]
 [ 0.5845018 ]
 [ 0.55833912]
 [ 0.75584841]
 [ 0.55733335]
 [ 0.89532804]
 [ 0.68659914]
 [ 0.75656348]
 [ 0.74894655]
 [ 0.5598892 ]
 [ 0.4994368 ]
 [ 0.59330881]
 [ 0.65786594]
 [ 0.58295649]
 [ 0.61794674]
 [ 0.42178476]
 [ 0.52262449]
 [ 0.42428026]
 [ 0.58271778]], g_loss

KeyboardInterrupt: 