# ABC 

## Numpy

In [28]:
import numpy as np

In [33]:
def prior(N,dim):
    return np.random.uniform(-2,2, size=(N,dim)).astype(np.float32)


def simulator(theta, x):
    sigma = np.random.normal(0, 0.1)
    sim = np.dot(theta, x) + sigma
    return sim

def inference(y, x, theta, e):
    accepted_samples = []
    for i in range(len(theta)):
        sim = simulator(theta[i, :], x)
        #print(sim)
        dis = np.linalg.norm(sim - y)
        #rint(dis)
        if dis < e:
            accepted_samples.append(theta[i,:])
    return np.array(accepted_samples)


# Define the dimensions and the number of samples
N = 100000
dim =10 # 2, 5, 10

# Observation
y_obs = np.array(0.3)
# Use seed to have the same value for x in each run
np.random.seed(42)
x = np.random.uniform(-0.5, 0.5, size=(dim,))
#print(x)
e = 0.01 # threshold

# Sampling thetas
np.random.seed(None)
theta_pr =prior(N, dim)

# List of the accepted samples
samples_pos = inference(y_obs, x, theta_pr, e)

In [34]:
samples_pos.shape # 2: 1167, 5: 981, 10: 669

(703, 10)

## JAX

In [7]:
import jax
import jax.numpy as jnp
from jax import lax
import time

In [27]:
def prior(N, dim, key=None):
    if key is None:
        key = jax.random.PRNGKey(int(time.time() * 1e6))  # Use current timestamp as seed
    key, subkey = jax.random.split(key)  # Split the key to get a new subkey
    return jax.random.uniform(subkey, shape=(N, dim), minval=-2, maxval=2, dtype=jnp.float32), key

def simulator(theta, x, rng_key=None):
    # Generate noise using JAX's random number generation
    if rng_key is None:
        rng_key = jax.random.PRNGKey(int(time.time()*1e6))
    rng_key, subkey = jax.random.split(rng_key)  # Split the key for noise generation
    sigma = jax.random.normal(subkey, shape=(), dtype=jnp.float32) * 0.1
    # Generate the simulated data 
    sim = jnp.dot(theta, x) + sigma
    #print(f'The simulated data are: ', sim)
    return sim, rng_key

def inference(y, x, theta, e, rng_key=None):
    accepted_samples = []
    for i in range(len(theta)):
        sim, rng_key = simulator(theta[i, :], x, rng_key)
        dis = jnp.linalg.norm(sim - y)
        #print(f'The distance is: ', dis)
        if dis < e:
            accepted_samples.append(theta[i, :])
    return jnp.array(accepted_samples)

# Define the dimensions and number of samples
N = 100000
dim = 10 # 2: 1224, 5: 1051, 10: 970

# Observation
y_obs = jnp.array(0.3)
x = jax.random.uniform(jax.random.PRNGKey(0), shape=(dim,), minval=-0.5, maxval=0.5)
print(x)
e = 0.01

# Sampling thetas
#global_key = jax.random.PRNGKey(0)
theta_pr, global_key = prior(N, dim)
#print(f'Theta priors are: ', theta_pr)

# List of the accepted samples
samples_pos = inference(y_obs, x, theta_pr, e)
print("Accepted samples shape:", samples_pos.shape)


[-0.14509487  0.10419905 -0.07241571 -0.26938403 -0.17014146 -0.06046343
 -0.24900234 -0.22269428  0.26782072  0.21474564]
Accepted samples shape: (970, 10)
