In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import pdb
from discs.common import math_util as math
from discs.samplers import locallybalanced
import jax
import numpy as np
from copy import deepcopy
import jax.numpy as jnp
import sys
import ml_collections
from tqdm import tqdm

In [3]:
from discs.samplers.mixed import BlockGibbsSampler, HMCWrapper
from discs.models import abstractmodel
from discs.common import configs as common_configs

In [4]:
sys.path.append('/Users/ankitkumar/nse/network_control')

In [5]:
import ergm_jax
from discs.samplers.dlmc import BinaryDLMC
from discs.samplers.configs import dlmc_config

### Testing HMC Wrapper

In [6]:
# Configuration for HMC/NUTS Wrapper 
config = common_configs.get_config()
# model_config = config_dict.ConfigDict(model_config)
# config.model.update(model_config)
#sampler_config = {'step_size':5, 'integrator': leapfrog, 'path_len':4, 'tune':0}
sampler_config = {'step_size':1e-1}
sampler_config = ml_collections.ConfigDict(sampler_config)
config.sampler.update(sampler_config)

In [7]:
dim = 10
burnin = 100
chain_length = 500
n_chains = 3

In [8]:
# Sampling loop
rnd = jax.random.PRNGKey(4)
init_rng, step_rng, model_rng = jax.random.split(rnd, 3)
model = ergm_jax.Gaussian(10)
model_params = model.make_init_params(model_rng)
sample_init_rng, sampler_init_rng = jax.random.split(init_rng, 2)
x = model.get_init_samples(sample_init_rng, n_chains)
sampler = HMCWrapper(config, model, model_params)
sampler_state = sampler.make_init_state(sampler_init_rng, x)
samples = []

In [9]:
# Is the mask implementation zeroing out the log prob gradient?

In [10]:
xmask = jnp.tile(jnp.array([1, 0, 1, 0, 1, 0, 1, 0, 1, 0]), (3, 1))
model.get_value_and_grad(model_params, x, xmask)

(Array([-9.704621 , -5.9627075, -8.972922 ], dtype=float32),
 Array([[-0.82199687,  0.        , -0.34370762,  0.        ,  2.6400461 ,
          0.        ,  1.9081315 ,  0.        , -1.9062486 , -0.        ],
        [-1.8160597 ,  0.        , -0.81580585,  0.        ,  1.7433207 ,
          0.        ,  0.7380365 ,  0.        ,  0.08427978, -0.        ],
        [-1.7859294 ,  0.        , -0.83835304,  0.        ,  2.0086946 ,
          0.        ,  2.369164  ,  0.        ,  0.18665051, -0.        ]],      dtype=float32))

In [25]:
# test the use of a mask
for idx in tqdm(range(burnin + chain_length)):
    #x_old = deepcopy(x)
    step_rng_, step_rng = jax.random.split(step_rng, 2)
    x, sampler_state, _ = sampler.step(model, step_rng_, x, model_params, sampler_state, xmask)
    # if idx % 100 == 0:
    #     print(sampler_state['p_accept'])
    if idx > burnin:
        samples.append(x)
        #print(x-x_old)

100%|██████████| 600/600 [00:30<00:00, 19.90it/s]


In [37]:
sampler_state.keys()

dict_keys(['num_ll_calls', 'steps', 'index', 'accepted_results', 'is_accepted', 'log_accept_ratio', 'proposed_state', 'proposed_results', 'extra', 'seed'])

In [36]:
sampler_state

{'num_ll_calls': Array(0, dtype=int32),
 'steps': Array(0, dtype=int32),
 'index': Array(0, dtype=int32),
 'accepted_results': UncalibratedHamiltonianMonteCarloKernelResults(
   log_acceptance_correction=Array(0., dtype=float32),
   target_log_prob=Array(-40.102097, dtype=float32),
   grads_target_log_prob=[Array([[-0.82199687,  1.474191  , -0.34370762,  3.8726501 ,  2.6400461 ,
             -0.5938172 ,  1.9081315 ,  0.33985704, -1.9062486 , -3.3789067 ],
            [-1.8160597 ,  0.04224145, -0.81580585,  2.4222043 ,  1.7433207 ,
              0.40339226,  0.7380365 ,  0.07806851,  0.08427978, -1.7560818 ],
            [-1.7859294 ,  0.33613083, -0.83835304,  2.3277597 ,  2.0086946 ,
             -0.48197478,  2.369164  , -0.06954838,  0.18665051,  0.30052185]],      dtype=float32)],
   initial_momentum=[Array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
            [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
            [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32)],
   fin

In [10]:
# How much overhead is there in the NUTS sampler compared to HMC? For the Gaussian test case, about a 
# factor of 10. The important determinant will be the ESS per CPU time. We can also try
# https://www.tensorflow.org/probability/api_docs/python/tfp/mcmc/SimpleStepSizeAdaptation

### Testing Mixed Sampler

In [7]:
# testing a mixed/discrete object. Key points about sampling - need to separately update the 
# adjacency matrix and the weight matrix, but conditioned on the discrete support, the weight sampler
# should only update non-zero weights, or at least should have small probability in doing so...
dim = 10
burnin = 1000
chain_length = 5000
n_chains = 3

# Initialize model up front
N = 10
g = np.zeros(N)   #g for group
g[int(N/2):] = 1   #2 groups for now

config = common_configs.get_config()
# Only need to record size attribute and num_categories to 2 (binary)
model_config = {'shape':(N, N), 'num_categories':2}
model_config = ml_collections.ConfigDict(model_config)
config.model.update(model_config)

#### Product model initialization
rng = jax.random.PRNGKey(420)
sample_init_rng, param_init_rng, rng = jax.random.split(rng, 3)
model0 = ergm_jax.WeightedSBM(comm_assignment=g, N=N) 
model1 = ergm_jax.deg_corrected_SBM(N=N, comm_assignment=g, sbm_only=False, degree_only=False) 
model = ergm_jax.ERGMProduct(N = N, ergm_children = (model0,model1), domain_categories = (1,0))
model_params = model.make_init_params(param_init_rng)
x = model.get_init_samples(sample_init_rng, n_chains)

dsampler_config = dlmc_config.get_config()
dconfig = deepcopy(config)
dconfig.sampler.update(dsampler_config)
dsampler = BinaryDLMC(dconfig)

sampler_config = {'step_size':5}
sampler_config = ml_collections.ConfigDict(sampler_config)
config.sampler.update(sampler_config)
# need to separate out the WeightedSBM parameters only
csampler = HMCWrapper(config, model0, model_params[0:model0.N_pars])
sampler = BlockGibbsSampler(dsampler, csampler)

rnd = jax.random.PRNGKey(165)
init_rng, step_rng, model_rng = jax.random.split(rnd, 3)
sample_init_rng, sampler_init_rng = jax.random.split(init_rng, 2)

In [9]:
x = model.get_init_samples(sample_init_rng, n_chains)
xd, xc = model.separate_sample(x)
sampler_state = sampler.make_init_state(sampler_init_rng, xd, xc)
samples = []
for idx in tqdm(range(burnin + chain_length)):
    step_rng_, step_rng = jax.random.split(step_rng, 2)
    xd, xc, sampler_state = sampler.step(model, step_rng_, xd, xc, model_params, sampler_state)
    if idx > burnin:
        samples.append((xd, xc))

  1%|          | 64/6000 [00:03<05:35, 17.72it/s] 


KeyboardInterrupt: 

### Testing first on binomial times Gaussian

In [29]:
# testing a mixed/discrete object. Key points about sampling - need to separately update the 
# adjacency matrix and the weight matrix, but conditioned on the discrete support, the weight sampler
# should only update non-zero weights, or at least should have small probability in doing so...
dim = 10
burnin = 1000
chain_length = 5000
n_chains = 3

# Initialize model up front
N = 10
config = common_configs.get_config()
# Only need to record size attribute and num_categories to 2 (binary)
model_config = {'shape':(N,), 'num_categories':2}
model_config = ml_collections.ConfigDict(model_config)
config.model.update(model_config)

#### Product model initialization
rng = jax.random.PRNGKey(420)
sample_init_rng, param_init_rng, rng = jax.random.split(rng, 3)
model0 = ergm_jax.Gaussian(10)
model1 = ergm_jax.Binomial(10)
model = ergm_jax.ERGMProduct(N = N, ergm_children = (model0,model1), domain_categories = (1,0))
model_params = model.make_init_params(param_init_rng)
x = model.get_init_samples(sample_init_rng, n_chains)

dsampler_config = dlmc_config.get_config()
dconfig = deepcopy(config)
dconfig.sampler.update(dsampler_config)
dsampler = BinaryDLMC(dconfig)

sampler_config = {'step_size':5}
sampler_config = ml_collections.ConfigDict(sampler_config)
config.sampler.update(sampler_config)
# need to separate out the WeightedSBM parameters only
csampler = HMCWrapper(config, model0, model_params[0:model0.N_pars])
sampler = BlockGibbsSampler(dsampler, csampler)
                
rnd = jax.random.PRNGKey(165)
init_rng, step_rng, model_rng = jax.random.split(rnd, 3)
sample_init_rng, sampler_init_rng = jax.random.split(init_rng, 2)

In [None]:
# Implementation question - currently in ERGM product we 
# assume a separable energy function.

# We need to take gradients with respect to the samples...
# This means that we take the 

In [14]:
x = model.get_init_samples(sample_init_rng, n_chains)
xd, xc = model.separate_sample(x)
sampler_state = sampler.make_init_state(sampler_init_rng, xd, xc)
samples = []

for idx in tqdm(range(burnin + chain_length)):
    step_rng_, step_rng = jax.random.split(step_rng, 2)
    xd, xc, sampler_state = sampler.step(model, step_rng_, xd, xc, model_params, sampler_state)
    if idx > burnin:
        samples.append((xd, xc))

  1%|          | 32/6000 [00:03<11:00,  9.04it/s] 


KeyboardInterrupt: 

In [None]:
# Simplest test one can engineer - do we recover the right
# expected values from the samples?

In [None]:
# How to test convergence to the right distribution? Try a histogram
# test (i.e. chi squared). Keep the dimensionality of the problem 
# low...

# Before we even get quantitiative, if we can manage to sample
# from the ground truth distribution, we could always just 
# inspect the marginal histograms...

In [None]:
# Open question: How does one calculate the effective sample size
# of a distribution with mixed support? We could just go ahead
# and calculate it anyways