Skip to content

Commit

Permalink
forward mode tested
Browse files Browse the repository at this point in the history
  • Loading branch information
arnauqb committed May 12, 2023
1 parent 56f90cf commit 4124942
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 15 deletions.
8 changes: 7 additions & 1 deletion birds/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ def __init__(
optimizer=None,
n_samples_per_epoch=5,
n_samples_regularisation=10_000,
diff_mode="reverse",
device="cpu",
progress_bar=True,
):
"""
Expand Down Expand Up @@ -55,6 +57,8 @@ def __init__(
self.n_samples_per_epoch = n_samples_per_epoch
self.n_samples_regularisation = n_samples_regularisation
self.progress_bar = progress_bar
self.diff_mode = diff_mode
self.device = device

def _differentiate_loss(
self, forecast_parameters, forecast_jacobians, regularisation_loss
Expand All @@ -78,7 +82,7 @@ def _differentiate_loss(
# then we differentiate the parameters through the flows but also tkaing into account the jacobians of the simulator
to_diff = torch.zeros(1)
for i in range(len(forecast_jacobians)):
to_diff += torch.dot(forecast_jacobians[i], forecast_parameters[i,:])
to_diff += torch.dot(forecast_jacobians[i], forecast_parameters[i, :])
to_diff.backward()

def step(self):
Expand All @@ -96,6 +100,8 @@ def step(self):
parameter_generator=lambda x: self.posterior_estimator.rsample((x,)),
observed_outputs=self.data,
n_samples=self.n_samples_per_epoch,
diff_mode=self.diff_mode,
device=self.device,
)
regularisation_loss = self.w * compute_regularisation_loss(
posterior_estimator=self.posterior_estimator,
Expand Down
34 changes: 20 additions & 14 deletions test/test_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,20 +25,26 @@ def sample(self, x):


class TestInfer:
def test_infer(self):
def test_random_walk(self):
"""
Tests inference in a random walk model.
"""
rw = RandomWalk(100)
true_ps = [0.25, 0.5, 0.75]
prior = torch.distributions.Normal(0.0, 1.0)
for true_p in true_ps:
data = rw(torch.tensor([true_p]))
posterior_estimator = TrainableGaussian()
optimizer = torch.optim.Adam(posterior_estimator.parameters(), lr=1e-2)
calib = Calibrator(
model=rw,
posterior_estimator=posterior_estimator,
prior=prior,
data=data,
optimizer=optimizer,
)
calib.run(1000)
assert np.isclose(calib.posterior_estimator.mu.item(), true_p, rtol=0.25)
for diff_mode in ("reverse", "forward"):
for true_p in true_ps:
data = rw(torch.tensor([true_p]))
posterior_estimator = TrainableGaussian()
optimizer = torch.optim.Adam(posterior_estimator.parameters(), lr=1e-2)
calib = Calibrator(
model=rw,
posterior_estimator=posterior_estimator,
prior=prior,
data=data,
optimizer=optimizer,
diff_mode=diff_mode,
)
calib.run(1000)
assert np.isclose(posterior_estimator.mu.item(), true_p, rtol=0.25)
assert posterior_estimator.sigma.item() < 1e-2

0 comments on commit 4124942

Please sign in to comment.