Skip to content

Commit

Permalink
laod best model in test
Browse files Browse the repository at this point in the history
  • Loading branch information
arnauqb committed May 12, 2023
1 parent 4124942 commit ebc5846
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 2 deletions.
10 changes: 10 additions & 0 deletions birds/infer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import numpy as np
from copy import deepcopy
import torch
from tqdm import tqdm
import logging
from collections import defaultdict

from birds.mpi_setup import mpi_rank
from birds.forecast import compute_forecast_loss_and_jacobian
Expand Down Expand Up @@ -124,14 +126,20 @@ def run(self, n_epochs, max_epochs_without_improvement=20):
max_epochs_without_improvement (int): The number of epochs without improvement after which the calibrator stops.
"""
best_loss = np.inf
best_model_state_dict = None
num_epochs_without_improvement = 0
iterator = range(n_epochs)
if self.progress_bar and mpi_rank == 0:
iterator = tqdm(iterator)
losses_hist = defaultdict(list)
for _ in tqdm(range(n_epochs)):
loss, forecast_loss, regularisation_loss = self.step()
losses_hist["total"].append(loss.item())
losses_hist["forecast"].append(forecast_loss.item())
losses_hist["regularisation"].append(regularisation_loss.item())
if loss < best_loss:
best_loss = loss
best_model_state_dict = deepcopy(self.posterior_estimator.state_dict())
num_epochs_without_improvement = 0
else:
num_epochs_without_improvement += 1
Expand All @@ -152,3 +160,5 @@ def run(self, n_epochs, max_epochs_without_improvement=20):
)
)
break

return losses_hist, best_model_state_dict
5 changes: 3 additions & 2 deletions test/test_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def test_random_walk(self):
for true_p in true_ps:
data = rw(torch.tensor([true_p]))
posterior_estimator = TrainableGaussian()
optimizer = torch.optim.Adam(posterior_estimator.parameters(), lr=1e-2)
optimizer = torch.optim.Adam(posterior_estimator.parameters(), lr=5e-2)
calib = Calibrator(
model=rw,
posterior_estimator=posterior_estimator,
Expand All @@ -45,6 +45,7 @@ def test_random_walk(self):
optimizer=optimizer,
diff_mode=diff_mode,
)
calib.run(1000)
_, best_model_state_dict = calib.run(1000)
posterior_estimator.load_state_dict(best_model_state_dict)
assert np.isclose(posterior_estimator.mu.item(), true_p, rtol=0.25)
assert posterior_estimator.sigma.item() < 1e-2

0 comments on commit ebc5846

Please sign in to comment.