In [None]:

def uniform_with_smooth_edge(a, b, h):
    """Uniform prior with smooth edges.
        Uniform in the range [a, b] and decays to zero in the regions [a-h, a] and [b, b+h].
        Normalization constant is 1/ (b - a + h).
        Returns:
            nlogp function, taking x and returning -log p(x)
    """
    
    nlog_normalization = jnp.log(b-a+h)
    c = (b-a) * 0.5
    
    def nlogp(x):
        centered_x = x - (a + c)
        y = (jnp.abs(centered_x) - c) / h
        edge = 0.5 * ( 1 + jnp.cos(jnp.pi * y))
        return jax.lax.select(y < 1, ((y >= 0) * edge + (y < 0.)) / jnp.exp(nlog_normalization), jnp.zeros(y.shape))
    
    ### we generate samples by rejection sampling with the uniform distribution proposal 
    
    def importance_weight(x):
        centered_x = x - (a + c)
        y = (jnp.abs(centered_x) - c) / h
        edge = 0.5 * ( 1 + jnp.cos(jnp.pi * y))
        return jax.lax.select(y < 0, 1., edge)
    
    def proposal(state):
        _, _, key = state
        key1, key2, key3 = jax.random.split(key, 3)
        x = jax.random.uniform(key1, minval= a - h, maxval= b + h)
        acc_prob = importance_weight(x)
        reject = jax.random.bernoulli(key2, 1.-acc_prob).astype(bool)
        return x, reject, key3

    def generate(key):
        cond = lambda state: state[1]
        init= (0., True, key)
        state= jax.lax.while_loop(cond, proposal, init)
        return state[0]
    
    return nlogp, generate



p, g = uniform_with_smooth_edge(4., 7., 0.5)
key = jax.random.PRNGKey(0)


X = jax.vmap(g)(jax.random.split(key, 10000))
x = jnp.linspace(3, 8, 1000)

plt.plot(x, p(x))
#plt.hist(np.array(X), bins = 30, density= True)
plt.ylabel('prior density')
plt.xlabel('log frequency')
plt.xticks([3.5, 4., 7, 7.5], [r'$\frac{1.5}{T}$', r'$\frac{2}{T}$', r'$\frac{1}{60}$', r'$\frac{1}{45}$'])
plt.show()