### Tutorial in hamiltorch for Running with Multiple Chains

In [1]:
import torch
import hamiltorch
import matplotlib.pyplot as plt
%matplotlib inline

In [2]:
print(hamiltorch.__version__)

0.4.1


## Sampling a multivariate Gaussian with multiple chains

As per normal, we define our example log probability function:

In [3]:
def log_prob(omega):
    mean = torch.tensor([0.,0.,0.])
    stddev = torch.tensor([.5,1.,2.]) 
    return torch.distributions.MultivariateNormal(mean, torch.diag(stddev**2)).log_prob(omega).sum()

In [4]:
# Set up the HMC parameters
N = 400
step_size = .3
L = 5

### Multiple Chains

In order to sample using multiple chains, we need to define a `chain` using the `hamiltorch.util.setup_chain` function. This function takes as arguments:
* `sampler`: This is the sampler function to call. For example, here I use the standard sampler from hamiltorch `hamiltorch.sample`.
* `prior`: To initialize the chains from different starting points, we need to define a `prior` function to sample the initial parameters for each chain.
* `kwargs`: A dictionary of key-word arguments for the `sampler`. In this example we are using the `kwargs` for `hamiltorch.sample`. Note that `params_init` is taken care of via the prior and therefore this ought to be left out of `kwargs`.

In [5]:
kwargs = {'log_prob_func': log_prob, 'num_samples':N, 'step_size':step_size, 'num_steps_per_sample': L, 'verbose': True}
num_workers = 4
seeds = torch.arange(8) # corresponding random seeds to the 4 workers
prior = lambda : 0.1 * torch.randn(3) # Define a Gaussian prior to sample from
chain = hamiltorch.util.setup_chain(hamiltorch.sample, prior, kwargs)

### To run multiple chains

The function `hamiltorch.util.multi_chain` takes the `chain` as an argument as well as the number of workers and the corresponding list of seeds. 

##### Parallel
Depending on your hardware, it might be faster to run the chains in parallel. If so, set `parallel=True` as shown here:

In [6]:
%%time
parallel=True
params_hmc_par = hamiltorch.util.multi_chain(chain, num_workers, seeds, parallel=parallel)

CPU times: user 50 s, sys: 4min 3s, total: 4min 53s
Wall time: 29.9 s


##### Serial
To run in series set `parallel=False` as shown here:

In [7]:
%%time
parallel=False
params_hmc_ser = hamiltorch.util.multi_chain(chain, num_workers, seeds, parallel=parallel)

CPU times: user 1min 12s, sys: 1.94 s, total: 1min 14s
Wall time: 12.9 s


##### Note:

It is likely that there are more efficient ways of running multiple chains, but for now it seems a useful feature to add in its version. These results are on a mac. When I ran on a Linux machine I saw that it was possible to get a 50% speed up with `parallel=True`. This was not with extensive experimentation