Here are the implementations

In [None]:
import numpy as np

def multinomial_resampling(particles, weights):
    """
    Resample particles based on their weights using multinomial resampling.
    :param particles: An array of particles with shape (num_particles, particle_dimension).
    :param weights: An array of weights with shape (num_particles,).
    :return: The resampled particles.
    """
    num_particles = particles.shape[0]

    # Normalize the weights
    normalized_weights = weights / np.sum(weights)

    # Resample particles using the multinomial distribution
    resampled_indices = np.random.choice(num_particles, size=num_particles, p=normalized_weights)

    # Select the resampled particles
    resampled_particles = particles[resampled_indices]

    return resampled_particles

import jax.numpy as jnp
from jax import random, jit

def multinomial_resampling_jax(particles, weights, key):
    num_particles, particle_dimension = particles.shape

    # Normalize the weights
    weights = weights / jnp.sum(weights)

    # Resample particles using the multinomial distribution
    key, subkey = random.split(key)
    resampled_indices = random.categorical(subkey, weights, shape=(num_particles,))
    resampled_particles = particles[resampled_indices]

    return resampled_particles

@jit
def multinomial_resampling_jit(particles, weights, key):
    num_particles, particle_dimension = particles.shape

    # Normalize the weights
    weights = weights / np.sum(weights)

    # Resample particles using the multinomial distribution
    key, subkey = random.split(key)
    resampled_indices = random.categorical(subkey, weights, shape=(num_particles,))
    resampled_particles = particles[resampled_indices]

    return resampled_particles

def systematic_resampling(particles, weights):
    num_particles = particles.shape[0]
    step_size = 1.0 / num_particles
    r = np.random.uniform(0, step_size)
    cumulative_weights = np.cumsum(weights)

    indices = np.zeros(num_particles, dtype=int)
    i = 0
    for m in range(num_particles):
        while r > cumulative_weights[i]:
            i += 1
        indices[m] = i
        r += step_size
        if r >= 1.0:
            r -= 1.0
            i = 0

    return particles[indices]

def systematic_resampling_jax(weights, key):
    n = len(weights)
    indices = jnp.arange(n)
    cum_weights = jnp.cumsum(weights)
    step = cum_weights[-1] / n
    u = (jnp.arange(n) + jax.random.uniform(key, (n,))) * step
    j = jnp.zeros((), dtype=jnp.int32)
    def body_fn(i, j):
        j = jax.lax.cond(cum_weights[j] < u[i],
                         lambda _: j + 1,
                         lambda _: j,
                         operand=None)
        return j, ()
    _, j = jax.lax.scan(body_fn, j, jnp.arange(n))
    indices = jax.ops.index_update(indices, jnp.arange(n), indices[j])
    return indices

@jit
def systematic_resampling_jit(weights, key):
    n = len(weights)
    indices = jnp.arange(n)
    cum_weights = jnp.cumsum(weights)
    step = cum_weights[-1] / n
    u = (jnp.arange(n) + jax.random.uniform(key, (n,))) * step
    j = jnp.zeros((), dtype=jnp.int32)
    def body_fn(i, j):
        j = jax.lax.cond(cum_weights[j] < u[i],
                         lambda _: j + 1,
                         lambda _: j,
                         operand=None)
        return j, ()
    _, j = jax.lax.scan(body_fn, j, jnp.arange(n))
    indices = jax.ops.index_update(indices, jnp.arange(n), indices[j])
    return indices



Test the functions now

In [None]:
num_particles = 100000
particle_dimension = 10
particles = np.random.randn(num_particles, particle_dimension)
weights = np.random.rand(num_particles)
seed = 42
key = random.PRNGKey(seed)
import time

particle_lst = [1,10,100,1000,10000,100000,1000000,10000000]
multi_np_times = []
multi_jax__times = []
multi_jax_jit = []

sys_np_times = []
sys_jax__times = []
sys_jax_jit = []


for n_part in particle_lst:
    pass