Skip to content

Commit

Permalink
Merge pull request #36 from arnauqb/mcmc_kernel
Browse files Browse the repository at this point in the history
Defined MCMCKernel class to show required form
  • Loading branch information
arnauqb committed Jul 13, 2023
2 parents e708cf1 + 6c48d99 commit 3e53321
Showing 1 changed file with 72 additions and 9 deletions.
81 changes: 72 additions & 9 deletions blackbirds/infer/mcmc.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from abc import ABC, abstractmethod
import logging
import numpy as np
import torch
Expand All @@ -6,10 +7,11 @@

logger = logging.getLogger("mcmc")

class MCMCKernel(ABC):

class MALA:
"""
Class that generates a step in the chain of a Metropolis-Adjusted Langevin Algorithm run.
Abstract base class for MCMC kernels. These kernels specify how to sample the next step in
a generic MCMC chain.
**Arguments**
Expand All @@ -21,7 +23,6 @@ class MALA:
- `jacobian_chunk_size`: The number of rows computed at a time for the model Jacobian. Set to None to compute the full Jacobian at once.
- `gradient_horizon`: The number of timesteps to use for the gradient horizon. Set 0 to use the full trajectory.
- `device`: The device to use for training.
- `discretisation_method`: How to discretise the overdamped Langevin diffusion. Default 'e-m' for Euler-Maruyama
"""

def __init__(
Expand All @@ -44,18 +45,54 @@ def __init__(
self.jacobian_chunk_size = jacobian_chunk_size
self.gradient_horizon = gradient_horizon
self.device = device
self.discretisation_method = discretisation_method
self._dim = self._verify_dim()
self._previous_log_density = None
self._previous_grad_theta_of_log_density = None
self._proposal = None

def _verify_dim(self):
"""
Checks the parameter dimension.
"""
return self.prior.sample((1,)).shape[-1]

@abstractmethod
def step(
self,
current_state: torch.Tensor,
data: torch.Tensor,
*args,
**kwargs
):
pass


class MALA(MCMCKernel):
"""
Class that generates a step in the chain of a Metropolis-Adjusted Langevin Algorithm run.
**Arguments**
- `prior`: The prior distribution. Must be differentiable in its argument.
- `w`: The weight hyperparameter in generalised posterior.
- `gradient_clipping_norm`: The norm to which the gradients are clipped.
- `forecast_loss`: The loss function used in the exponent of the generalised likelihood term. Maps from data and chain state to loss.
- `diff_mode`: The differentiation mode to use. Can be either 'reverse' or 'forward'.
- `jacobian_chunk_size`: The number of rows computed at a time for the model Jacobian. Set to None to compute the full Jacobian at once.
- `gradient_horizon`: The number of timesteps to use for the gradient horizon. Set 0 to use the full trajectory.
- `device`: The device to use for training.
- `discretisation_method`: How to discretise the overdamped Langevin diffusion. Default 'e-m' for Euler-Maruyama
"""

def __init__(
self,
*args,
discretisation_method: str = "e-m",
**kwargs
):
super().__init__(*args, **kwargs)
self.discretisation_method = discretisation_method
self._previous_log_density = None
self._previous_grad_theta_of_log_density = None
self._proposal = None

def _compute_log_density_and_grad(self, state, data):
_state = state.clone().detach()
_state.requires_grad = True
Expand Down Expand Up @@ -173,7 +210,7 @@ class MCMC:

def __init__(
self,
kernel,
kernel: MCMCKernel,
num_samples: int = 100_000,
progress_bar: bool = True,
progress_info: bool = True,
Expand All @@ -188,7 +225,33 @@ def __init__(
def reset(self):
self._samples = []

def run(self, initial_state, data, *args, seed=0, T=1, **kwargs):
def run(
self,
initial_state: torch.Tensor,
data: torch.Tensor,
*args,
seed: int = 0,
T: int = 1,
**kwargs
):

"""
Runs the MCMC chain.
**Arguments**
- `initial_state`: Starting location of the MCMC chain.
- `data`: A torch.Tensor containing the data against which the simulator is compared.
- `seed`: An integer specifying the initial random state of the RNG.
- `T`: An integer specifying the number of steps between updates of the progress info (if shown).
Additional arguments and keyword arguments can be passed, which will be passed to the kernel
.step() method.
"""

assert isinstance(initial_state, torch.Tensor), "Initial state of the MCMC chain must be a torch.Tensor"
assert isinstance(data, torch.Tensor), "The data must be passed as a torch.Tensor"

if seed is not None:
torch.manual_seed(seed)
self.reset()
Expand Down

0 comments on commit 3e53321

Please sign in to comment.