Skip to content

Commit

Permalink
Added assertions to ensure initial_state and data passed to MCMC's .r…
Browse files Browse the repository at this point in the history
…un method are torch.Tensor
  • Loading branch information
joelnmdyer committed Jul 13, 2023
1 parent 4895722 commit 6c48d99
Showing 1 changed file with 3 additions and 0 deletions.
3 changes: 3 additions & 0 deletions blackbirds/infer/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,9 @@ def run(
.step() method.
"""

assert isinstance(initial_state, torch.Tensor), "Initial state of the MCMC chain must be a torch.Tensor"
assert isinstance(data, torch.Tensor), "The data must be passed as a torch.Tensor"

if seed is not None:
torch.manual_seed(seed)
self.reset()
Expand Down

0 comments on commit 6c48d99

Please sign in to comment.