In [46]:
import jax
from functools import partial
import numpy as np
import jax.numpy as jnp

In [62]:
class Model:
    def __init__(
        self,
        no_parameters: int,
        no_particles: int,
        space_boundaries: np.array,
        measurement_repetitions: int = 10,
    ) -> None:
        self.no_parameters = no_parameters
        self.no_particles = no_particles
        self.space_boundaries = space_boundaries
        self.measurement_repetitions = measurement_repetitions

    def initialize_particle_locations(self, key):
        return _initialize_particle_locations(
            key, self.no_parameters, self.space_boundaries, self.no_particles
        )
    
    def initialize_weights(self):
        return jnp.ones(self.no_particles)/self.no_particles



def _initialize_particle_locations(
    key, no_of_parameters, boundaries, no_of_particles
):
    subkey = jax.random.split(key, no_of_parameters + 1)
    key = subkey[1]
    subkeys = subkey[1:]
    return (
        key,
        jax.vmap(populate_one_axis, in_axes=(0, 0, None))(
            subkeys, boundaries, no_of_particles
        ).T,
    )


def populate_one_axis(key, bnds, no_particles):
    return jax.random.uniform(
        key, minval=jnp.min(bnds), maxval=jnp.max(bnds), shape=[no_particles]
    )

In [63]:
model = Model(
    no_parameters=8,
    no_particles=100,
    space_boundaries=np.array([[-1, 1] for _ in range(8)]),
)

In [64]:
key = jax.random.PRNGKey(1)

In [65]:
key, locations = model.initialize_particle_locations(key)
weights = model.initialize_weights()

(100,)