In [2]:
import datetime
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "0"


import time
from functools import partial
from pathlib import Path

import blackjax
import galsim
import jax
import jax.numpy as jnp
import jax_galsim as xgalsim
import numpy as np
import optax
from jax import jit as jjit
from jax import random, vmap
from jax.scipy import stats

from bpd.chains import inference_loop
from bpd.draw import add_noise
from bpd.measure import get_snr

print("devices available:", jax.devices())

SCRATCH_DIR = Path("/pscratch/sd/i/imendoza/data/cache_chains")


# GPU preamble
GPU = jax.devices("gpu")[0]

jax.config.update("jax_default_device", GPU)

jax.config.update("jax_enable_x64", True)

devices available: [CudaDevice(id=0)]


In [3]:
PIXEL_SCALE = 0.2
BACKGROUND = 1e4
SLEN = 53
PSF_HLR = 0.7
GSPARAMS = xgalsim.GSParams(minimum_fft_size=256, maximum_fft_size=256)

LOG_FLUX = 4.5
HLR = 0.9
G1 = 0.05
G2 = 0.0
X = 0.0
Y = 0.0

TRUE_PARAMS = {"f": LOG_FLUX, "hlr": HLR, "g1": G1, "g2": G2, "x": X, "y": Y}

# make sure relevant things are in GPU
TRUE_PARAMS_GPU = jax.device_put(TRUE_PARAMS, device=GPU)
BACKGROUND_GPU = jax.device_put(BACKGROUND, device=GPU)
BOUNDS = {
    "f": (-1.0, 9.0),
    "hlr": (0.01, 5.0),
    "g1": (-0.7, 0.7),
    "g2": (-0.7, 0.7),
    "x": 1,  # sigma (in pixels)
    "y": 1,  # sigma (in pixels)
}
BOUNDS_GPU = jax.device_put(BOUNDS, device=GPU)


# run setup
IS_MATRIX_DIAGONAL = True
N_WARMUPS = 500
N_SAMPLES = 1000
SEED = 42

# chees setup
LR = 1e-3
INIT_STEP_SIZE = 0.1


# sample from ball around some dictionary of true params
def sample_ball(rng_key, center_params: dict):
    new = {}
    keys = random.split(rng_key, len(center_params.keys()))
    rng_key_dict = {p: k for p, k in zip(center_params, keys)}
    for p in center_params:
        centr = center_params[p]
        if p == "f":
            new[p] = random.uniform(
                rng_key_dict[p], shape=(), minval=centr - 0.25, maxval=centr + 0.25
            )
        elif p == "hlr":
            new[p] = random.uniform(
                rng_key_dict[p], shape=(), minval=centr - 0.2, maxval=centr + 0.2
            )
        elif p in {"g1", "g2"}:
            new[p] = random.uniform(
                rng_key_dict[p], shape=(), minval=centr - 0.025, maxval=centr + 0.025
            )
        elif p in {"x", "y"}:
            new[p] = random.uniform(
                rng_key_dict[p], shape=(), minval=centr - 0.5, maxval=centr + 0.5
            )
    return new


def _draw_gal():
    gal = galsim.Gaussian(flux=10**LOG_FLUX, half_light_radius=HLR)
    gal = gal.shift(dx=X, dy=Y)
    gal = gal.shear(g1=G1, g2=G2)

    psf = galsim.Gaussian(flux=1.0, half_light_radius=PSF_HLR)
    gal_conv = galsim.Convolve([gal, psf])
    image = gal_conv.drawImage(nx=SLEN, ny=SLEN, scale=PIXEL_SCALE)
    return image.array


def draw_gal(f, hlr, g1, g2, x, y):
    # x, y arguments in pixels
    gal = xgalsim.Gaussian(flux=10**f, half_light_radius=hlr)
    gal = gal.shift(dx=x * PIXEL_SCALE, dy=y * PIXEL_SCALE)
    gal = gal.shear(g1=g1, g2=g2)

    psf = xgalsim.Gaussian(flux=1, half_light_radius=PSF_HLR)
    gal_conv = xgalsim.Convolve([gal, psf]).withGSParams(GSPARAMS)
    image = gal_conv.drawImage(nx=SLEN, ny=SLEN, scale=PIXEL_SCALE)
    return image.array


def _logprob_fn(params, data):

    # prior
    prior = jnp.array(0.0, device=GPU)
    for p in ("f", "hlr", "g1", "g2"):  # uniform priors
        b1, b2 = BOUNDS_GPU[p]
        prior += stats.uniform.logpdf(params[p], b1, b2 - b1)

    for p in ("x", "y"):  # normal
        sigma = BOUNDS_GPU[p]
        prior += stats.norm.logpdf(params[p], sigma)

    # likelihood
    model = draw_gal(**params)
    likelihood_pp = stats.norm.logpdf(data, loc=model, scale=jnp.sqrt(BACKGROUND_GPU))
    likelihood = jnp.sum(likelihood_pp)

    return prior + likelihood

In [5]:
snr = get_snr(_draw_gal(), BACKGROUND)
print("galaxy snr:", snr)

# get data
_data = add_noise(_draw_gal(), BACKGROUND, rng=np.random.default_rng(SEED), n=1)[0]
data_gpu = jax.device_put(_data, device=GPU)
print("data info:", data_gpu.devices(), type(data_gpu), data_gpu.shape)

# collect random keys we need
rng_key = random.key(SEED)
rng_key = jax.device_put(rng_key, device=GPU)

ball_key, warmup_key, sample_key = random.split(rng_key, 3)

ball_keys = random.split(ball_key, 100)
sample_keys = random.split(sample_key, 100)

# get initial positions for all chains
all_init_positions = vmap(sample_ball, in_axes=(0, None))(
    ball_keys, TRUE_PARAMS_GPU
)


galaxy snr: 18.25107
data info: {CudaDevice(id=0)} <class 'jaxlib.xla_extension.ArrayImpl'> (53, 53)


In [7]:
sample_keys.shape, all_init_positions['f'].shape

((100,), (100,))

In [None]:
# jit and vmap functions to run chains
_run_inference = vmap(jjit(do_inference), in_axes=(0, 0, None, None))

In [44]:
OPTIM = optax.adam(LR)
def do_warmup(rng_key, positions, data, n_chains: int = None):
    logdensity = partial(_logprob_fn, data=data)
    warmup = blackjax.chees_adaptation(logdensity, n_chains)
    # `positions` = PyTree where each leaf has shape (num_chains, ...)
    return warmup.run(rng_key, positions, INIT_STEP_SIZE, OPTIM, N_WARMUPS)


def do_inference(rng_key, init_states, data, tuned_params: dict):
    _logdensity = partial(_logprob_fn, data=data)
    kernel = blackjax.dynamic_hmc(_logdensity, **tuned_params).step
    return inference_loop(rng_key, init_states, kernel=kernel, n_samples=1000)

In [40]:
(last_states, tuned_params), _ = do_warmup(warmup_key, all_init_positions, data_gpu, n_chains=100)
last_states.position['f'].shape

(100,)

In [45]:
_run_inference = vmap(do_inference, in_axes=(0, 0, None, None))

In [47]:
%%time
states, info = _run_inference(sample_keys, last_states, data_gpu, tuned_params)

CPU times: user 13.5 s, sys: 147 ms, total: 13.7 s
Wall time: 12.8 s
