In [1]:
import numpy as np

import jax.numpy as jnp
from jax import jit, grad, value_and_grad, vmap, random
from jax.scipy.special import logsumexp

import jax.numpy as jnp
from jax import jit, vmap, grad, random, lax, ops
import numpy as np
import time
#from jax import partial

In [2]:
def genCovMat(key, d):
    return jnp.eye(d)


def logistic(theta, x):
    return 1/(1+jnp.exp(-jnp.dot(theta, x)))

batch_logistic = jit(vmap(logistic, in_axes=(None, 0)))
batch_benoulli = vmap(random.bernoulli, in_axes=(0, 0))

def gen_data(key, dim, N):
    """
    Generate data with dimension `dim` and `N` data points

    Parameters
    ----------
    key: uint32
        random key
    dim: int
        dimension of data
    N: int
        Size of dataset

    Returns
    -------
    theta_true: ndarray
        Theta array used to generate data
    X: ndarray
        Input data, shape=(N,dim)
    y_data: ndarray
        Output data: 0 or 1s. shape=(N,)
    """
    key, subkey1, subkey2, subkey3 = random.split(key, 4)
    print(f"generating data, with N={N} and dim={dim}")
    theta_true = random.normal(subkey1, shape=(dim, ))*jnp.sqrt(10)
    print(theta_true)
    print(jnp.shape(theta_true))
    print(type(theta_true))
    covX = genCovMat(subkey2, dim)
    X = jnp.dot(random.normal(subkey3, shape=(N,dim)), jnp.linalg.cholesky(covX))

    p_array = batch_logistic(theta_true, X)
    keys = random.split(key, N)
    y_data = batch_benoulli(keys, p_array).astype(jnp.int32)
    return theta_true, X, y_data

In [3]:
def build_grad_log_post(X, y_data, N):
    """
    Builds grad_log_post
    """
    @jit
    def loglikelihood(theta, x_val, y_val):
        return -logsumexp(jnp.array([0., (1.-2.*y_val)*jnp.dot(theta, x_val)])) #gets one vector 
    @jit
    def log_prior(theta):
        return -(0.5/10)*jnp.dot(theta,theta) #get one value 

    batch_loglik = jit(vmap(loglikelihood, in_axes=(None, 0,0)))

    def log_post(theta):
        return log_prior(theta) + N*jnp.mean(batch_loglik(theta, X, y_data), axis=0) #one value 

    grad_log_post = jit(grad(log_post))
    return grad_log_post

def build_value_and_grad_log_post(X, y_data, N):
    """
    Builds grad_log_post
    """
    @jit
    def loglikelihood(theta, x_val, y_val):
        return -logsumexp(jnp.array([0., (1.-2.*y_val)*jnp.dot(theta, x_val)]))

    @jit
    def log_prior(theta):
        return -(0.5/10)*jnp.dot(theta,theta)

    batch_loglik = jit(vmap(loglikelihood, in_axes=(None, 0,0)))

    def log_post(theta):
        return log_prior(theta) + N*jnp.mean(batch_loglik(theta, X, y_data), axis=0)

    val_and_grad_log_post = jit(value_and_grad(log_post))
    return val_and_grad_log_post

def build_batch_grad_log_post(X, y_data, N):
    """
    Builds grad_log_post that takes in minibatches X and y_data
    """
    @jit
    def loglikelihood(theta, x_val, y_val):
        return -logsumexp(jnp.array([0., (1.-2.*y_val)*jnp.dot(theta, x_val)]))

    @jit
    def log_prior(theta):
        return -(0.5/10)*jnp.dot(theta,theta)

    batch_loglik = jit(vmap(loglikelihood, in_axes=(None, 0,0)))

    def log_post(theta, X, y_data):
        return log_prior(theta) + N*jnp.mean(batch_loglik(theta, X, y_data), axis=0)

    grad_log_post = jit(grad(log_post))
    return grad_log_post



In [4]:
def ula_kernel(key, param, grad_log_post, dt):
    key, subkey = random.split(key)
    paramGrad = grad_log_post(param)
    param = param + dt*paramGrad + jnp.sqrt(2*dt)*random.normal(key=subkey, shape=(param.shape))
    return key, param


def ula_sampler_jax_kernel(key, grad_log_post, num_samples, dt, x_0, print_rate=500):
    dim, = x_0.shape
    samples = np.zeros((num_samples, dim))
    param = x_0
    print(f"Python loop with Jax kernel")
    for i in range(num_samples):
        key, param = ula_kernel(key, param, grad_log_post, dt)
        samples[i] = param
        if i%print_rate==0:
            print(f"Iteration {i}/{num_samples}")
    return samples

In [5]:
def sgld_kernel(key, param, grad_log_post, dt, X, y_data, minibatch_size):
    N, _ = X.shape
    key, subkey1, subkey2 = random.split(key, 3)
    idx_batch = random.choice(subkey1, N, shape=(minibatch_size,))
    paramGrad = grad_log_post(param, X[idx_batch], y_data[idx_batch])
    param = param + dt*paramGrad + jnp.sqrt(2*dt)*random.normal(key=subkey2, shape=(param.shape))
    return key, param


def sgld_sampler_jax_kernel(key, grad_log_post, num_samples, dt, x_0, X, y_data, minibatch_size, print_rate=500):
    dim, = x_0.shape
    samples = np.zeros((num_samples, dim))
    param = x_0
    print(f"Python loop with Jax kernel")
    for i in range(num_samples):
        key, param = sgld_kernel(key, param, grad_log_post, dt, X, y_data, minibatch_size)
        samples[i] = param
        if i%print_rate==0:
            print(f"Iteration {i}/{num_samples}")
    return samples

In [14]:
key = random.PRNGKey(0)
N = 1000
dim = 4
dt = 5e-3
num_samples = 100
print_rate = num_samples/2

theta_true, X, y_data = gen_data(key, dim, N)
print(theta_true)
print(f'xdata{X}')
print(f'ydata{(y_data)}')
grad_log_post = build_grad_log_post(X, y_data, N)
ula_sampler_jax_kernel(key, grad_log_post, num_samples,
                                              dt=dt, x_0=theta_true, print_rate=print_rate)
#ula_sampler_jax_kernel(key, grad_log_post, num_samples, dt, x_0=theta_true, print_rate=500)

generating data, with N=1000 and dim=4
[-2.0566347  -0.02853895 -4.105336   -3.1483822 ]
(4,)
<class 'jaxlib.xla_extension.ArrayImpl'>
[-2.0566347  -0.02853895 -4.105336   -3.1483822 ]
xdata[[ 1.4415115e+00 -1.0026432e+00 -1.2540932e+00  1.7176519e-01]
 [ 8.9354467e-01  6.6053903e-01 -1.2657831e+00  4.3967950e-01]
 [-8.6018312e-01 -6.4259398e-01 -1.0805608e+00  1.3151470e+00]
 ...
 [-7.3177606e-01 -1.0262679e+00 -3.0052990e-01 -1.4075083e+00]
 [-3.6647838e-01  2.7720016e-01  1.4874473e+00 -1.5547055e-01]
 [ 3.3587062e-01  1.6510191e-03  4.6714416e-01  1.9931759e+00]]
ydata[1 1 1 0 1 1 0 0 0 0 0 1 0 1 1 1 1 1 0 0 0 1 1 1 0 0 0 1 0 1 1 1 1 1 0 1 1
 0 1 0 0 0 0 0 1 1 1 1 1 0 1 0 1 1 0 1 1 1 1 0 0 1 1 0 0 1 1 1 0 0 0 1 1 0
 1 0 1 1 0 1 1 0 0 1 1 0 0 1 1 1 0 0 0 0 0 1 0 1 0 0 0 1 1 0 0 1 1 0 1 1 1
 1 0 0 1 1 1 0 1 0 1 1 1 1 0 1 1 0 0 0 0 0 0 0 0 1 0 0 0 1 0 1 1 0 0 0 1 1
 0 0 0 0 1 0 0 1 1 0 0 1 1 0 0 1 1 0 0 0 0 0 1 0 0 1 0 1 0 1 1 0 1 0 0 0 1
 0 1 0 0 0 1 1 1 0 0 0 1 0 0 1 0 0 1 1 0 0 0 1

array([[-1.93164706, -0.10784662, -4.15332365, -3.08923864],
       [-1.96454895, -0.06535793, -3.99635649, -3.15591908],
       [-2.18910909, -0.25673312, -4.04947758, -3.17306566],
       [-2.1026814 , -0.12165285, -4.07850266, -3.16832447],
       [-2.16543818, -0.16428615, -3.91803837, -3.1615386 ],
       [-2.22167659, -0.2581431 , -4.05087996, -3.10364795],
       [-2.1895051 , -0.30929056, -3.99353409, -3.21255803],
       [-2.34161305, -0.17326409, -4.21587658, -3.08674097],
       [-2.2121079 , -0.35579789, -3.9937129 , -3.1285429 ],
       [-2.15039921, -0.49635336, -4.12209988, -3.08596396],
       [-2.04387355, -0.3115744 , -4.14943838, -3.16309834],
       [-1.96261418, -0.1321487 , -4.23607206, -3.25259972],
       [-2.08475423, -0.01639844, -4.38709402, -3.14765263],
       [-2.16799903, -0.10149371, -4.27598381, -3.13940239],
       [-1.98980629, -0.17154115, -4.18998003, -3.23608041],
       [-1.79847801, -0.14314932, -4.23999691, -3.22668409],
       [-1.83534598, -0.

In [7]:
num_samples = 20000
print_rate = print_rate/2
dim = 5

key = random.PRNGKey(0)
print_rate = num_samples/2
N = 1000
minibatch_size = int(N*0.1)
dt = 5e-3

theta_true, X, y_data = gen_data(key, dim, N)
grad_log_post = build_batch_grad_log_post(X, y_data, N)
sgld_sampler_jax_kernel(key, grad_log_post, num_samples,
                                              dt=dt, x_0=theta_true, X=X, y_data=y_data,
                                              minibatch_size=minibatch_size, print_rate=print_rate)

generating data, with N=1000 and dim=5
[ 6.7104073 -1.4863473 -1.6330161 -1.152002   4.4074416]
(5,)
<class 'jaxlib.xla_extension.ArrayImpl'>
Python loop with Jax kernel
Iteration 0/20000
Iteration 10000/20000


array([[ 6.88011503, -1.27444351, -1.74291885, -1.22155797,  4.25603485],
       [ 6.83632898, -1.20264292, -1.55853927, -1.26840484,  4.22265434],
       [ 6.87512732, -1.38083804, -1.57559919, -1.20404816,  4.26270103],
       ...,
       [ 7.74614763, -1.17637455, -1.82227933, -1.49867773,  5.1773901 ],
       [ 7.73753977, -1.28514969, -1.74067044, -1.56225634,  5.35626411],
       [ 7.69129801, -1.26810491, -1.71750188, -1.8206985 ,  5.16555595]])