In [None]:
class MaskedAutoregressiveFlow:
    """
    Implements a Masked Autoregressive Flow, which is a stack of mades such that the random numbers which drive made i
    are generated by made i-1. The first made is driven by standard gaussian noise. In the current implementation, all
    mades are of the same type. If there is only one made in the stack, then it's equivalent to a single made.
    """

    def __init__(self, n_inputs, n_hiddens, act_fun, n_mades, batch_norm=False,
                 input_order='sequential', mode='sequential', input=None):
        """
        Constructor.
        :param n_inputs: number of inputs
        :param n_hiddens: list with number of hidden units for each hidden layer
        :param act_fun: tensorflow activation function
        :param n_mades: number of mades
        :param batch_norm: whether to use batch normalization between mades
        :param input_order: order of inputs of last made
        :param mode: strategy for assigning degrees to hidden nodes: can be 'random' or 'sequential'
        :param input: tensorflow placeholder to serve as input; if None, a new placeholder is created
        """

        # save input arguments
        self.n_inputs = n_inputs
        self.n_hiddens = n_hiddens
        self.act_fun = act_fun
        self.n_mades = n_mades
        self.batch_norm = batch_norm
        self.momentum = momentum
        self.mode = mode

        self.input = tf.placeholder(dtype=dtype,shape=[None,n_inputs],name='x') if input is None else input
        self.training = tf.placeholder_with_default(False,shape=(),name="training")
        self.parms = []

        self.mades = []
        self.bns = []
        self.moments = []
        self.assign_bns = []
        self.u = self.input
        self.logdet_dudx = 0.0

        for i in range(n_mades):

            # create a new made
            made = mades.GaussianMade(n_inputs, n_hiddens, act_fun, input_order, mode, self.u)
            self.mades.append(made)
            self.parms += made.parms
            # invert input order
            input_order = input_order if input_order == 'random' else made.input_order[::-1]

            # inverse autoregressive transform
            self.u = made.u
            self.logdet_dudx += 0.5 * tf.reduce_sum(made.logp, axis=1,keepdims=True)

            # batch normalization
            if batch_norm:
                bn = BatchNormalization()
                moments = tf.nn.moments(self.u,[0])
                v_tmp = moments[1]
                self.u = bn(self.u,training=self.training)
                self.parms += [bn.loggamma,bn.beta]
                v_tmp = tf.cond(self.training,lambda:v_tmp,lambda:bn.variance)
                self.logdet_dudx += tf.reduce_sum(bn.loggamma) - 0.5 * tf.reduce_sum(tf.log(v_tmp+1e-5))
                self.bns.append(bn)
                self.moments.append(moments)
                self.assign_bns.append(tf.assign(bn.mean,moments[0]))
                self.assign_bns.append(tf.assign(bn.variance,moments[1]))

        self.input_order = self.mades[0].input_order

        # log likelihoods
        self.L = tf.add(-0.5 * n_inputs * np.log(2 * np.pi) - 0.5 * tf.reduce_sum(self.u ** 2, axis=1,keepdims=True),
                        self.logdet_dudx,name='L')

        # train objective
        self.trn_loss = -tf.reduce_mean(self.L,name='trn_loss')

    def eval(self, x, sess, log=True, training=False):
        """
        Evaluate log probabilities for given inputs.
        :param x: data matrix where rows are inputs
        :param sess: tensorflow session where the graph is run
        :param log: whether to return probabilities in the log domain
        :param training: in training, data mean and variance is used for batchnorm
                         while outside training the saved mean and variance is used
        :return: list of log probabilities log p(x)
        """        

        lprob = sess.run(self.L,feed_dict={self.input:x,self.training:training})

        return lprob if log else np.exp(lprob)
    
    def update_batch_norm(self,x,sess):
        """
        Updates batch normalization moments with the values obtained in data set x.
        :param x: data matrix whose moments will be used for the update
        :param sess: tensorflow session where the graph is run
        :return: None
        """
        sess.run(self.assign_bns,feed_dict={self.input:x,self.training:True})
        

    def gen(self, sess, n_samples=1, u=None):
        """
        Generate samples, by propagating random numbers through each made.
        :param sess: tensorflow session where the graph is run
        :param n_samples: number of samples
        :param u: random numbers to use in generating samples; if None, new random numbers are drawn
        :return: samples
        """

        x = rng.randn(n_samples, self.n_inputs) if u is None else u

        if getattr(self, 'batch_norm', False):

            for made, bn in zip(self.mades[::-1], self.bns[::-1]):
                x = bn.eval_inv(sess,x)
                x = made.gen(sess,n_samples, x)

        else:

            for made in self.mades[::-1]:
                x = made.gen(sess,n_samples, x)

        return x

    def calc_random_numbers(self, x, sess):
        """
        Givan a dataset, calculate the random numbers used internally to generate the dataset.
        :param x: numpy array, rows are datapoints
        :param sess: tensorflow session where the graph is run
        :return: numpy array, rows are corresponding random numbers
        """

        return sess.run(self.u,feed_dict={self.input:x})