In [6]:
import jax
import jax.numpy as jnp
from jax import lax
import time
from jax import random
import time

In [8]:
def prior(N, dim, key):
    return random.uniform(key, shape=(N,dim), minval=-2, maxval=2, dtype=jnp.float32)

def simulator(theta, x, key):
    # Generate random noise sigma from a normal distribution
    sigma = random.normal(key, dtype=jnp.float32)*0.01
    # Compute the simulation
    sim = jnp.dot(theta, x) + sigma
    return sim

def inference(y, x, theta, e, key):
    accepted_samples = []
    for i in range(len(theta)):
        sim = simulator(theta[i, :], x, key)
        dis = jnp.linalg.norm(sim-y)
        #print("The distance is: ", d)
        if dis < e:
            accepted_samples.append(theta[i, :])
    return jnp.array(accepted_samples)

# Generate key
seed=123
key = random.PRNGKey(seed)

# Define samples and dim
N, dim = 100000, 3

# Observation
y_obs = jnp.array(0.3)
x = random.uniform(key, shape=(dim,), minval=-0.5, maxval=0.5)
print(x)
e = 0.01

# STart time
start_time = time.time()

theta = prior(N, dim, key)
#print("Generate theta: \n", theta)

sample_pos = inference(y_obs, x, theta, e, key)

end_time = time.time()
elapsed_time = end_time - start_time

print("Accepted samples shape: ", sample_pos.shape)
print(f"Inference function took {elapsed_time:.2f} seconds to run")

[-0.05844688  0.20968866  0.45037472]
Accepted samples shape:  (1094, 3)
Inference function took 27.12 seconds to run


In [33]:
sample_pos

Array([[ 1.2974119, -1.4024577,  1.5159726],
       [-1.2172799, -0.4043584,  0.7205825]], dtype=float32)

In [4]:
import jax
import jax.numpy as jnp
from jax import random
import time

# Prior distribution
def prior(N, dim, key):
    return random.uniform(key, shape=(N,dim), minval=-2, maxval=2, dtype=jnp.float32)

# Simulator function (adding small noise)
def simulator(theta, x, key):
    sigma = random.normal(key, dtype=jnp.float32) * 0.01
    return jnp.dot(theta, x) + sigma

# Inference function
def inference(y, x, theta, e, key):
    accepted_samples=[]

    # Iterate over all theta samples
    for i in range(len(theta)):
        # Split the key each time to generate independent random numbers
        key, subkey = random.split(key)
        # Simulate the data for the current theta
        sim = simulator(theta[i], x, subkey)
        # Compute the distance between the simulation and the observed data
        distance = jnp.linalg.norm(sim-y)
        # Check if the distance is smaller than threshold e and accept sample
        if distance<e:
            accepted_samples.append(theta[i])

    return jnp.array(accepted_samples)

# Main execution
seed=123
key = random.PRNGKey(seed)

# Define samples and dimension
N, dim = 100000, 3

# Observed value
y_obs = jnp.array(0.3)
x = random.uniform(key, shape=(dim,), minval=-0.5, maxval=0.5)
e = 0.01

# Start the timer
start_time = time.time()

# Generate theta using prior
key, subkey = random.split(key)
theta = prior(N, dim, subkey)

# Run inference
key, subkey = random.split(key)
sample_pos = inference(y_obs, x, theta, e, subkey)

end_time = time.time()
elapsed_time = end_time-start_time

# Print results
print("Accepted samples shape: ", sample_pos.shape)
print(f"Inference function took {elapsed_time:.2f} seconds to run.")

Accepted samples shape:  (1084, 3)
Inference function took 35.78 seconds to run.


### Vectorization for speed

In [5]:
import jax
import jax.numpy as jnp
from jax import random
import time

# Prior distribution
def prior(N, dim, key):
    return random.uniform(key, shape=(N, dim), minval=-2, maxval=2, dtype=jnp.float32)

# Simulator function (adding small noise)
def simulator(theta, x, key):
    sigma = random.normal(key, dtype=jnp.float32) * 0.01  # small noise
    return jnp.dot(theta, x) + sigma

# Vectorized Inference function
def inference(y, x, theta, e, key):
    # Generate independent subkeys for each simulation
    keys = random.split(key, theta.shape[0])

    # Simulate all theta values simultaneously using vmap
    sims = jax.vmap(lambda th, key: simulator(th, x, key))(theta, keys)

    # Calculate the distances for all samples at once
    distances = jnp.abs(sims - y)

    # Mask samples where distance < e
    mask = distances < e

    # Apply mask to theta to get accepted samples
    accepted_samples = theta[mask]

    return accepted_samples, sims, distances, mask

# Main execution
seed = 123
key = random.PRNGKey(seed)

# Define samples and dimension
N, dim = 100000, 3

# Observed value
y_obs = jnp.array(0.3)
x = random.uniform(key, shape=(dim,), minval=-0.5, maxval=0.5)
e = 0.01

# Start the time
start_time = time.time()

# Generate theta using prior
key, subkey = random.split(key)
theta = prior(N, dim, subkey)

# Run inference
key, subkey = random.split(key)
sample_pos, sims, distances, mask = inference(y_obs, x, theta, e, subkey)

# Calculeta time
end_time = time.time()
elapsed_time = end_time - start_time

# Print results
print("Accepted samples shape: ", sample_pos.shape)
print(f"Inference function took {elapsed_time:.2f} seconds to run.")


Accepted samples shape:  (1035, 3)
Inference function took 0.50 seconds to run.
