In [1]:
from phasespace import GenParticle
from particle import Particle
import jax.numpy as np
import jax.random as rjax

In [2]:
# from phspdecay import generate

In [3]:
# ws, genpcls = generate('', 3)
# genpcls

In [4]:
from cluster import Cluster, momentum_from_cluster_jacobian



In [5]:
from helix import Helix

In [6]:
def random_cluster(rng, N):
    """ Helper function """
    energy = rjax.uniform(rng, (N,), minval=0., maxval=3.)
    costh = rjax.uniform(rng, (N,), minval=-1., maxval=1.)
    phi = rjax.uniform(rng, (N,), minval=-np.pi, maxval=np.pi)
    return Cluster(energy, costh, phi)

In [7]:
rng = rjax.PRNGKey(seed=0)

In [8]:
N = 100
clu = random_cluster(rng, N)
jac = momentum_from_cluster_jacobian(clu)

In [9]:
jac.px.energy.shape

(100, 1)

In [10]:
def test_momentum_from_cluster_jacobian():
    """ """
    N = 100
    clu = random_cluster(rng, N)
    jac = momentum_from_cluster_jacobian(clu)

    assert jac.px.energy.shape == (N, 1)
    assert jac.py.energy.shape == (N, 1)
    assert jac.pz.energy.shape == (N, 1)
    assert jac.px.costh.shape == (N, 1)
    assert jac.py.costh.shape == (N, 1)
    assert jac.pz.costh.shape == (N, 1)
    assert jac.px.phi.shape == (N, 1)
    assert jac.py.phi.shape == (N, 1)
    assert jac.pz.phi.shape == (N, 1)

In [11]:
test_momentum_from_cluster_jacobian()

In [12]:
from cluster import cartesian_to_cluster
from cartesian import Position, Momentum

In [13]:
N = 100
pos = Position.from_ndarray(rjax.uniform(rng, (N, 3)))
mom = Momentum.from_ndarray(rjax.uniform(rng, (N, 3)))
clu = cartesian_to_cluster(pos, mom)

In [14]:
clu.as_array.shape

(100, 3)

In [15]:
from cluster import sample_cluster_resolution, cluster_covariance
import jax

In [16]:
def test_cluster_covariance():
    """ """
    N = 100
    clu = random_cluster(rng, N)
    cov = cluster_covariance(clu)

    assert cov.shape == (N, 3, 3)

In [17]:
test_cluster_covariance()

In [18]:
def test_sample_cluster_resolution():
    N = 100
    clu = random_cluster(rng, N)
    sclu, cov = sample_cluster_resolution(clu)

    assert cov.shape == (N, 3, 3)
    assert sclu.as_array.shape == (N, 3)

In [19]:
test_sample_cluster_resolution()

In [20]:
cov = cluster_covariance(clu)
mvn = jax.vmap(lambda cov: rjax.multivariate_normal(rng, np.zeros(cov.shape[-1]), cov))
dclu = mvn(cov)

In [21]:
dclu.shape

(100, 3)

In [22]:
cov.shape

(100, 3, 3)