From ebc5846a795de07135a123ec9fdae5920b1cb414 Mon Sep 17 00:00:00 2001 From: Arnau Quera-Bofarull Date: Fri, 12 May 2023 15:51:21 +0100 Subject: [PATCH] laod best model in test --- birds/infer.py | 10 ++++++++++ test/test_infer.py | 5 +++-- 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/birds/infer.py b/birds/infer.py index 13e3701..f670904 100644 --- a/birds/infer.py +++ b/birds/infer.py @@ -1,7 +1,9 @@ import numpy as np +from copy import deepcopy import torch from tqdm import tqdm import logging +from collections import defaultdict from birds.mpi_setup import mpi_rank from birds.forecast import compute_forecast_loss_and_jacobian @@ -124,14 +126,20 @@ def run(self, n_epochs, max_epochs_without_improvement=20): max_epochs_without_improvement (int): The number of epochs without improvement after which the calibrator stops. """ best_loss = np.inf + best_model_state_dict = None num_epochs_without_improvement = 0 iterator = range(n_epochs) if self.progress_bar and mpi_rank == 0: iterator = tqdm(iterator) + losses_hist = defaultdict(list) for _ in tqdm(range(n_epochs)): loss, forecast_loss, regularisation_loss = self.step() + losses_hist["total"].append(loss.item()) + losses_hist["forecast"].append(forecast_loss.item()) + losses_hist["regularisation"].append(regularisation_loss.item()) if loss < best_loss: best_loss = loss + best_model_state_dict = deepcopy(self.posterior_estimator.state_dict()) num_epochs_without_improvement = 0 else: num_epochs_without_improvement += 1 @@ -152,3 +160,5 @@ def run(self, n_epochs, max_epochs_without_improvement=20): ) ) break + + return losses_hist, best_model_state_dict diff --git a/test/test_infer.py b/test/test_infer.py index eaf645a..43fc591 100644 --- a/test/test_infer.py +++ b/test/test_infer.py @@ -36,7 +36,7 @@ def test_random_walk(self): for true_p in true_ps: data = rw(torch.tensor([true_p])) posterior_estimator = TrainableGaussian() - optimizer = torch.optim.Adam(posterior_estimator.parameters(), lr=1e-2) + optimizer = torch.optim.Adam(posterior_estimator.parameters(), lr=5e-2) calib = Calibrator( model=rw, posterior_estimator=posterior_estimator, @@ -45,6 +45,7 @@ def test_random_walk(self): optimizer=optimizer, diff_mode=diff_mode, ) - calib.run(1000) + _, best_model_state_dict = calib.run(1000) + posterior_estimator.load_state_dict(best_model_state_dict) assert np.isclose(posterior_estimator.mu.item(), true_p, rtol=0.25) assert posterior_estimator.sigma.item() < 1e-2