Skip to content

Commit

Permalink
Merge pull request #16 from arnauqb/fix_grad
Browse files Browse the repository at this point in the history
Fix grad
  • Loading branch information
arnauqb committed Jul 11, 2023
2 parents 7791d50 + 84322ae commit ed8fafe
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions blackbirds/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def step(self,
if self.discretisation_method == 'e-m':
if self._proposal is None:
# This would happen if the user hasn't initialised the chain themselves
gradient_term = torch.matmul(sC, self._previous_grad_theta_of_log_density[0])
gradient_term = torch.matmul(sC, self._previous_grad_theta_of_log_density)
mean = current_state + gradient_term
logger.debug("Total mean =", mean)
logger.debug("Gradient_term =", gradient_term)
Expand All @@ -121,7 +121,7 @@ def step(self,
# Compute reverse proposal logpdf
if self.discretisation_method == 'e-m':
try:
rev_proposal = torch.distributions.multivariate_normal.MultivariateNormal(new_state + torch.matmul(sC, grad_theta_of_new_log_density[0]),
rev_proposal = torch.distributions.multivariate_normal.MultivariateNormal(new_state + torch.matmul(sC, grad_theta_of_new_log_density),
covariance_matrix = 2 * sC)
except ValueError as e:
logger.debug(new_state, grad_theta_of_new_log_density)
Expand Down

0 comments on commit ed8fafe

Please sign in to comment.