# Benchmark 2

We have different initial positions, running multiple chains and multiple vectorized warmups, on the same target image

In [28]:
SEED = 42

In [29]:
# get data
data = add_noise(_draw_gal(), BACKGROUND, 
                 rng=np.random.default_rng(SEED), 
                 n=1) # maximum number of chains
data_gpu = jax.device_put(data, device=GPU)
print(data_gpu.devices(), type(data_gpu), data_gpu.shape)

{CudaDevice(id=0)} <class 'jaxlib.xla_extension.ArrayImpl'> (53, 53)


In [31]:
# base rng key
rng_key = jax.random.key(SEED)
rng_key = jax.device_put(rng_key, device=GPU)
prior_key, sample_key, warmup_key = random.split(rng_key, 3)
print(rng_key.devices())

{CudaDevice(id=0)}


In [32]:
# what parameter space do we want to explore?
N_WARMUPS = (10, 50, 100, 200, 500, 1000)
MAX_DOUBLINGS = (5, 10)
N_CHAINS = (1, 5, 10, 25, 50, 100) # probably explodes beyond 1000

In [33]:
results_dict2 = {
        md:{n_warmups:{n_chains:{} for n_chains in N_CHAINS} for n_warmups in N_WARMUPS} for md in MAX_DOUBLINGS
}

In [34]:
all_init_positions = prior_sample(prior_key, max(N_CHAINS)) # subsampel below

In [45]:
# vmap only rng_key
def do_warmup(rng_key, init_position:dict, data, n_warmups:int,  max_num_doublings:int):
    
    _logdensity = partial(_logprob_fn, data=data)
    
    warmup = blackjax.window_adaptation(
    blackjax.nuts, _logdensity, progress_bar=False, is_mass_matrix_diagonal=False, 
        max_num_doublings=max_num_doublings, 
        initial_step_size=0.1, target_acceptance_rate=0.90
    )    
    return warmup.run(rng_key, init_position, n_warmups) # (init_states, tuned_params), adapt_info

    
def do_inference(rng_key, init_state, data, step_size:float, inverse_mass_matrix, max_num_doublings:int, n_samples:int):
    _logdensity = partial(_logprob_fn, data=data)
    kernel = blackjax.nuts(_logdensity, step_size=step_size, inverse_mass_matrix=inverse_mass_matrix, 
                           max_num_doublings=max_num_doublings).step
    return inference_loop(rng_key, init_state, kernel=kernel, n_samples=n_samples) # state, info
    


In [49]:
# here we do the warmup independently for every chain regardless of the fact that we target the same data 
# it's an open question whether we can get away with 1 warmup for a given data, but maybe this will help clarify

print('md, n_warmup, n_chains')
for md in MAX_DOUBLINGS:
    for n_warmups in N_WARMUPS:
        for n_chains in N_CHAINS:
            print(md, n_warmups, n_chains)
            n_samples = 1000 // n_chains
            _run_warmup = jax.vmap(
                            jax.jit(
                                partial(
                                    do_warmup, n_warmups=n_warmups, max_num_doublings=md
                                )
                        ), in_axes=(0, 0, None)
            )
            _run_inference = jax.vmap(
                                jax.jit(
                                    partial(
                                        do_inference, max_num_doublings=md, n_samples=n_samples, 
                                )
                            ), in_axes=(0, 0, None, 0, 0)
            )
            
            # prepare initialization 
            warmup_keys = random.split(warmup_key, n_chains)
            sample_keys = random.split(sample_key, n_chains)
                
            _init_positions = {p:q[:n_chains] for p,q in all_init_positions.items()}
            

            # compilation times
            t1 = time.time()
            (_states, _tuned_params), _ = jax.block_until_ready(_run_warmup(warmup_keys, _init_positions, data_gpu))
            t2 = time.time()
            results_dict2[md][n_warmups][n_chains]['warmup_comp_time'] = t2 - t1 

            t1 = time.time()
            _ = jax.block_until_ready(_run_inference(sample_keys, _states, data_gpu, 
                                                                _tuned_params['step_size'], 
                                                                _tuned_params['inverse_mass_matrix']))
            t2 = time.time()
            results_dict2[md][n_warmups][n_chains]['inference_comp_time'] = t2 - t1 
                
            # run times
            t1 = time.time()
            (init_states, tuned_params), _ = jax.block_until_ready(_run_warmup(warmup_keys, _init_positions, data_gpu))
            t2 = time.time()
            results_dict2[md][n_warmups][n_chains]['warmup_run_time'] = t2 - t1 

            t1 = time.time()
            states, infos = jax.block_until_ready(_run_inference(sample_keys, init_states, data_gpu, 
                                                                 tuned_params['step_size'], 
                                                                 tuned_params['inverse_mass_matrix']
                                                                ))
            t2 = time.time()
            results_dict2[md][n_warmups][n_chains]['inference_run_time'] = t2 - t1 
            
            # save states and info for future reference
            results_dict2[md][n_warmups][n_chains]['states'] = states
            results_dict2[md][n_warmups][n_chains]['info'] = infos
print('DONE!')

md, n_warmup, n_chains
5 10 1
5 10 5
5 10 10
5 10 25
5 10 50
5 10 100
5 50 1
5 50 5
5 50 10
5 50 25
5 50 50
5 50 100
5 100 1
5 100 5
5 100 10
5 100 25
5 100 50
5 100 100
5 200 1
5 200 5
5 200 10
5 200 25
5 200 50
5 200 100
5 500 1
5 500 5
5 500 10
5 500 25
5 500 50
5 500 100
5 1000 1
5 1000 5
5 1000 10
5 1000 25
5 1000 50
5 1000 100
10 10 1
10 10 5
10 10 10
10 10 25
10 10 50
10 10 100
10 50 1
10 50 5
10 50 10
10 50 25
10 50 50
10 50 100
10 100 1
10 100 5
10 100 10
10 100 50
10 100 100
10 200 1
10 200 5
10 200 10
10 200 25


KeyboardInterrupt: 

In [52]:
# import pickle 
# with open('partial_results_benchmark2.pickle', 'wb') as f:
#     pickle.dump(results_dict2, f, protocol=pickle.HIGHEST_PROTOCOL)

In [55]:
import pickle 
with open('partial_results_benchmark2.pickle', 'rb') as f:
    ex = pickle.load(f)
ex[5][10][10].keys()

dict_keys(['warmup_comp_time', 'inference_comp_time', 'warmup_run_time', 'inference_run_time', 'states', 'info'])

In [None]:
md = 5
warmup_times = 
run_times = 