In [None]:
import sys
from matplotlib import pyplot as plt
sys.path.append('../')
from dataset import *
from abstract_network import *
import time
from models import *
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '1'

In [None]:
dataset = MnistDataset(binary=False, one_hot=True)
x = tf.placeholder(tf.float32, [None] + dataset.data_dims)
c = tf.placeholder(tf.float32, [None, 10])

c_ = classifier(x, 10)
d = tf.nn.sigmoid(discriminator_cond(x, c))
d_ = tf.nn.sigmoid(discriminator_cond(x, c_, reuse=True))

# Gradient penalty
epsilon = tf.random_uniform([], 0.0, 1.0)
c_hat = epsilon * c + (1 - epsilon) * c_
d_hat = tf.nn.sigmoid(discriminator_cond(x, c_hat, reuse=True))

ddc = tf.gradients(d_hat, c_hat)[0]
ddc = tf.sqrt(tf.reduce_sum(tf.square(ddc), axis=1))
d_grad_loss = tf.reduce_mean(tf.square(ddc - 1.0) * 1.0)

d_confusion = tf.reduce_mean(d_) - tf.reduce_mean(d)
d_loss = d_confusion + d_grad_loss
g_loss = -tf.reduce_mean(d_)

In [None]:
optimal_c = tf.get_variable(shape=(100, 10), name='optimal_c')
optimal_c_init = tf.
optimal_c_class = tf.sigmoid(optimal_c)
d_oc = tf.nn.sigmoid(discriminator_cond(x, optimal_c_class, reuse=True))
d_oc_loss = -tf.reduce_mean(d_oc)
input_train = tf.train.GradientDescentOptimizer(learning_rate=1e-2).minimize(d_oc_loss, var_list=optimal_c)

In [None]:
d_vars = [var for var in tf.global_variables() if 'dc_net' in var.name]
g_vars = [var for var in tf.global_variables() if 'c_net' in var.name]
d_train = tf.train.GradientDescentOptimizer(learning_rate=5e-3).minimize(d_loss, var_list=d_vars)
g_train = tf.train.GradientDescentOptimizer(learning_rate=1e-4).minimize(g_loss, var_list=g_vars)

In [None]:
correct_prediction = tf.equal(tf.argmax(c_, 1), tf.argmax(c, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

dc_ = tf.gradients(d_, c_)[0]
dc = tf.gradients(d, c)[0]

In [None]:
sess = tf.Session(config=tf.ConfigProto(gpu_options=tf.GPUOptions(allow_growth=True)))
sess.run(tf.global_variables_initializer())

In [None]:
def print_stats():
    bx, bc = dataset.next_labeled_batch(500)
    vals = sess.run([g_loss, d_loss, d_confusion, accuracy, d_grad_loss], feed_dict={x: bx, c: bc})
    print("g_loss=%6.2f, d_loss=%6.2f, d_confusion=%6.2f, accuracy=%6.2f, grad_loss=%6.2f" % tuple(vals))
print_stats()

In [None]:
def sample_train(size=128):
    bx, bc = dataset.next_labeled_batch(size)
    bc = label_noise(bc)
    return bx, bc

def sample_test(size=128):
    bx, bc = dataset.next_labeled_test_batch(size)
    bc = label_noise(bc)
    return bx, bc

In [None]:
for idx in range(1, 10001):
    bx, bc = sample_train(128)
    sess.run([d_train, g_train], feed_dict={x: bx, c: bc})
    if idx % 500 == 0:
        print_stats()
        make_plots()

In [None]:
for idx in range(1000):
    bx, bc = sample_train(128)
    # for i in range(10):
    sess.run(d_train, feed_dict={x: bx, c: bc})
    sess.run(g_train, feed_dict={x: bx, c: bc})
    if idx % 100 == 0:
        print_stats()

In [None]:
def make_plots():
    dc_val, dc_val_, bc_ = sess.run([dc, dc_, c_], feed_dict={x: tbx, c: tbc})
    plot_args = {'interpolation':'none', 'cmap':'Greys'}
    plt.figure(figsize=(10, 10))
    plt.subplot(1, 4, 1)
    plt.imshow(dc_val[:20], **plot_args)
    plt.subplot(1, 4, 2)
    plt.imshow(dc_val_[:20],**plot_args)
    plt.subplot(1, 4, 3)
    plt.imshow(bc_[:20],**plot_args)
    plt.subplot(1, 4, 4)
    plt.imshow(tbc[:20],**plot_args)
    plt.show()

In [None]:
tbx, tbc = sample_train(128)