In [1]:
import jax.numpy as jnp
from jax import grad,jacobian,vmap
import pandas as pd
import numpy as np
import jax.scipy.special as sps
import jax.scipy.stats as scs
from jax import random
import pymc as pm
from scipy.special import digamma
import matplotlib.pyplot as plt

In [2]:
class priors:
    
    def pivot(self, theta, *args, **kwds):
        
        """
        Pivot function of the distribution
        """
        
        return self.pivot(theta, *args, **kwds)
    
    def inverse_pivot(self, x, *args, **kwds):
        
        """
        Inverse pivot function of the distribution
        """
        
        return self.inverse_pivot(x, *args, **kwds)
    
    def sample_x(self, *args, **kwds):
        
        """
        Function to sample from X (independent from the hyperparameters lambda)
        """
        
        return self.sample_x(*args, **kwds)
    
    def grad_inverse_pivot(self, x, *args, **kwds):
        
        """
        Gradient of the inverse pivot with respect to the hyperparameters lambda
        """
        
        return self.grad_inverse_pivot(x, *args, **kwds)
    


class model_probabilities:
    
    def __init__(self, is_discrete=False):
        self.is_discrete = is_discrete
    
    def cdf(self, y, *args, **kwds):
        
        return self.cdf(y, *args, **kwds)
    
    def pdf(self, y, *args, **kwds):
        
        return self.pdf(y, *args, **kwds)
    
    def partition_prob(self, partition, *args, **kwds):
        
        if self.is_discrete:
            return self.pdf(partition, *args, **kwds)
        
        a = partition[0]
        b = partition[1]
        
        return self.cdf(b, *args, **kwds) - self.cdf(a, *args, **kwds)
        
    def grad_partition_prob(self, partition, *args, **kwds):
        
        return self.grad_partition_prob(partition, *args, **kwds)
    

class gaussian_prior(priors):
    
    def __init__(self, mu=None, sigma=None):
        super().__init__()
        self.mu = mu
        self.sigma = sigma
        self.is_discrete = False
    
    def _update(self, mu, sigma):
        self.mu = mu
        self.sigma = sigma
        
    def pivot(self, theta):
        
        return (theta - self.mu)/self.sigma
    
    def inverse_pivot(self, x):
        
        return x * self.sigma + self.mu
    
    def grad_inverse_pivot(self, x):
        
        dsigma = x
        dmu = jnp.ones(len(dsigma)) if len(dsigma) > 1 else 1.
        
        return jnp.array([dmu, dsigma]) ## 1st row is dmu, 2nd row is dsigma
    
    def sample_x(self, size):
        
        return np.random.normal(size=size)
    
        



class gaussian_model_probs(model_probabilities):
    
    def __init__(self, mu=None, sigma=None):
        super().__init__()
        self.mu = mu
        self.sigma = sigma
        self.is_discrete = False
    
    def _update(self, mu, sigma):
        self.mu = mu
        self.sigma = sigma
        
    def cdf(self, x):
        
        x = np.asarray(x)
        return scs.norm.cdf(x, loc = self.mu, scale = self.sigma)
    
    def model_prob_gradient(self, partition):
        
        a = partition[0]
        b = partition[1]
        
        #x = jnp.asarray(x)
        
        dmu = - (1/self.sigma) * scs.norm.pdf(b, loc=self.mu, scale=self.sigma) + (1/self.sigma) * scs.norm.pdf(a, loc=self.mu, scale=self.sigma)
        dsigma = -((b - self.mu) / self.sigma**2) * scs.norm.pdf(b, loc=self.mu, scale=self.sigma) + ((a - self.mu) / self.sigma**2) * scs.norm.pdf(a, loc=self.mu, scale=self.sigma)
        
        return jnp.array([dmu, dsigma]) ## 1st row is dmu, 2nd row is dsigma
        
        #return grad(self.jax_helper_cdf, argnums=1)(b, self.mu, self.sigma) - grad(self.jax_helper_cdf, argnums=1)(a, self.mu, self.sigma) ### IMPROVE THIS WITH VMAP!!!
    
    #def jax_helper_cdf(self, x, params):
        
        #mu = params[0]
        #sigma = params[1]
        
        #return scs.norm.cdf(x, loc = mu, scale = sigma)
    
    



   

In [3]:
mu_1 = 2.
sigma = 1.
sigma_1 = 2.

partition = np.array([-2.,3.])


gs_prior = gaussian_prior(mu = mu_1, sigma = sigma_1)

x_samples = gs_prior.sample_x(1000000)

theta = gs_prior.inverse_pivot(x_samples)

gs_probs = gaussian_model_probs(mu = theta, sigma = sigma)

prob_y_given_theta_dtheta_samples = gs_probs.model_prob_gradient(partition)[0,:]

prob_y_given_theta_dsigma_samples = gs_probs.model_prob_gradient(partition)[1,:]

pivot_gaussian_inverse_grad_samples = gs_prior.grad_inverse_pivot(x_samples)


ppe_dmu1 = jnp.mean(prob_y_given_theta_dtheta_samples * pivot_gaussian_inverse_grad_samples[0,:])
ppe_dsigma1 = jnp.mean(prob_y_given_theta_dtheta_samples * pivot_gaussian_inverse_grad_samples[1,:])
ppe_dsigma = jnp.mean(prob_y_given_theta_dsigma_samples)

print(ppe_dmu1, ppe_dsigma, ppe_dsigma1)

-0.12508571 -0.06117869 -0.1223839


In [4]:
def get_gaussian_probs(partition, lam):
    
    mu_1 = lam[0]
    sigma = lam[1]
    sigma_1 = lam[2]
    
    a = partition[0]
    b = partition[1]
    
    p1 = scs.norm.cdf((b - mu_1)/jnp.sqrt(sigma**2 + sigma_1**2)) - scs.norm.cdf((a - mu_1)/jnp.sqrt(sigma**2 + sigma_1**2))    
    
    return p1

lam = jnp.array([2., 1., 2.])
partition = jnp.array([-2., 3.])

print(grad(get_gaussian_probs, argnums=1)(partition, lam))

[-0.12541339 -0.06110352 -0.12220705]
