In [8]:
import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
from tensorflow_probability import distributions as tfd
from tensorflow_probability import bijectors as tfb
import utils as utl

In [20]:
# Test Bijectors
class GMC_bijector(tfb.Bijector):
    def __init__(self,ndims,ncomps,inputs,forward_min_event_ndims=1, validate_args: bool = False,name="gmc"):
        super(GMC_bijector, self).__init__(
            validate_args=validate_args, forward_min_event_ndims=forward_min_event_ndims, name=name
        )
        self.ndims = ndims
        self.ncomps = ncomps
        self.alpha = inputs[0]
        self.mu_vectors = inputs[1] 
        self.chol_matrices = inputs[2]
        cov_mat_array=tf.TensorArray(tf.float64,size=ncomps)
        for k in range(ncomps):
            cov_mat = tf.matmul(self.chol_matrices[k],tf.transpose(self.chol_matrices[k]))
            cov_mat_array =  cov_mat_array.write(k,cov_mat)         
        self.cov_matrices = cov_mat_array.stack()
    
    def _forward(self, x_mat):
        temp_array = tf.TensorArray(tf.float64,size=self.ndims)
        for j in range(ndims):
            u_cur = utl.gmm_cdf(x_mat[:,j],self.alpha,self.mu_vectors[:,j],self.cov_matrices[:,j,j])
            temp_array = temp_array.write(j,u_cur)
        u_mat = tf.transpose(temp_array.stack())
        return u_mat

    
    def _inverse(self, u_mat):
        temp_array = tf.TensorArray(tf.float64,size=self.ndims)
        for j in range(ndims):
            x_cur = utl.gmm_icdf(u_mat[:,j],self.alpha,self.mu_vectors[:,j],self.cov_matrices[:,j,j])
            temp_array = temp_array.write(j,x_cur)
        x_mat = tf.transpose(temp_array.stack())
        return x_mat

    
    def _forward_log_det_jacobian(self,x_mat):
        nobs = x_mat.get_shape().as_list()[0]
        forward_log_det_J = tf.zeros(nobs)
        for j in range(ndims):
            forward_log_det_J += utl.gmm_lpdf(x_mat[:,j],self.alpha,self.mu_vectors[:,j],self.cov_matrices[:,j,j])
        return forward_log_det_J
        
    
    def _inverse_log_det_jacobian(self, u_mat):
        return -self._forward_log_det_jacobian(self._inverse(u_mat))

                        