Skip to content

Commit

Permalink
Merge pull request #37 from arnauqb/mcmc_loss
Browse files Browse the repository at this point in the history
Changed name of loss from forecast_loss to loss
  • Loading branch information
arnauqb committed Jul 13, 2023
2 parents 3e53321 + 1824140 commit 6cf29ea
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions blackbirds/infer/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class MCMCKernel(ABC):
- `prior`: The prior distribution. Must be differentiable in its argument.
- `w`: The weight hyperparameter in generalised posterior.
- `gradient_clipping_norm`: The norm to which the gradients are clipped.
- `forecast_loss`: The loss function used in the exponent of the generalised likelihood term. Maps from data and chain state to loss.
- `loss`: The loss function used in the exponent of the generalised likelihood term. Maps from data and chain state to loss.
- `diff_mode`: The differentiation mode to use. Can be either 'reverse' or 'forward'.
- `jacobian_chunk_size`: The number of rows computed at a time for the model Jacobian. Set to None to compute the full Jacobian at once.
- `gradient_horizon`: The number of timesteps to use for the gradient horizon. Set 0 to use the full trajectory.
Expand All @@ -28,7 +28,7 @@ class MCMCKernel(ABC):
def __init__(
self,
prior: torch.distributions.Distribution,
forecast_loss: Callable,
loss: Callable,
w: float = 1.0,
gradient_clipping_norm: float = np.inf,
diff_mode: str = "reverse",
Expand All @@ -40,7 +40,7 @@ def __init__(
self.prior = prior
self.w = w
self.gradient_clipping_norm = gradient_clipping_norm
self.forecast_loss = forecast_loss
self.loss = loss
self.diff_mode = diff_mode
self.jacobian_chunk_size = jacobian_chunk_size
self.gradient_horizon = gradient_horizon
Expand Down Expand Up @@ -73,7 +73,7 @@ class MALA(MCMCKernel):
- `prior`: The prior distribution. Must be differentiable in its argument.
- `w`: The weight hyperparameter in generalised posterior.
- `gradient_clipping_norm`: The norm to which the gradients are clipped.
- `forecast_loss`: The loss function used in the exponent of the generalised likelihood term. Maps from data and chain state to loss.
- `loss`: The loss function used in the exponent of the generalised likelihood term. Maps from data and chain state to loss.
- `diff_mode`: The differentiation mode to use. Can be either 'reverse' or 'forward'.
- `jacobian_chunk_size`: The number of rows computed at a time for the model Jacobian. Set to None to compute the full Jacobian at once.
- `gradient_horizon`: The number of timesteps to use for the gradient horizon. Set 0 to use the full trajectory.
Expand All @@ -96,7 +96,7 @@ def __init__(
def _compute_log_density_and_grad(self, state, data):
_state = state.clone().detach()
_state.requires_grad = True
ell = self.forecast_loss(_state, data)
ell = self.loss(_state, data)
log_prior_pdf = self.prior.log_prob(_state)
log_density = -ell + log_prior_pdf * self.w
log_density.backward()
Expand Down

0 comments on commit 6cf29ea

Please sign in to comment.