In [1]:
import os

In [2]:
os.environ["JAX_PLATFORMS"] = "cuda,cpu"
os.environ["CUDA_VISIBLE_DEVICES"]="0"

In [3]:
import jax

In [4]:
from functools import partial

import jax
import jax.numpy as jnp
import typer
from jax import random

from jax import jit 

from bpd import DATA_DIR
from bpd.io import load_dataset_jax, save_dataset
from bpd.jackknife import run_bootstrap_shear_pipeline
from bpd.pipelines import pipeline_shear_inference_simple, logtarget_shear

In [5]:
from bpd.chains import run_inference_nuts
from bpd.likelihood import shear_loglikelihood
from bpd.prior import (
    ellip_prior_e1e2,
    interim_gprops_logprior,
    true_all_params_skew_logprior,
    true_ellip_logprior,
)
from bpd.sample import sample_noisy_ellipticities_unclipped
from bpd.utils import uniform_logpdf

In [6]:
CPU = jax.devices('cpu')[0]
GPU = jax.devices('gpu')[0]

# Run

In [7]:
dirpath = DATA_DIR / "cache_chains" / "exp70_51"
samples_plus_fpath = dirpath / "interim_samples_511_plus.npz"
samples_minus_fpath = dirpath / "interim_samples_511_minus.npz"
assert samples_plus_fpath.exists() and samples_minus_fpath.exists()
# fpath = dirpath / f"g_samples_boots_{seed}.npz"

dsp = load_dataset_jax(samples_plus_fpath)
dsm = load_dataset_jax(samples_minus_fpath)

rng_key = random.key(52)
k1, k2 = random.split(rng_key)

n_gals = 100_000

total_n_gals = dsp["samples"]["e1"].shape[0]
if n_gals is None:
    n_gals = total_n_gals
assert n_gals <= total_n_gals
subset = random.choice(k1, jnp.arange(total_n_gals), shape=(n_gals,), replace=False)

e1p = dsp["samples"]["e1"][subset]
e2p = dsp["samples"]["e2"][subset]
e1e2p = jnp.stack([e1p, e2p], axis=-1)

e1m = dsm["samples"]["e1"][subset]
e2m = dsm["samples"]["e2"][subset]
e1e2m = jnp.stack([e1m, e2m], axis=-1)

sigma_e = dsp["hyper"]["shape_noise"]
sigma_e_int = dsp["hyper"]["sigma_e_int"]
assert sigma_e == dsm["hyper"]["shape_noise"]
assert sigma_e_int == dsm["hyper"]["sigma_e_int"]
assert jnp.all(dsp["truth"]["e1"] == dsm["truth"]["e1"])
assert jnp.all(dsp["truth"]["lf"] == dsm["truth"]["lf"])

In [8]:
e1e2p.shape, e1e2m.shape

((100000, 300, 2), (100000, 300, 2))

# Evaluate likelihood timing

## gpu

In [9]:
_logprior = lambda e, g: true_ellip_logprior(e, g, sigma_e=sigma_e)
_interim_logprior = lambda e: jnp.log(ellip_prior_e1e2(e, sigma=sigma_e_int))

_loglikelihood = partial(
    shear_loglikelihood, logprior=_logprior, interim_logprior=_interim_logprior
)
_loglikelihood_jitted = jit(_loglikelihood, device=GPU)

_logtarget = jit(partial(logtarget_shear, loglikelihood=_loglikelihood_jitted), device=GPU)

In [10]:
%%timeit
_logtarget(g=jnp.array([0.02,0.0]), data=e1e2p)

227 μs ± 45.7 μs per loop (mean ± std. dev. of 7 runs, 1 loop each)


## cpu

In [11]:
_logprior = lambda e, g: true_ellip_logprior(e, g, sigma_e=sigma_e)
_interim_logprior = lambda e: jnp.log(ellip_prior_e1e2(e, sigma=sigma_e_int))

_loglikelihood = partial(
    shear_loglikelihood, logprior=_logprior, interim_logprior=_interim_logprior
)
_loglikelihood_jitted = jit(_loglikelihood, device=CPU)

_logtarget = jit(partial(logtarget_shear, loglikelihood=_loglikelihood_jitted), device=CPU)

In [12]:
%%timeit
_logtarget(g=jnp.array([0.02,0.0]), data=e1e2p)

202 ms ± 70.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
