In [2]:
# Numerical analysis
import numpy as np
import jax.numpy as jnp
from jax import random
from jax.nn import sigmoid

# Bayesian inference
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS
import arviz as az

# Visualisation
import matplotlib.pyplot as plt
import seaborn as sns

# Set pseudo random number key
prng_key = random.PRNGKey(0)

In [None]:
def generate_data(dim, pstar, A, sigma, nrep):
  """
  Generate synthetic data.

  Parameters:
    int dim: The dimension of the dataset
    int pstar: The number of true non-zero signals
    real A: The magnitude of true non-zero signals
    real sigma: Standard deviation of the noise
    nrep: The number of replications
  """
  data = np.zeros((nrep, dim))       # Initialise
  data[:, :pstar] += A               # Add signal
  data += np.random.normal(0, sigma, size=(nrep, dim))  # Add noise
  return data

In [None]:
def model(lam, sigma=1, y=None):
    '''
    Defines the Concrete random variable distribution.

    Parameters:
       real lam: temperature coefficient
       real sigma: stdev of y
       array y: dependent variable
    '''
    D = y.shape
    alphas = numpyro.deterministic("alphas", jnp.exp(y))
    U = numpyro.sample("U", dist.Uniform(0, 1).expand(D))
    L = numpyro.deterministic("L", jnp.log(U) - jnp.log(1-U))
    gammas = numpyro.deterministic("gammas", sigmoid((L + jnp.log(alphas)) / lam))
    kappas = numpyro.deterministic("kappas", 1-gammas)

$$\begin{align*}

\text{REBAR} =& 
E[f(H(z)) - \eta f(\sigma_\lambda (\tilde{z})) \frac{\partial}{\partial \theta} \log p(b)) |_{b=H(z)}\\
&+ \eta \frac{\partial}{\partial \theta} f(\sigma_\lambda (z)) - \eta \frac{\partial}{\partial \theta} f(\sigma_\lambda (\tilde{z}))
]\\\\

z &:= \log \frac{\theta}{1-\theta} + \log \frac{u}{1-u} \\
u &\sim \text{Uniform}(0, 1)\\\\

\tilde{z} &:= \begin{cases}
\log (\frac{v}{1-v}\frac{1}{1-\theta} + 1), \text{ if } b=1 \\
- \log (\frac{v}{1-v}\frac{1}{\theta} + 1), \text{ if } b=0 \\
\end{cases}\\

v &\sim \text{Uniform}(0, 1) \\\\

H(z) &\approx \sigma_\lambda(z_\lambda) \\
z_\lambda &= \frac{\lambda^2+\lambda+1}{\lambda+1} \log\frac{\theta}{1-\theta} + \log \frac{u}{1-u}
\end{align*}$$

In [None]:
def rebar_estimator(fn, theta, b, data, eta, lam):
    n = data.shape[0]

    u = random.uniform(shape=(n,))
    v = random.uniform(shape=(n,))

    z = jnp.log(theta) - jnp.log(1-theta) + jnp.log(u) - jnp.log(1-u)
    z_lam = (lam**+lam+1)/(1+lam) * (jnp.log(theta) - jnp.log(1-theta)) + jnp.log(u) - jnp.log(1-u)
    H_z = sigmoid(z_lam / lam)

    if b == 1:
        z_tilde = jnp.log((v / ((1-v)*(1-theta))) + 1)
    elif b == 0:
        z_tilde = -jnp.log((v/((1-v)*theta)) + 1)
    else:
        raise ValueError("b must be 0 or 1")
    
    
    fn(H_z) - eta * fn(sigmoid(z_tilde/lam))* ?? 
    eta * ??
    - eta * ??
    

Goal: find $\lambda$ that minimizes the variance of the REBAR estimator.

$$\begin{align*}
\frac{\partial}{\partial \lambda} \text{Var}[r(\lambda)] 
= E[2r(\lambda)\frac{\partial r(\lambda)}{\partial \lambda}]
\end{align*}$$

In [None]:
def optimizer(data):
    rebar = rebar_estimator(fn, theta, b, data, eta, lam)

    # minimize the following:
    2*rebar * ??
    