In [None]:
import os

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from tqdm import tqdm

from trainers import TrainerForVAEGAN
from utils import unpickle
from vae_gans import CIFAR_VAE_GAN

os.environ["CUDA_VISIBLE_DEVICES"] = "3"

datadir = './CIFAR10_data/'
batches = [datadir + batch for batch in os.listdir(datadir)]


In [None]:
batch = unpickle(batches[0])
data = batch[b'data'].astype(np.float32)
cifar = np.transpose(np.reshape(data, [-1, 3, 32, 32]), [0, 2, 3, 1])
labels = batch[b'labels']

for i in tqdm(range(1, 5)):
    batch = unpickle(batches[i])

    data = batch[b'data'].astype(np.float32)
    cifar = np.concatenate((cifar, np.transpose(np.reshape(data, [-1, 3, 32, 32]), [0, 2, 3, 1])), axis=0)
    labels = np.concatenate((labels, batch[b'labels']), axis=0)

scaled_cifar = cifar / 127.5 - 1.0

test_batch = unpickle(batches[5])
cifar_test = np.transpose(np.reshape(test_batch[b'data'], [-1, 3, 32, 32]), [0, 2, 3, 1])
scaled_cifar_test = cifar_test / 127.5 - 1.0
labels_test = np.array(test_batch[b'labels'])

data_train = (scaled_cifar, labels)
data_test = (scaled_cifar_test, labels_test)


In [None]:
plt.figure(figsize=(4, 4))

for i in range(16):
    image = (data_train[0][i] + 1) * 0.5

    plt.subplot(4, 4, i + 1)
    plt.imshow(image.reshape(32, 32, 3))
    plt.xticks([])
    plt.yticks([])

plt.suptitle('Training Data', fontsize=20, y=1.03)
plt.tight_layout()
plt.subplots_adjust(wspace=0.0, hspace=0.0)

plt.show()
plt.close()


In [None]:
tf.reset_default_graph()
vae_gan = CIFAR_VAE_GAN('tf_logs/exp3/vae-gan/data_test/', lmda=1e-2, zdim=128, learning_rate=2e-4, beta1=0.0,
                        beta2=0.9)

sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())

trainer = TrainerForVAEGAN(sess, vae_gan, data_train, n_dis=5)
trainer.train(150, p_epochs=1)
