In [8]:
import jax.numpy as jnp
from jax import Array, random, vmap
from jax._src.prng import PRNGKeyArray
from jax.numpy.linalg import norm
from jax.typing import ArrayLike

In [9]:
def sample_ellip_prior(rng_key: PRNGKeyArray, sigma: float, n: int = 1):
    """Sample n ellipticities isotropic components with Gary's prior from magnitude."""
    key1, key2 = random.split(rng_key, 2)
    e_mag = sample_mag_ellip_prior(key1, sigma=sigma, n=n)
    beta = random.uniform(key2, shape=(n,), minval=0, maxval=jnp.pi)
    return jnp.stack((e_mag, beta), axis=1)

In [10]:
def gete1e2_from_ebeta(e_mag:float, beta:float):
    e1 = e_mag * jnp.cos(2 * beta)
    e2 = e_mag * jnp.sin(2 * beta)
    
    return e1, e2

def get_ebeta_from_e1e2(e1:float, e2:float):
    beta = jnp.arctan2(e2 / e1) * 0.5
    e_mag = jnp.sqrt(e1**2 + e2**2)
    return e_mag, beta


def scalar_shear_transformation(e: Array, g: Array):
    """Transform elliptiticies by a fixed shear (scalar version).

    The transformation we used is equation 3.4b in Seitz & Schneider (1997).

    NOTE: This function is meant to be vmapped later.
    """
    assert e.shape == (2,) and g.shape == (2,)

    e_mag, beta_e = e
    g_mag, beta_g = g
    
    e1, e2 = gete1e2_from_ebeta(e_mag, beta_e)
    g1, g2 = gete1e2_from_ebeta(g_mag, beta_g)

    e_comp = e1 + e2 * 1j
    g_comp = g1 + g2 * 1j

    e_prime = (e_comp + g_comp) / (1 + g_comp.conjugate() * e_comp)
    
    ep1 = e_prime.real
    ep2 = e_prime.imag
    
    ep_mag, beta_p = get_ebeta_from_e1e2(ep1, ep2)
    
    return jnp.array([ep_mag, beta_p])

def scalar_inv_shear_transformation(e: Array, g: Array):
    """Transform elliptiticies by a fixed shear (scalar version).

    The transformation we used is equation 3.4b in Seitz & Schneider (1997).

    NOTE: This function is meant to be vmapped later.
    """
    assert e.shape == (2,) and g.shape == (2,)

    e_mag, beta_e = e
    g_mag, beta_g = g
    
    e1, e2 = gete1e2_from_ebeta(e_mag, beta_e)
    g1, g2 = gete1e2_from_ebeta(g_mag, beta_g)

    e_comp = e1 + e2 * 1j
    g_comp = g1 + g2 * 1j

    e_prime = (e_comp - g_comp) / (1 - g_comp.conjugate() * e_comp)
    
    ep1 = e_prime.real
    ep2 = e_prime.imag
    
    ep_mag, beta_p = get_ebeta_from_e1e2(ep1, ep2)
    
    return jnp.array([ep_mag, beta_p])

# batched
shear_transformation = vmap(scalar_shear_transformation, in_axes=(0, None))
inv_shear_transformation = vmap(scalar_inv_shear_transformation, in_axes=(0, None))

# useful for jacobian later
inv_shear_func1 = lambda e, g: scalar_inv_shear_transformation(e, g)[0]
inv_shear_func2 = lambda e, g: scalar_inv_shear_transformation(e, g)[1]

In [None]:
# first set of likelihoods:
def logtarget(
    e_sheared: Array,
    *,
    data: Array,  # renamed from `e_obs` for comptability with `do_inference_nuts`
    sigma_m: float,
    interim_prior: Callable,
):
    e_obs = data
    assert e_sheared.shape == (2,) and e_obs.shape == (2,)

    # ignore angle prior assumed uniform
    # prior enforces magnitude < 1.0 for posterior samples
    prior = jnp.log(interim_prior(e_sheared[0]))
    likelihood = jnp.sum(jsp.stats.norm.logpdf(e_obs, loc=e_sheared, scale=sigma_m))
    return prior + likelihood