In [1]:
import tensorflow as tf
import numpy as np
import os
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec

In [2]:
import pickle
mb_size = 16
Z_dim = 100
x, t = [], []
i = 0
for i in range(5):
    path = "data_batch_" + str(i + 1)
    with open(path, 'rb') as f:
        batch = pickle.load(f,encoding='latin1')
    x.append(batch['data'])
    t.append(batch['labels'])
T = np.zeros((50000,10),dtype = np.int32)
x = np.concatenate(x)/np.float32(255)
t = np.concatenate(t).astype(np.int32)
for i in range(50000):
    a = t[i]
    T[i,a] = 1
h_dim = 128
print(x)
print(T)
print(x.shape)
print(T.shape)

[[0.23137255 0.16862746 0.19607843 ... 0.54901963 0.32941177 0.28235295]
 [0.6039216  0.49411765 0.4117647  ... 0.54509807 0.5568628  0.5647059 ]
 [1.         0.99215686 0.99215686 ... 0.3254902  0.3254902  0.32941177]
 ...
 [0.13725491 0.15686275 0.16470589 ... 0.3019608  0.25882354 0.19607843]
 [0.7411765  0.7294118  0.7254902  ... 0.6627451  0.67058825 0.67058825]
 [0.8980392  0.9254902  0.91764706 ... 0.6784314  0.63529414 0.6313726 ]]
[[0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 1]
 [0 0 0 ... 0 0 1]
 ...
 [0 0 0 ... 0 0 1]
 [0 1 0 ... 0 0 0]
 [0 1 0 ... 0 0 0]]
(50000, 3072)
(50000, 10)


In [3]:
def xavier_init(size):
    in_dim = size[0]
    xavier_stddev = 1. / tf.sqrt(in_dim / 2.)
    return tf.random_normal(shape=size, stddev=xavier_stddev)

In [4]:
X_dim = x.shape[1]
y_dim = T.shape[1]


In [5]:
X = tf.placeholder(tf.float32, shape=[None, 3072])

y = tf.placeholder(tf.float32, shape=[None, y_dim])

D_W1 = tf.Variable(xavier_init([X_dim + y_dim, h_dim]))
D_b1 = tf.Variable(tf.zeros(shape=[h_dim]))

D_W2 = tf.Variable(xavier_init([h_dim, 1]))
D_b2 = tf.Variable(tf.zeros(shape=[1]))

theta_D = [D_W1, D_W2, D_b1, D_b2]

In [6]:
Z = tf.placeholder(tf.float32, shape=[None, Z_dim])

G_W1 = tf.Variable(xavier_init([Z_dim + y_dim, h_dim]))
G_b1 = tf.Variable(tf.zeros(shape=[h_dim]))

G_W2 = tf.Variable(xavier_init([h_dim, X_dim]))
G_b2 = tf.Variable(tf.zeros(shape=[X_dim]))

theta_G = [G_W1, G_W2, G_b1, G_b2]

In [7]:
def discriminator(x1, y1):
    inputs = tf.concat(axis=1, values=[x1, y1])
    D_h1 = tf.nn.leaky_relu(tf.matmul(inputs, D_W1) + D_b1)
    D_logit = tf.matmul(D_h1, D_W2) + D_b2
    D_prob = tf.nn.tanh(D_logit)

    return D_prob, D_logit

In [8]:
def generator(z, y):
    inputs = tf.concat(axis=1, values=[z, y])
    G_h1 = tf.nn.relu(tf.matmul(inputs, G_W1) + G_b1)
    G_log_prob = tf.matmul(G_h1, G_W2) + G_b2
    G_prob = tf.nn.sigmoid(G_log_prob)

    return G_prob

In [9]:
def sample_Z(m, n):
    return np.random.uniform(-1., 1., size=[m, n])

In [11]:
def plot(samples):
    fig = plt.figure(figsize=(4,4 ))
    gs = gridspec.GridSpec(4,4)
    gs.update(wspace=0.05, hspace=0.05)

    for i, sample in enumerate(samples):
        ax = plt.subplot(gs[i])
        plt.axis('off')
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_aspect('equal')
        R = sample[0:1024].reshape(32,32)
        G = sample[1024:2048].reshape(32,32)
        B = sample[2048:].reshape(32,32)

        img = np.dstack((R,G,B)) 
        plt.imshow(img,  cmap='Greys_r')

    return fig

In [12]:
G_sample = generator(Z, y)
D_real, D_logit_real = discriminator(X, y)
D_fake, D_logit_fake = discriminator(G_sample, y)

D_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_logit_real, labels=tf.ones_like(D_logit_real)))
D_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_logit_fake, labels=tf.zeros_like(D_logit_fake)))
D_loss = D_loss_real + D_loss_fake
G_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_logit_fake, labels=tf.ones_like(D_logit_fake)))

D_solver = tf.train.AdamOptimizer(learning_rate = 1e-3).minimize(D_loss, var_list=theta_D)
G_solver = tf.train.AdamOptimizer(learning_rate = 1e-4).minimize(G_loss, var_list=theta_G)


In [13]:
if not os.path.exists('outcc1/'):
    os.makedirs('outcc1/')


In [None]:


for it in range(1000000):
    if it % 1000 == 0:
        for u in range(10):
            n_sample = 16

            Z_sample = sample_Z(16, Z_dim)
            y_sample = np.zeros(shape=[16, y_dim])
            y_sample[:, u] = 1

            samples = sess.run(G_sample, feed_dict={Z: Z_sample, y:y_sample})

            fig = plot(samples)
            plt.savefig('outcc1/{}.png'.format("a" + str(u) + str(i).zfill(3)), bbox_inches='tight')
            i += 1
            plt.close(fig)

    k = 0
    while k < 50000:
        start = k
        end = k + mb_size
        X_mb = x[start:end,:]                
        y_mb = T[start:end,:]                
        k += mb_size
    Z_sample = sample_Z(mb_size, Z_dim)
    _, D_loss_curr = sess.run([D_solver, D_loss], feed_dict={X: X_mb, Z: Z_sample, y:y_mb})
    _, G_loss_curr = sess.run([G_solver, G_loss], feed_dict={Z: Z_sample, y:y_mb})

    if it % 1000 == 0:
        print('Iter: {}'.format(it))
        print('D loss: {:.4}'. format(D_loss_curr))
        print('G_loss: {:.4}'.format(G_loss_curr))
        print()

Iter: 0
D loss: 1.53
G_loss: 17.49

Iter: 1000
D loss: 0.1204
G_loss: 6.887

Iter: 2000
D loss: 0.2234
G_loss: 6.228

Iter: 3000
D loss: 0.1697
G_loss: 11.81

Iter: 4000
D loss: 0.3366
G_loss: 7.202

Iter: 5000
D loss: 0.1541
G_loss: 9.67

Iter: 6000
D loss: 0.1422
G_loss: 7.404

Iter: 7000
D loss: 0.9223
G_loss: 5.211

Iter: 8000
D loss: 0.3593
G_loss: 7.574

Iter: 9000
D loss: 0.3281
G_loss: 9.683

Iter: 10000
D loss: 0.1479
G_loss: 4.871

Iter: 11000
D loss: 0.2228
G_loss: 4.928

Iter: 12000
D loss: 0.1993
G_loss: 6.965

Iter: 13000
D loss: 0.2689
G_loss: 14.04

Iter: 14000
D loss: 0.4707
G_loss: 9.106

Iter: 15000
D loss: 0.03137
G_loss: 13.84

Iter: 16000
D loss: 0.211
G_loss: 8.142

Iter: 17000
D loss: 0.4862
G_loss: 12.9

Iter: 18000
D loss: 0.19
G_loss: 5.171

Iter: 19000
D loss: 0.2865
G_loss: 20.09

Iter: 20000
D loss: 0.2473
G_loss: 5.096

Iter: 21000
D loss: 0.06536
G_loss: 9.149

Iter: 22000
D loss: 0.2499
G_loss: 6.565

Iter: 23000
D loss: 0.4512
G_loss: 10.92

Iter: 2400

In [79]:
saver=tf.train.Saver()
save_path = saver.save(sess, "./model1.ckpt")

In [18]:
for i in range(10):
    n_sample = 16

    Z_sample = sample_Z(16, Z_dim)
    y_sample = np.zeros(shape=[16, y_dim])

    y_sample[:, i] = 1

    samples = sess.run(G_sample, feed_dict={Z: Z_sample, y:y_sample})

    fig = plot(samples)
    name = 'a' + str(i) + '.png'
    plt.savefig(name, bbox_inches='tight')
    plt.close(fig)