# Invertible Neural Network Tutorial

### Model to use: Real-NVP (https://arxiv.org/abs/1605.08803)



### Environment setup and load dataset

In [1]:
import numpy as np
import tensorflow as tf
import os
import matplotlib.pyplot as plt
import pandas as pd
%matplotlib inline

# Do not use GPU
os.environ['CUDA_VISIBLE_DEVICES']=''

'''

(X_train, Y_train), (X_test, Y_test) = tf.keras.datasets.mnist.load_data(path='mnist.npz')

print(X_train.shape)
print(Y_train.shape)

plt.imshow(X_train[0])
plt.show()
'''

dataS = pd.read_csv('Dataset/PCC/MLDataset_quat.csv')
dataS = dataS.drop('Unnamed: 0',axis=1)
X = dataS.iloc[:,:7].values
Y = dataS.iloc[:,7:].values

### Invertible Neural Networks

In [2]:
# Define network for calculating s and t
def dense_resnet(inputs, mid_channels, output_channels, num_blocks):
    assert len(inputs.shape) == 2
    def _blocks(_x, name):
        shortcut = _x
        _x = tf.layers.dense(_x, mid_channels, activation='relu', name=name+'1')
        _x = tf.layers.dense(_x, mid_channels, activation='relu', name=name+'2')
        return _x + shortcut
    
    inputs = tf.layers.dense(inputs, mid_channels, activation='relu', name='initial')
    
    for i in range(num_blocks):
        inputs = _blocks(inputs, '{}'.format(i))
    inputs = tf.layers.dense(inputs, mid_channels, activation=None, name='final')
    return inputs
        
#Define mask
def get_mask(inputs, reverse_mask, data_format='NHWC', dtype=tf.float32):
    shape = inputs.get_shape().as_list()
    if len(shape) == 2:
        N = shape[-1]
        range_n = tf.range(N)
        odd_ind = tf.mod(range_n, 2)
        
        odd_ind = tf.reshape(odd_ind, [-1, N])
        checker = odd_ind
        
    
    elif len(shape) == 4:
        H = shape[2] if data_format == 'NCHW' else shape[1]
        W = shape[3] if data_format == 'NCHW' else shape[2]
               
        range_h = tf.range(H)
        range_w = tf.range(W)
        
        odd_ind_h = tf.cast(tf.mod(range_h, 2), dtype=tf.bool)
        odd_ind_w = tf.cast(tf.mod(range_w, 2), dtype=tf.bool)
        
        odd_h = tf.tile(tf.expand_dims(odd_ind_h, -1), [1, W])
        odd_w = tf.tile(tf.expand_dims(odd_ind_w,  0), [H, 1])
                
        checker = tf.logical_xor(odd_h, odd_w)
        
        reshape = [-1, 1, H, W] if data_format == 'NCHW' else [-1, H, W, 1]
        checker = tf.reshape(checker, reshape)
        
    
    else:
        raise ValueError('Invalid tensor shape. Dimension of the tensor shape must be '
                         '2 (NxD) or 4 (NxCxHxW or NxHxWxC), got {}.'.format(inputs.get_shape().as_list()))
        
        
    if checker.dtype != dtype:
        checker = tf.cast(checker, dtype)
        
    if reverse_mask:
        checker = 1. - checker
        
    return checker

# Define coupling layer
def coupling_layer(inputs, mid_channels, num_blocks, reverse_mask, name='coupling_layer', backward=False, reuse=None):
    mask = get_mask(inputs, reverse_mask)
    with tf.variable_scope(name) as scope:
        if reuse:
            scope.reuse_variables()
            
        if backward:
            v1 = inputs * mask
            v2 = inputs * (1-mask)
            with tf.variable_scope('st1'):
                st1 = dense_resnet(inputs=v1, mid_channels=mid_channels, output_channels=inputs.get_shape().as_list()[1]*2, num_blocks=3)
                s1 = st1[:, 0:tf.shape(inputs)[1]]
                rescale1 = tf.get_variable('rescale_s', shape=[inputs.get_shape().as_list()[1]], dtype=tf.float32, initializer=tf.constant_initializer(1.))
                s1 = rescale1 * tf.nn.tanh(s1)
                t1 = st1[:, tf.shape(inputs)[1]:tf.shape(inputs)[1]*2]
                
            u2 = (1-mask)*(v2 - t1)*tf.exp(-s1)
        
            with tf.variable_scope('st2'):
                st2 = dense_resnet(inputs=u2, mid_channels=mid_channels, output_channels=inputs.get_shape().as_list()[1]*2, num_blocks=3)
                s2 = st2[:, 0:tf.shape(inputs)[1]]
                rescale2 = tf.get_variable('rescale_s', shape=[inputs.get_shape().as_list()[1]], dtype=tf.float32, initializer=tf.constant_initializer(1.))
                s2 = rescale2 * tf.nn.tanh(s2)
                t2 = st2[:, tf.shape(inputs)[1]:tf.shape(inputs)[1]*2]
                
            u1 = mask * (v1 - t2)*tf.exp(-s2)
            inputs = u1 + u2
        
        else:
            u1 = inputs * mask
            u2 = inputs * (1-mask)
        
            with tf.variable_scope('st2'):
                st2 = dense_resnet(inputs=u2, mid_channels=mid_channels, output_channels=inputs.get_shape().as_list()[1]*2, num_blocks=3)
                s2 = st2[:, 0:tf.shape(inputs)[1]]
                rescale2 = tf.get_variable('rescale_s', shape=[inputs.get_shape().as_list()[1]], dtype=tf.float32, initializer=tf.constant_initializer(1.))
                s2 = rescale2 * tf.nn.tanh(s2)
                t2 = st2[:, tf.shape(inputs)[1]:tf.shape(inputs)[1]*2]
        
            v1 = mask * (u1 * tf.exp(s2) + t2)
            
            with tf.variable_scope('st1'):
                st1 = dense_resnet(inputs=v1, mid_channels=mid_channels, output_channels=inputs.get_shape().as_list()[1]*2, num_blocks=3)
                s1 = st1[:, 0:tf.shape(inputs)[1]]
                rescale1 = tf.get_variable('rescale_s', shape=[inputs.get_shape().as_list()[1]], dtype=tf.float32, initializer=tf.constant_initializer(1.))
                s1 = rescale1 * tf.nn.tanh(s1)
                t1 = st1[:, tf.shape(inputs)[1]:tf.shape(inputs)[1]*2]
        
        
            v2 = (1-mask) * (u2 * tf.exp(s1) + t1)
            inputs = v1 + v2
        
        return inputs
    
# Code from https://github.com/chrischute/real-nvp
def preprocess(x):
    data_constraint = 0.9
    y = (x*255. + tf.random.uniform(tf.shape(x), 0, 1))/256.
    y = (2 * y - 1) * data_constraint
    y = (y + 1) / 2
    y = tf.log(y) - tf.log(1-y)
    
    ldj = tf.nn.softplus(y) + tf.nn.softplus(-y) - tf.nn.softplus(tf.log(1-data_constraint) - tf.log(data_constraint))
    sldj = tf.reduce_sum(tf.reshape(ldj, [tf.shape(ldj)[0], -1]), axis=-1)
    return y, sldj

    
def real_nvp(inputs, mid_channels, backward=False, reuse=False):
#    
    x = inputs
    if backward:
       
        x = coupling_layer(x, mid_channels, 4, reverse_mask=True, name='c4', backward=backward, reuse=reuse)
        x = coupling_layer(x, mid_channels, 4, reverse_mask=False, name='c3', backward=backward, reuse=reuse)
        x = coupling_layer(x, mid_channels, 4, reverse_mask=True, name='c2', backward=backward, reuse=reuse)
        x = coupling_layer(x, mid_channels, 4, reverse_mask=False, name='c1', backward=backward, reuse=reuse)
    else:
        
#        x, sldj = preprocess(inputs)
        x = coupling_layer(x, mid_channels, 4, reverse_mask=False, name='c1', backward=backward, reuse=reuse)
        x = coupling_layer(x, mid_channels, 4, reverse_mask=True, name='c2', backward=backward, reuse=reuse)
        x = coupling_layer(x, mid_channels, 4, reverse_mask=False, name='c3', backward=backward, reuse=reuse)
        x = coupling_layer(x, mid_channels, 4, reverse_mask=True, name='c4', backward=backward, reuse=reuse)
#    x = tf.nn.sigmoid(x)
    return x


In [3]:
forward_inputs = tf.placeholder(tf.float32, [None, 4])
output = real_nvp(forward_inputs, 8, backward=False, reuse=False)

backward_inputs = tf.placeholder(tf.float32, [None, 4])
#x_restored = real_nvp(backward_inputs, 8, backward=True, reuse=True)
x_restored = real_nvp(output, 8, backward=True, reuse=True)

sess = tf.Session()
sess.run(tf.global_variables_initializer())
X_toy, _ = sess.run(preprocess(X_toy))

output_ = sess.run(output, feed_dict={forward_inputs:X_toy})
#print(sess.run(preprocess(forward_inputs), feed_dict={forward_inputs:X_toy}))
#restored_ = sess.run(x_restored, feed_dict={backward_inputs:output_[0]})
restored_ = sess.run(x_restored, feed_dict={forward_inputs:X_toy})
print('Original: \n',X_toy)
print('Forward: \n', output_)
print('Restored: \n', restored_)


#restored_ = sess.run(restored)

Instructions for updating:
Use keras.layers.dense instead.
Instructions for updating:
Colocations handled automatically by placer.
Original: 
 [[-0.39975804 -0.03595382  0.10423118  2.5258074 ]
 [-0.64843106 -1.3048775   0.8708211   0.03910542]]
Forward: 
 [[  2.719318     0.14667463  -0.80206853   4.1788845 ]
 [  6.4994426  -10.969748    -3.5650327   -3.517817  ]]
Restored: 
 [[-0.39975777 -0.03595387  0.10423115  2.5258079 ]
 [-0.64843154 -1.3048768   0.8708212   0.03910562]]
