In [None]:
import os
import argparse
import tensorflow as tf
from clustergan import ClusterGAN

parser = argparse.ArgumentParser(description='ClusterGAN')
parser.add_argument('--model_dir', type=str, 
                      default='./exp',
                      help='Directory in which the model is stored')
#parser.add_argument('--data_dir', type=str,
#                      default='../data',
#                      help='Directory in which the data is stored')
parser.add_argument('--is_training', type=bool, default=False, help='whether it is training or inferecing')
parser.add_argument('--n_cat', type=int, default=1, help='number of categorical variables')
parser.add_argument('--num_classes', type=int, default=10, help='dimension of categorical variables')
parser.add_argument('--dim_gen', type=int, default=30, help='continuous dim of latent variable')
parser.add_argument('--z_dim', type=int, default=40, help='random noise dim of latent variable')
parser.add_argument('--sampler', type=str, default='one_hot')
parser.add_argument('--batch_size', type=int, default=64, help='batch size')
parser.add_argument('--epoch', type=int, default=5000, help='epochs')
parser.add_argument('--saving_cycle', type=int, default=1, help='how often the model will be saved')
parser.add_argument('--d_lr', type=float, default=1e-4, help='learning rate for discriminator')
parser.add_argument('--g_lr', type=float, default=1e-4, help='learning rate for generator')
parser.add_argument('--gpu_num', type=str, default="1", help='gpu to be used')

args, unparsed = parser.parse_known_args()
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"]=args.gpu_num

model = ClusterGAN(args)
    
# open session 
# c = tf.ConfigProto()
# c.gpu_options.visible_device_list = args.gpu_num

# sess = tf.Session(config=c)
sess = tf.Session(config=tf.ConfigProto(
      allow_soft_placement=True, log_device_placement=True))
sess.run(tf.global_variables_initializer())

In [None]:
if args.model_dir is None:
    raise ValueError('Need to provide model directory')

# load model
model.load(sess, args.model_dir)

In [None]:
from tensorflow.examples.tutorials.mnist import input_data
import numpy as np
import matplotlib.pyplot as plt
from utils import sample_Z

In [None]:
""" give a mode and check the generated imgs with that mode(mode accuracy) """
for mode in range(args.num_classes):
    modes = [mode] * args.batch_size
    testz = sample_Z(args.batch_size, args.z_dim, args.sampler, args.num_classes, args.n_cat, modes)
    fake = sess.run(model.x_, feed_dict={model.z: testz})
    fake = np.reshape(fake, [-1, 28, 28])

    dir_path = os.path.join(model.test_dir, str(mode))
    if not os.path.exists(dir_path): os.makedirs(dir_path)
    for i in range(args.batch_size):
        plt.imshow(fake[i])
        plt.axis('off')
        plt.savefig(os.path.join(dir_path, str(i) + '.png'))

In [None]:
""" give real imgs and check how they get clustered(cluster accuracy) """
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
bx, bl = mnist.test.next_batch(args.batch_size)
bx = np.reshape(bx, [-1, 28, 28, 1])

zhats_gen, zhats_label = sess.run([model.z_infer_gen, model.z_infer_label], feed_dict={model.x : bx})

mode2label = [6,5,0,3,2,9,7,4,8,1]

acc = 0
for i in range(args.batch_size):
    if(np.argmax(bl[i]) == mode2label[np.argmax(zhats_label[i])]):
        acc += 1
acc = acc / args.batch_size * 100
print(acc)

In [None]:
""" reconstruct x_ from z, which is a latent representation of x(reconstruction accuracy) """
testz = np.concatenate((zhats_gen, zhats_label), axis=1)
recon = sess.run(model.x_, feed_dict={model.z: testz})
recon = np.reshape(recon, [-1, 28, 28]) # reconstructed image
plt.imshow(recon[0])
bx = np.reshape(bx, [-1, 28, 28]) # original image
plt.imshow(bx[0])