# ABC 

### ABC Algorithm

1. Draw θ from the prior
2. Generate the simulated data g(θ,x) = N(y, θ1x1 + θ2x2, σ^2)
3. Calculate the distance d(f(θ,x), y)
4. Accept the samples if d(f(θ,x)) < ε (threshold)

_The problem with the ABC algorithm is that a lot of samples get rejected. As the number of parameters increases, the number of accepted samples decreases. Therefore, it's necessary to reduce the number of rejected samples and make the algorithm more efficient and deterministic._

### Numpy

In [2]:
import numpy as np

In [8]:
def prior(N,dim):
    return np.random.uniform(-2,2, size=(N,dim)).astype(np.float32)


def simulator(theta, x):
    sigma = np.random.normal(0, 0.01)
    sim = np.dot(theta, x) + sigma
    return sim

def simulator2(theta, x):
    # Compute the mean of the normal distribution
    mean = np.dot(theta, x)
    # Assume fixed variance
    var = 0.01
    # Generate sample from the normal distribution with mean and variance
    sim = np.random.normal(mean, np.sqrt(var))
    return sim

def inference(y, x, theta, e):
    accepted_samples = []
    for i in range(len(theta)):
        sim = simulator2(theta[i, :], x)
        #print(sim)
        dis = np.linalg.norm(sim - y)
        #rint(dis)
        if dis < e:
            accepted_samples.append(theta[i,:])
    return np.array(accepted_samples)


# Define the dimensions and the number of samples
N = 100000
dim = 10 # 2, 5, 10

# Observation
y_obs = np.array(0.3)
# Use seed to have the same value for x in each run
np.random.seed(42)
x = np.random.uniform(-0.5, 0.5, size=(dim,))
#print(x)
e = 0.01 # threshold

# Sampling thetas
np.random.seed(None)
theta_pr =prior(N, dim)

# List of the accepted samples
samples_pos = inference(y_obs, x, theta_pr, e)

In [9]:
samples_pos.shape # simulator1: 2: 1167, 5: 981, 10: 669, simulator2: 2: 1097:931, 10:642

(642, 10)

## OMC

### OMC algorithm

1. Draw θ from the prior
2. Draw u (nuisance variables) from its distribution
3. Generate the simulated data ysim = f(θ x u)
4. Define the objective function d(f(θ,u), y))
5. Minimize the objective function to find the optimal u, u*
6. Find the new simulated data based on u*
7. Compute the jacobian matrix J $ \frac{\partial f(\theta, u)}{\partial \theta} $
8. Compute the weights from the formula $ w_i = \left(\det(J^\top J)\right)^{-\frac{1}{2}} $
9. Accept $ \theta_i $ as posterior sample with weight $ w_i $.

_The determinant of the Jacobian matrix represents the volume of the parallepepiped spanned by the columns of $ J_i $. This volume gives insights in how changes in $ \theta $ affect the summary statistics. Ill-conditioned matrices $  J_i^\top J_i $ produce very large weights for the corresponding $ \theta_i $, possibly completely overshadowing the remaining samples and creating an approximate posterior density that is spiked at a single location. These ill-conditioned matrices occur when a large parameter region around the optimum $ \theta_i $ produces data with small distances. In other words, the OMC failure occurs when a large parameter region around $ \theta_i $ is a solution to $ || f(\theta, u) - y|| < \epsilon $, meaning different $ \theta $ generate the same $ y $, which happens for example, when the likelihood function is (nearly) constant around $ \theta_i $._

In [46]:
import numpy as np
from scipy.optimize import minimize

def prior(N, dim):
    return np.random.uniform(-2, 2, size=(N, dim)).astype(np.float32)

def simulator(theta, u):
    # Assuming the simulator function is linear in theta and u introduces stochasticity
    return np.dot(theta, u)

def objective_function(u, theta, y):
    sim = simulator(theta, u)
    return np.linalg.norm(sim - y)

def compute_jacobian(theta, u):
    # Jacobian of the simulator with respect to theta
    return np.array([u] * len(theta))

def compute_weight(jacobian_matrix, regularization=1e-8):
    #JTJ = np.dot(jacobian_matrix.T, jacobian_matrix) + np.eye(jacobian_matrix.shape[1]) * regularization
    JTJ = np.dot(jacobian_matrix.T, jacobian_matrix)
    det = np.linalg.det(JTJ)
    return (1.0 / np.sqrt(det)) if det > 0 else 0

def inference(y, theta, u_dim):
    accepted_samples = []
    weights = []
    
    for theta_i in theta:
        #print(f'Theta_i is: {theta_i}')
        
        # Initial guess for u (can be sampled from a distribution)
        u_init = np.random.uniform(-1, 1, u_dim)
        #print(f'U_init is: {u_init}')
        result = minimize(objective_function, u_init, args=(theta_i, y), method='L-BFGS-B')
        u_star = result.x
        #print(f'u_star is {u_star}')
        
        sim = simulator(theta_i, u_star)
        jacobian_matrix = compute_jacobian(theta_i, u_star)
        weight = compute_weight(jacobian_matrix)
        
        accepted_samples.append(theta_i)
        weights.append(weight)
    
    return np.array(accepted_samples), np.array(weights)

# Define the dimensions and the number of samples
N = 1000
theta_dim, u_dim = 5, 5  # Dimension of theta, # Dimension of nuisance variables

# Observation
y_obs = np.full(theta_dim, 0.3)  # Assuming y has the same dimension as simulator output
np.random.seed(42)

# Sampling thetas
np.random.seed(None)
theta_pr = prior(N, theta_dim)

# List of the accepted samples
samples_pos, weights = inference(y_obs, theta_pr, u_dim)
print(f'Number of accepted samples: {len(samples_pos)}')


Number of accepted samples: 1000


In [44]:
weights

[0,
 0,
 3.072279379752269e+32,
 4.363429448805528e+32,
 0,
 0,
 0,
 0,
 4.387664344777639e+31,
 0,
 3.571678498935494e+34,
 0,
 0,
 2.550354252905469e+32,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 1.2164115828141494e+32,
 2.3621135640150057e+32,
 0,
 0,
 0,
 4.3050287016928986e+33,
 0,
 4.800380200327931e+32,
 0,
 0,
 0,
 1.7819417562856782e+32,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 1.2121030689215575e+32,
 0,
 0,
 0,
 7.719314313576575e+31,
 1.6669639303642515e+33,
 2.0856188027715676e+32,
 1.0621889239842815e+31,
 4.300936825390029e+31,
 0,
 8.590503505983072e+31,
 0,
 4.968385249056952e+31,
 0,
 1.6268943002407214e+32,
 3.059472711406161e+32,
 0,
 2.4736828607359856e+38,
 7.483116707475131e+31,
 0,
 0,
 7.493093445737249e+35,
 1.4050229311904254e+33,
 2.963404368729938e+32,
 0,
 0,
 0,
 0,
 0,
 3.708163761846783e+32,
 2.2177814075833348e+33,
 2.4745778870893816e+32,
 7.423269739238661e+30,
 0,
 1.780456620494702e+32,
 4.898426703418154e+30,
 0,
 6.324419170459073e+31,
 0,
 0,
 1.4971276483037814

In [48]:
# Create a mask for non-zero weights
non_zero_mask = weights > 0

# Apply the mask to filter samples and weights
filtered_samples = samples_pos[non_zero_mask]
filtered_weights = weights[non_zero_mask]

print(f'The number of the non-negative samples are: {len(filtered_samples)}')

The number of the non-negative samples are: 358


## JAX

In [7]:
import jax
import jax.numpy as jnp
from jax import lax
import time

In [27]:
def prior(N, dim, key=None):
    if key is None:
        key = jax.random.PRNGKey(int(time.time() * 1e6))  # Use current timestamp as seed
    key, subkey = jax.random.split(key)  # Split the key to get a new subkey
    return jax.random.uniform(subkey, shape=(N, dim), minval=-2, maxval=2, dtype=jnp.float32), key

def simulator(theta, x, rng_key=None):
    # Generate noise using JAX's random number generation
    if rng_key is None:
        rng_key = jax.random.PRNGKey(int(time.time()*1e6))
    rng_key, subkey = jax.random.split(rng_key)  # Split the key for noise generation
    sigma = jax.random.normal(subkey, shape=(), dtype=jnp.float32) * 0.1
    # Generate the simulated data 
    sim = jnp.dot(theta, x) + sigma
    #print(f'The simulated data are: ', sim)
    return sim, rng_key

def inference(y, x, theta, e, rng_key=None):
    accepted_samples = []
    for i in range(len(theta)):
        sim, rng_key = simulator(theta[i, :], x, rng_key)
        dis = jnp.linalg.norm(sim - y)
        #print(f'The distance is: ', dis)
        if dis < e:
            accepted_samples.append(theta[i, :])
    return jnp.array(accepted_samples)

# Define the dimensions and number of samples
N = 100000
dim = 10 # 2: 1224, 5: 1051, 10: 970

# Observation
y_obs = jnp.array(0.3)
x = jax.random.uniform(jax.random.PRNGKey(0), shape=(dim,), minval=-0.5, maxval=0.5)
print(x)
e = 0.01

# Sampling thetas
#global_key = jax.random.PRNGKey(0)
theta_pr, global_key = prior(N, dim)
#print(f'Theta priors are: ', theta_pr)

# List of the accepted samples
samples_pos = inference(y_obs, x, theta_pr, e)
print("Accepted samples shape:", samples_pos.shape)


[-0.14509487  0.10419905 -0.07241571 -0.26938403 -0.17014146 -0.06046343
 -0.24900234 -0.22269428  0.26782072  0.21474564]
Accepted samples shape: (970, 10)
