*09/01/24*
Authors: Ismael Mendoza

Here I check two things, that the jax-galsim jit compiled function to draw a model is allowed by the transfer guard, and the same thing for an inference loop in blackjax with a nuts kernel

In [1]:
import jax 
import jax.numpy as jnp
from jax.scipy import stats

from jax import random
from jax import jit as jjit

In [2]:
import galsim 
import jax_galsim as xgalsim

2024-09-01 18:25:11.102540: W external/xla/xla/service/gpu/nvptx_compiler.cc:836] The NVIDIA driver's CUDA version is 12.2 which is older than the PTX compiler version (12.6.20). Because the driver is older than the PTX compiler version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


In [3]:
import blackjax
import numpy as np 

import arviz as az
import chainconsumer as cc
import matplotlib.pyplot as plt

import numpyro
import time
from datetime import date

import matplotlib.pyplot as plt 

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
from tqdm import tqdm

In [5]:
from functools import partial

In [6]:
from blackjax.diagnostics import effective_sample_size, potential_scale_reduction

In [7]:
import bpd
from bpd.draw import add_noise
from bpd.measure import get_snr


In [8]:
from bpd.chains import inference_loop

In [9]:
jax.devices()

[CudaDevice(id=0), CudaDevice(id=1), CudaDevice(id=2), CudaDevice(id=3)]

In [10]:
GPU = jax.devices('gpu')[0]

In [11]:
jax.config.update("jax_enable_x64", True) # fair comparison with GalSim, among other reasons

# Drawing galaxy model 

In [19]:
# constant galaxy parameters to fit 
PIXEL_SCALE = 0.2
BACKGROUND = 1e4
SLEN= 53
PSF_HLR = 0.7

LOG_FLUX = 4.5
HLR = 0.9
G1 = 0.05
G2 = 0.0
X=0.
Y=0.

TRUE_PARAMS = {'f': LOG_FLUX, 'hlr': HLR, 'g1': G1, 'g2': G2, 'x': X, 'y': Y}

In [20]:
from functools import partial
# jax drawing
GSPARAMS = xgalsim.GSParams(minimum_fft_size=256, maximum_fft_size=256)

def draw_gal(f, hlr, g1, g2, x, y):
    # x, y arguments in pixels
    gal = xgalsim.Gaussian(flux=10**f, half_light_radius=hlr)
    gal = gal.shift(dx=x*PIXEL_SCALE, dy=y*PIXEL_SCALE)
    gal = gal.shear(g1=g1, g2=g2)
    
    psf = xgalsim.Gaussian(flux=1., half_light_radius=PSF_HLR)
    gal_conv = xgalsim.Convolve([gal, psf]).withGSParams(GSPARAMS)
    image = gal_conv.drawImage(nx=SLEN,ny=SLEN, ## if I change this to SLEN_GPU, get error
                               scale=PIXEL_SCALE_GPU,
                            )
    return image.array

In [21]:
TRUE_PARAMS_GPU = jax.device_put(TRUE_PARAMS, device=GPU)

draw_gal_jitted = jax.jit(draw_gal, backend='gpu')
_ = draw_gal_jitted(**TRUE_PARAMS_GPU)

In [23]:
with jax.transfer_guard('disallow'):
    _ = draw_gal_jitted(**TRUE_PARAMS) # always gives error

XlaRuntimeError: INVALID_ARGUMENT: Disallowed host-to-device transfer: aval=ShapedArray(float64[]), dst_sharding=SingleDeviceSharding(device=CudaDevice(id=0))

In [24]:
with jax.transfer_guard('disallow'):
    _ = draw_gal_jitted(**TRUE_PARAMS_GPU) # OK

# Simple inference loop

In [28]:
def _draw_gal():
    gal = galsim.Gaussian(flux=10**LOG_FLUX, half_light_radius=HLR)
    gal = gal.shift(dx=X, dy=Y)
    gal = gal.shear(g1=G1, g2=G2)
    
    psf = galsim.Gaussian(flux=1., half_light_radius=PSF_HLR)
    gal_conv = galsim.Convolve([gal, psf])
    image = gal_conv.drawImage(nx=SLEN,ny=SLEN,scale=PIXEL_SCALE,
                                 )
    return image.array

In [29]:
TRUE_PARAMS_GPU = jax.device_put(TRUE_PARAMS, device=GPU)
BACKGROUND_GPU = jax.device_put(BACKGROUND, device=GPU)
BOUNDS = {'f': (-1., 9.), 'hlr': (0.01, 5.0), 
          'g1': (-0.7, 0.7), 'g2': (-0.7, 0.7), 
          'x': 1,  # sigma (in pixels)
          'y':1 # sigma (in pixels)
}
BOUNDS_GPU = jax.device_put(BOUNDS, device=GPU)

In [30]:
def _logprob_fn(params, data):    
    #prior
    prior = jnp.array(0., device=GPU) 
    for p in ('f', 'hlr', 'g1', 'g2'): # uniform priors
        b1, b2 = BOUNDS_GPU[p]
        prior += stats.uniform.logpdf(params[p], b1, b2-b1)
        
    for p in ('x', 'y'): # normal
        sigma = BOUNDS_GPU[p]
        prior += stats.norm.logpdf(params[p], sigma)

    # likelihood
    model = draw_gal(**params)
    likelihood = stats.norm.logpdf(data, loc=model, scale=jnp.sqrt(BACKGROUND_GPU))
    
    return jnp.sum(prior) + jnp.sum(likelihood)


In [32]:
# get data
SEED = 42

data = add_noise(_draw_gal(), BACKGROUND, 
                 rng=np.random.default_rng(SEED), 
                 n=1)[0]
data_gpu = jax.device_put(data, device=GPU)
print(data_gpu.devices(), type(data_gpu), data_gpu.shape)

{CudaDevice(id=0)} <class 'jaxlib.xla_extension.ArrayImpl'> (53, 53)


In [37]:
# base rng key
rng_key = jax.random.key(SEED)
rng_key = jax.device_put(rng_key, device=GPU)
print(rng_key.devices())

{CudaDevice(id=0)}


In [38]:
init_positions = {**TRUE_PARAMS_GPU}

In [39]:
# warmup function to jit
def call_warmup(rng_key, init_positions, data, n_warmups, max_num_doublings):
    _logdensity = partial(_logprob_fn, data=data)
    warmup = blackjax.window_adaptation(
    blackjax.nuts, _logdensity, progress_bar=False, is_mass_matrix_diagonal=False, 
        max_num_doublings=max_num_doublings, 
        initial_step_size=0.1, 
        target_acceptance_rate=0.90 # sometimes prevents divergences by decrasing final step size, although now priors should be wide enough..
    )
    return warmup.run(rng_key, init_positions, n_warmups) # (init_states, tuned_params), adapt_info

In [40]:
run_warmup = jjit(partial(call_warmup, n_warmups=10, max_num_doublings=5))
_ = run_warmup(rng_key, init_positions, data_gpu)

In [46]:
with jax.transfer_guard('disallow'):
    (state, tuned_params), _ = run_warmup(rng_key, init_positions, data_gpu) #WARMUP OK after compilation

In [51]:
# vmap only rng_key
def do_inference(rng_key, init_state, data, tuned_params:dict, n_samples:int):
    logdensity = partial(_logprob_fn, data=data)
    kernel = blackjax.nuts(logdensity, **tuned_params).step
    return inference_loop(rng_key, init_state, kernel=kernel, n_samples=n_samples)
    


In [55]:
_run_inference =jax.jit(partial(do_inference, n_samples=10, tuned_params=tuned_params))
_ = _run_inference(rng_key, state, data_gpu)

In [56]:
with jax.transfer_guard('disallow'):
    _ = _run_inference(rng_key, state, data_gpu) #INFERENCE OK after compilation