# Variational Inference with Normalizing Flows
[Paper Link](https://arxiv.org/abs/1505.05770#:~:text=Our%20approximations%20are%20distributions%20constructed,level%20of%20complexity%20is%20attained.)

## Import moduels 

In [1]:
import os 
import numpy as np 
from sklearn.model_selection import train_test_split

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

import matplotlib.pyplot as plt
import seaborn as sns
sns.set_style("whitegrid")
%matplotlib inline

## Model 

### Planar Flows
$ f(z) = z + u h(w^Tz + b) $

In [137]:
class PlanarFlow(layers.Layer):
    def __init__(self,input_dim,activation = 'tanh'):
        super().__init__()
        self.input_dim = input_dim
        self.activation, self.deriv_activation = self.getActivation(activation)

    def getActivation(self,activation):
        if activation == 'elu':
            f = tf.nn.elu
            df = lambda x: tf.ones_like(x)* (x >= 0) + tf.math.exp(x) * (x < 0)
            return f, df
        elif activation == 'tanh':
            f = tf.math.tanh
            df = lambda x: 1-tf.math.tanh(x)**2
            return f, df

    def build(self,input_dim):
        w_init = tf.random_normal_initializer()
        self.w = tf.Variable(
            initial_value=w_init(shape=(1,self.input_dim), dtype="float32"),
            trainable=True,
        )
        
        u_init = tf.random_normal_initializer()
        self.u = tf.Variable(
            initial_value=u_init(shape=(1,self.input_dim), dtype="float32"),
            trainable=True,
        )
        
        b_init = tf.random_normal_initializer()
        self.b = tf.Variable(
            initial_value=b_init(shape=(1,1), dtype="float32"), trainable=True
        )        

    def call(self,z):
        lin = tf.matmul(z,tf.transpose(self.w)) +self.b
        f = z + self.u*self.activation(lin)
        phi = self.deriv_activation(lin)*self.w
        log_det = tf.math.log(tf.math.abs(1+tf.matmul(phi,tf.transpose(self.u)) ))
        return phi, log_det

### Radial Flows
$ f(z) = z + \beta h (\alpha , r )(z - z_0 ) $

In [244]:
class RadialFlow(layers.Layer):
    def __init__(self,input_dim,activation = 'tanh'):
        super().__init__()
        self.input_dim = input_dim
        self.activation, self.deriv_activation = self.getActivation(activation)

    def getActivation(self,activation):
        if activation == 'elu':
            f = tf.nn.elu
            df = lambda x: tf.ones_like(x)* (x >= 0) + tf.math.exp(x) * (x < 0)
            return f, df
        elif activation == 'tanh':
            f = tf.math.tanh
            df = lambda x: 1-tf.math.tanh(x)**2
            return f, df

    def build(self,input_dim):
        z0_init = tf.random_normal_initializer()
        self.z0 = tf.Variable(
            initial_value=z0_init(shape=(1,self.input_dim), dtype="float32"),
            trainable=True,
        )
        
        log_alpha_init = tf.random_normal_initializer()
        self.log_alpha = tf.Variable(
            initial_value=log_alpha_init(shape=(1,1), dtype="float32"), trainable=True
        )     
        
        beta_init = tf.random_normal_initializer()
        self.beta = tf.Variable(
            initial_value=beta_init(shape=(1,1), dtype="float32"), trainable=True
        )        

    def call(self,z):
        z_sub = z-self.z0
        alpha = tf.math.exp(self.log_alpha)
        r = tf.norm(z_sub,axis=1)
        h = tf.transpose(1 / (alpha + r))
        f = z + self.beta * h * z_sub
        
        log_det = (self.input_dim-1)*tf.math.log(1+self.beta*h)+tf.math.log(1+self.beta*h\
                                          +self.beta - tf.transpose(self.beta*r/(alpha+r)**2))
        return f, log_det


### VAE_Flow

## Fit model 

In [None]:
model = getModel(10)
zh, log_det = model(np.random.uniform(0,1,(1,10)))
print(zh.shape)
print(log_det.shape)

## Sample model 

In [None]:
def getModel(dim):
    inputs = keras.layers.Input(dim)
    zh, log_det = PlanarFlow(dim)(inputs)
    model = keras.models.Model(inputs=[inputs], outputs=[zh, log_det])
    return model