In [1]:
import numpy as np
import matplotlib
matplotlib.use('Agg',warn=False)
import matplotlib.pyplot as plt
%matplotlib inline
import warnings
warnings.filterwarnings('ignore')
from IPython.display import Markdown as md


In [16]:
import os


fdir = './outs_test'
modeldir = './pretrained_model/'
ae_dir = './wae_metric/pretrained_model/'
batch_size = 100

if not os.path.exists(fdir):
    os.makedirs(fdir)

if not os.path.exists(modeldir):
    os.makedirs(modeldir)

In [4]:
from wae_metric.run_WAE import LATENT_SPACE_DIM, load_dataset
#these utilities are loaded from the autoencoder scripts to keep them consistent

jag_inp, jag_sca, jag_img = load_dataset('./data/')
LATENT_DIM = LATENT_SPACE_DIM

In [5]:
print('---------------Dataset Information---------------\nInput parameters: {}, Output Scalars: {}, Output Images: {}'.format(jag_inp.shape,
                                                                           jag_sca.shape,
                                                                           jag_img.shape))

---------------Dataset Information---------------
Input parameters: (10000, 5), Output Scalars: (10000, 15), Output Images: (10000, 16384)


## Create Test Train Splits

In [7]:
np.random.seed(4321) #this is the random seed used during training.  
tr_id = np.random.choice(jag_sca.shape[0],int(jag_sca.shape[0]*0.8),replace=False)
print(tr_id[:10])
te_id = list(set(range(jag_sca.shape[0])) - set(tr_id))


[7763 6764 6662 5371 7257 2963 1321 6730 9597 3155]


In [8]:
X_train = jag_inp[tr_id,:]
y_sca_train = jag_sca[tr_id,:]
y_img_train = jag_img[tr_id,:]

np.random.shuffle(te_id)

X_test = jag_inp[te_id,:]
y_sca_test = jag_sca[te_id,:]
y_img_test = jag_img[te_id,:]
y_img_test_mb = y_img_test[-100:,:]

y_img_test_mb = y_img_test_mb.reshape(100,64,64,4)

## Save Ground Truth Images in "fdir"

In [10]:
from utils import plot

for k in range(4):
        fig = plot(y_img_test_mb[:,:,:,k],immax=np.max(y_img_test_mb[:,:,:,k].reshape(-1,4096),axis=1),
                   immin=np.min(y_img_test_mb[:,:,:,k].reshape(-1,4096),axis=1))
        plt.savefig('{}/gt_img_{}_{}.png'
                    .format(fdir,str(k).zfill(3),str(k)), bbox_inches='tight')
        plt.close()


In [12]:
dim_x = X_train.shape[1]
dim_y_sca = y_sca_train.shape[1]
dim_y_img = y_img_train.shape[1]
dim_y_img_latent = LATENT_DIM #latent space


## Build the Computational Graph


In [17]:
from modelsv2 import cycModel_MM
import wae_metric.model_AVB as wae
import tensorflow as tf

tf.reset_default_graph()
y_sca = tf.placeholder(tf.float32, shape=[None, dim_y_sca])
y_img = tf.placeholder(tf.float32, shape=[None, dim_y_img])
x = tf.placeholder(tf.float32, shape=[None, dim_x])
train_mode = tf.placeholder(tf.bool,name='train_mode')

y_mm = tf.concat([y_img,y_sca],axis=1)

### 1. Map outputs (images, scalars) --> latent space with pre-trained autoencoder

y_latent_img = wae.gen_encoder_FCN(y_mm, dim_y_img_latent,train_mode)

### 2. Next, build the CycleGAN that learns to map input params <--> latent vector
cycGAN_params = {'input_params':x, 
                 'outputs':y_latent_img,
                 'param_dim':dim_x,
                 'output_dim':dim_y_img_latent,
                 'L_adv':1e-2, # controls "physical" consistency
                 'L_cyc':1e-1, # controls cyclical consustency
                 'L_rec':1.}   # controls fidelity of surrogate

JagNet_MM = cycModel_MM(**cycGAN_params)
JagNet_MM.run(train_mode)
### 3. Decode the predictions from the CycleGAN into output space of images and scalars
y_img_out = wae.var_decoder_FCN(JagNet_MM.output_fake, dim_y_img+dim_y_sca,train_mode)


In [18]:
t_vars = tf.global_variables()
m_vars = [var for var in t_vars if 'wae' in var.name]
metric_saver = tf.train.Saver(m_vars)
saver = tf.train.Saver(list(set(t_vars)-set(m_vars)))

sess = tf.Session()
sess.run(tf.global_variables_initializer())

ckpt = tf.train.get_checkpoint_state(modeldir)
ckpt_metric = tf.train.get_checkpoint_state(ae_dir)

if ckpt_metric and ckpt_metric.model_checkpoint_path:
       metric_saver.restore(sess, ckpt_metric.model_checkpoint_path)
       print("************ Image Metric Restored! **************")

if ckpt and ckpt.model_checkpoint_path:
    saver.restore(sess, ckpt.model_checkpoint_path)
    print("************ Model restored! **************")


INFO:tensorflow:Restoring parameters from ./wae_metric/pretrained_model/model_99999.ckpt
************ Image Metric Restored! **************
INFO:tensorflow:Restoring parameters from ./pretrained_model/model_99500.ckpt
************ Model restored! **************


## Train the network

In [None]:
from utils import test_imgs_plot

for it in range(50000):

        randid = np.random.choice(X_train.shape[0],batch_size,replace=False)
        x_mb = X_train[randid,:]
        y_img_mb = y_img_train[randid,:]
        y_sca_mb = y_sca_train[randid,:]

        fd = {x: x_mb, y_sca: y_sca_mb,y_img:y_img_mb,train_mode:True}
        _, dloss,gloss0,gloss1 = sess.run([JagNet_MM.D_solver,JagNet_MM.loss_disc,
                                           JagNet_MM.loss_gen0,JagNet_MM.loss_gen1],
                                          feed_dict=fd)
        _ = sess.run([JagNet_MM.G0_solver],feed_dict=fd)

        if it%100 == 0:
            print('Iter: {}; forward loss: {:.4}; inverse loss: {:.4}'
                  .format(it, gloss0, gloss1))
        if it%500==0:
            nTest = 16
            x_test_mb = X_test[-nTest:,:]
            samples,samples_x = sess.run([y_img_out,JagNet_MM.input_cyc],
                                       feed_dict={x: x_test_mb,train_mode:False})
            data_dict= {}
            data_dict['samples'] = samples
            data_dict['samples_x'] = samples_x
            data_dict['y_sca'] = y_sca_test
            data_dict['y_img'] = y_img_test
            data_dict['x'] = x_test_mb

            test_imgs_plot(fdir,it,data_dict)
            