A working implementation of the AAE in Tensorflow as in the old GitHub repository.

In [1]:
import numpy as np
import tensorflow as tf
from tqdm import tqdm

  from ._conv import register_converters as _register_converters


In [2]:
# Load the train and test files

train = 'train_aae.txt'
test = 'test_aae.txt'

```batch_gen``` generates the batches and randomly shuffles them and ```buffered_gen``` generates the buffer and uses it to store the batches.

In [3]:
def batch_gen(data, batch_n):
    """
    Given the data, returns the batches using random shuffling.
    
    Parameters
    ----------
    data: np.array
        Consists of all the instances set into a numpy array
    batch_n: int
        Size of each batch
        
    Returns
    -------
    data: Generator object
        A generator object with all the batches in it.
    
    """
    # Create a list of indices for the data,
    # which is of the same size as the data.
    # For eg,
    #     If data is of size 1000,
    #     inds = [0, 1, ..., 999]
    inds = list(range(data.shape[0]))
    
    # Randomly shuffle the indices.
    # For eg,
    #     inds = [650, 720, ..., 2]
    np.random.shuffle(inds)
    
    # Generate size of data / batch size
    # number of batches.
    # For eg,
    #     If data is of size 1000 and 
    #     batch_n = 50, then i will be
    #     in range 0, 19
    for i in range(int(data.shape[0] / batch_n)):
        
        # Split inds according to the batches
        # created by i
        ii = inds[i*batch_n:(i+1)*batch_n]
        
        # Return a generator with each index
        # in ii matching that tuple from data.
        # For eg,
        #     If data = [[0, 1], [1, 2], ..., [999, 1000]]
        #     and ii = [50, 55, 1] (and batch_n = 3)
        #     then return [[50, 51], [55, 56], [1, 2]]
        yield data[ii, :]

In [4]:
def buffered_gen(f, batch_n=1024, buffer_size=2000):
    """
    Creates the batches by reading a file 'f', where the
    data is stored. 
    
    The data is stored in a buffer.
    
    Parameters
    ----------
    
    f: str
        String containing address of file
        
    batch_n: int
        Size of each batch
        
    buffer_size: int
        Size of the buffer. Denotes total number
        of batches which can be possibly stored in
        the buffer.
        
    Returns
    -------
    A generator object with all the batches in it. 
    
    """
    
    # Open file f
    inp = open(f)
    
    # Create new data list
    data = []
    
    # i = index of line, line is the line read
    for i, line in enumerate(inp):
        
        # For each line,
        # the line is first stripped, and the split according to tabs
        # the first element is read into an array and each input of it
        # is converted to a float. This is then appended to data.
        data.append(np.array(list(map(float, line.strip().split('\t')[1]))))
        
        # If size of buffer is finished, that is the buffer can store
        # data only uptill the number of instances = buffer_size * batch_n
        # and if the next instance fills up the buffer, then...
        if (i+1) % (buffer_size * batch_n) == 0:
            
            # Generate batches for whatever has been stored in data so far
            bgen = batch_gen(np.vstack(data), batch_n)
            
            # Yield the batches
            for batch in bgen:
                yield batch
            
            # Empty data
            data = []
            
    else:
        # Generate batches while leaving out the last element of data
        bgen = batch_gen(np.vstack(data[:-1]), batch_n)

        # Yield the batches
        for batch in bgen:
            yield batch

To load the data from the test samples.

In [5]:
def load_test():
    with open(test) as inp:
        data = [np.array(list(map(float, line.strip().split('\t')[1]))) for line in inp]
    return np.vstack(data)

Utility functions for AAE

In [6]:
def he_initializer(size):
    return tf.random_normal_initializer(mean=0.0, stddev=np.sqrt(1. / size), seed=None, dtype=tf.float32)

In [7]:
def linear_layer(tensor, input_size, out_size, init_fn=he_initializer,):
    W = tf.get_variable('W', shape=[input_size, out_size], initializer=init_fn(input_size))
    b = tf.get_variable('b', shape=[out_size], initializer=tf.constant_initializer(0.1))
    return tf.add(tf.matmul(tensor, W), b)

In [8]:
def bn_layer(tensor, size, epsilon=0.0001):
    batch_mean, batch_var = tf.nn.moments(tensor, [0])
    scale = tf.get_variable('scale', shape=[size], initializer=tf.constant_initializer(1.))
    beta = tf.get_variable('beta', shape=[size], initializer=tf.constant_initializer(0.))
    return tf.nn.batch_normalization(tensor, batch_mean, batch_var, beta, scale, epsilon)

In [9]:
def sample_prior(loc=0., scale=1., size=(64, 10)):
    return np.random.normal(loc=loc, scale=scale, size=size)

Actual implementation of the AAE

In [10]:
class AAE(object):
    def __init__(self,
                 gpu_config=None,
                 batch_size=1024, 
                 input_space=167,
                 latent_space=20,
                 middle_layers=None,
                 activation_fn=tf.nn.tanh,
                 learning_rate=0.001,
                 initializer=he_initializer):

        self.batch_size = batch_size
        self.input_space = input_space
        self.latent_space = latent_space
        if middle_layers is None:
            self.middle_layers = [256, 256]
        else:
            self.middle_layers = middle_layers
        self.activation_fn = activation_fn
        self.learning_rate = learning_rate

        self.initializer = initializer

        tf.reset_default_graph()
        
        self.input_x = tf.placeholder(tf.float32, [None, input_space])
        self.z_tensor = tf.placeholder(tf.float32, [None, latent_space])

        # Encoder net: 152->256->256->10
        with tf.variable_scope("encoder"):
            self.encoder_layers = self.encoder()
            self.encoded = self.encoder_layers[-1]
        
        # Decoder net: 10->256->256->152
        with tf.variable_scope("decoder"):
            self.decoder_layers = self.decoder(self.encoded)
            self.decoded = self.decoder_layers[-1]
            tf.get_variable_scope().reuse_variables()
            self.generator_layers = self.decoder(self.z_tensor)
            self.generated = tf.nn.sigmoid(self.generator_layers[-1])

        # Discriminator net: 10->64->64->8->1
        sizes = [64, 64, 8, 1]
        with tf.variable_scope("discriminator"):
            self.disc_layers_neg = self.discriminator(self.encoded, sizes)
            self.disc_neg = self.disc_layers_neg[-1]
            tf.get_variable_scope().reuse_variables()
            self.disc_layers_pos = self.discriminator(self.z_tensor, sizes)
            self.disc_pos = self.disc_layers_pos[-1]

        self.pos_loss = tf.nn.relu(self.disc_pos) - self.disc_pos + tf.log(1.0 + tf.exp(-tf.abs(self.disc_pos)))
        self.neg_loss = tf.nn.relu(self.disc_neg) + tf.log(1.0 + tf.exp(-tf.abs(self.disc_neg)))
        self.disc_loss = tf.reduce_mean(tf.add(self.pos_loss, self.neg_loss))
            
        self.enc_loss = tf.reduce_mean(tf.nn.relu(self.disc_neg) - self.disc_neg + tf.log(1.0 + tf.exp(-tf.abs(self.disc_neg))))
        batch_logloss = tf.reduce_sum(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.decoded, labels=self.input_x), 1)
        self.dec_loss = tf.reduce_mean(batch_logloss)
        
        disc_ws = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='discriminator')
        enc_ws = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='encoder')
        ae_ws = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='encoder') + \
                tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='decoder')
            
        self.train_discriminator = tf.train.AdamOptimizer(self.learning_rate).minimize(self.disc_loss, var_list=disc_ws)
        self.train_encoder = tf.train.AdamOptimizer(self.learning_rate).minimize(self.enc_loss, var_list=enc_ws)
        self.train_autoencoder = tf.train.AdamOptimizer(self.learning_rate).minimize(self.dec_loss, var_list=ae_ws)

        if gpu_config is None:
            gpu_config = tf.ConfigProto()
            gpu_config.gpu_options.per_process_gpu_memory_fraction = 0.4
        
        self.sess = tf.Session(config=gpu_config)
        self.init_net()
        
    def encoder(self):
        sizes = self.middle_layers + [self.latent_space]
        with tf.variable_scope("layer-0"):
            encoder_layers = [linear_layer(self.input_x, self.input_space, sizes[0])]
        for i in range(len(sizes) - 1):
            with tf.variable_scope("layer-%i" % (i+1)):
                activated = self.activation_fn(encoder_layers[-1])
                normed = bn_layer(activated, sizes[i])
                next_layer = linear_layer(normed, sizes[i], sizes[i+1])
            encoder_layers.append(next_layer)
            
        return encoder_layers

    def decoder(self, tensor):
        sizes = self.middle_layers[::-1] + [self.input_space]
        with tf.variable_scope("layer-0"):
            decoder_layers = [linear_layer(tensor, self.latent_space, sizes[0])]
        for i in range(len(sizes) - 1):
            with tf.variable_scope("layer-%i" % (i+1)):
                activated = self.activation_fn(decoder_layers[-1])
                normed = bn_layer(activated, sizes[i])
                next_layer = linear_layer(normed, sizes[i], sizes[i+1])
            decoder_layers.append(next_layer)
        
        return decoder_layers
    
    def discriminator(self, tensor, sizes):
        with tf.variable_scope("layer-0"):
            disc_layers = [linear_layer(tensor, self.latent_space, sizes[0])]
        for i in range(len(sizes) - 1):
            with tf.variable_scope("layer-%i" % (i+1)):
                activated = tf.nn.tanh(disc_layers[-1])
                normed = bn_layer(activated, sizes[i])
                next_layer = linear_layer(normed, sizes[i], sizes[i+1])
            disc_layers.append(next_layer)

        return disc_layers
    
    def init_net(self):
        init = tf.global_variables_initializer()
        self.sess.run(init)        
    
    def train(self, log):
        sess = self.sess
        saver = tf.train.Saver()
        hist = []
        test_data = load_test()
        
        for e in tqdm(range(5)):
            print (log, "epoch #%d" % (e+1))
            log.flush()
            train_gen = buffered_gen(train, batch_n=self.batch_size)
            for i, batch_x in enumerate(train_gen):
                if i%3 == 0:
                    batch_z = sample_prior(scale=1.0, size=(len(batch_x), self.latent_space))
                    sess.run(self.train_discriminator, feed_dict={self.input_x: batch_x, self.z_tensor: batch_z})
                elif i%3 == 1:
                    sess.run(self.train_encoder, feed_dict={self.input_x: batch_x})
                else:
                    sess.run(self.train_autoencoder, feed_dict={self.input_x: batch_x})
                if i%10000 == 0:
                    batch_z = sample_prior(scale=1.0, size=(len(test_data), self.latent_space))
                    losses = sess.run([self.disc_loss, self.enc_loss, self.dec_loss],
                                      feed_dict={self.input_x: test_data, self.z_tensor: batch_z})
                    discriminator_loss, encoder_loss, decoder_loss = losses
                    print (log, "disc: %f, encoder : %f, decoder : %f" % (discriminator_loss/2., encoder_loss, decoder_loss))
                    log.flush()
            else:
                saver.save(sess, './fpt.aae.%de.model.ckpt' % e)
                batch_z = sample_prior(scale=1.0, size=(len(test_data), self.latent_space))
                losses = sess.run([self.disc_loss, self.enc_loss, self.dec_loss],
                                  feed_dict={self.input_x: test_data, self.z_tensor: batch_z})


                discriminator_loss, encoder_loss, decoder_loss = losses
                print (log, "disc: %f, encoder : %f, decoder : %f" % (discriminator_loss/2., encoder_loss, decoder_loss))
                log.flush()
                hist.append(decoder_loss)
        return hist
    
    def load(self, model):
        saver = tf.train.Saver()
        saver.restore(self.sess, model)

In [11]:
aae = AAE(batch_size=1024)

In [12]:
with open('./fpt.aae.log', 'w') as log:
     aae_0 = aae.train(log)

  0%|          | 0/5 [00:00<?, ?it/s]

<_io.TextIOWrapper name='./fpt.aae.log' mode='w' encoding='UTF-8'> epoch #1
<_io.TextIOWrapper name='./fpt.aae.log' mode='w' encoding='UTF-8'> disc: 0.790111, encoder : 0.739086, decoder : 136.339584


 20%|██        | 1/5 [00:19<01:18, 19.67s/it]

<_io.TextIOWrapper name='./fpt.aae.log' mode='w' encoding='UTF-8'> disc: 0.694194, encoder : 0.693319, decoder : 34.655224
<_io.TextIOWrapper name='./fpt.aae.log' mode='w' encoding='UTF-8'> epoch #2
<_io.TextIOWrapper name='./fpt.aae.log' mode='w' encoding='UTF-8'> disc: 0.694168, encoder : 0.693329, decoder : 34.655224


 40%|████      | 2/5 [00:37<00:56, 18.68s/it]

<_io.TextIOWrapper name='./fpt.aae.log' mode='w' encoding='UTF-8'> disc: 0.693558, encoder : 0.693285, decoder : 27.962450
<_io.TextIOWrapper name='./fpt.aae.log' mode='w' encoding='UTF-8'> epoch #3
<_io.TextIOWrapper name='./fpt.aae.log' mode='w' encoding='UTF-8'> disc: 0.693556, encoder : 0.693285, decoder : 27.962450


 60%|██████    | 3/5 [00:54<00:36, 18.33s/it]

<_io.TextIOWrapper name='./fpt.aae.log' mode='w' encoding='UTF-8'> disc: 0.693373, encoder : 0.693217, decoder : 27.356462
<_io.TextIOWrapper name='./fpt.aae.log' mode='w' encoding='UTF-8'> epoch #4
<_io.TextIOWrapper name='./fpt.aae.log' mode='w' encoding='UTF-8'> disc: 0.693373, encoder : 0.693216, decoder : 27.356462


 80%|████████  | 4/5 [01:13<00:18, 18.34s/it]

<_io.TextIOWrapper name='./fpt.aae.log' mode='w' encoding='UTF-8'> disc: 0.693290, encoder : 0.693200, decoder : 27.000544
<_io.TextIOWrapper name='./fpt.aae.log' mode='w' encoding='UTF-8'> epoch #5
<_io.TextIOWrapper name='./fpt.aae.log' mode='w' encoding='UTF-8'> disc: 0.693295, encoder : 0.693197, decoder : 27.000544


100%|██████████| 5/5 [01:31<00:00, 18.32s/it]

<_io.TextIOWrapper name='./fpt.aae.log' mode='w' encoding='UTF-8'> disc: 0.693246, encoder : 0.693184, decoder : 26.761305



