Skip to content

Commit

Permalink
fixed tests
Browse files Browse the repository at this point in the history
  • Loading branch information
arnauqb committed May 16, 2023
1 parent 4dd074d commit d06bf10
Show file tree
Hide file tree
Showing 6 changed files with 120 additions and 98 deletions.
106 changes: 57 additions & 49 deletions birds/calibrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,8 @@ def step(self):
"""
Performs one training step.
"""
self.optimizer.zero_grad()
if mpi_rank == 0:
self.optimizer.zero_grad()
(
forecast_parameters,
forecast_loss,
Expand All @@ -112,70 +113,77 @@ def step(self):
diff_mode=self.diff_mode,
device=self.device,
)
regularisation_loss = self.w * compute_regularisation_loss(
posterior_estimator=self.posterior_estimator,
prior=self.prior,
n_samples=self.n_samples_regularisation,
)
self._differentiate_loss(
forecast_parameters, forecast_jacobians, regularisation_loss
)
# clip gradients
torch.nn.utils.clip_grad_norm_(
self.posterior_estimator.parameters(), self.gradient_clipping_norm
)
self.optimizer.step()
loss = forecast_loss + regularisation_loss
return loss, forecast_loss, regularisation_loss
if mpi_rank == 0:
regularisation_loss = self.w * compute_regularisation_loss(
posterior_estimator=self.posterior_estimator,
prior=self.prior,
n_samples=self.n_samples_regularisation,
)
self._differentiate_loss(
forecast_parameters, forecast_jacobians, regularisation_loss
)
# clip gradients
torch.nn.utils.clip_grad_norm_(
self.posterior_estimator.parameters(), self.gradient_clipping_norm
)
self.optimizer.step()
loss = forecast_loss + regularisation_loss
return loss, forecast_loss, regularisation_loss
return None, None, None

def run(self, n_epochs, max_epochs_without_improvement=20):
"""
Runs the calibrator for {n_epochs} epochs. Stops if the loss does not improve for {max_epochs_without_improvement} epochs.
Arguments:
n_epochs (int): The number of epochs to run the calibrator for.
n_epochs (int | np.inf): The number of epochs to run the calibrator for.
max_epochs_without_improvement (int): The number of epochs without improvement after which the calibrator stops.
"""
self.best_loss = torch.tensor(np.inf)
self.best_model_state_dict = None
self.writer = SummaryWriter(log_dir=self.tensorboard_log_dir)
if mpi_rank == 0:
self.writer = SummaryWriter(log_dir=self.tensorboard_log_dir)
num_epochs_without_improvement = 0
iterator = range(n_epochs)
if self.progress_bar and mpi_rank == 0:
iterator = tqdm(iterator)
self.losses_hist = defaultdict(list)
for epoch in iterator:
loss, forecast_loss, regularisation_loss = self.step()
self.losses_hist["total"].append(loss.item())
self.losses_hist["forecast"].append(forecast_loss.item())
self.losses_hist["regularisation"].append(regularisation_loss.item())
self.writer.add_scalar("Loss/total", loss, epoch)
self.writer.add_scalar("Loss/forecast", forecast_loss, epoch)
self.writer.add_scalar("Loss/regularisation", regularisation_loss, epoch)
if loss < self.best_loss:
self.best_loss = loss
self.best_model_state_dict = deepcopy(
self.posterior_estimator.state_dict()
)
num_epochs_without_improvement = 0
else:
num_epochs_without_improvement += 1
if self.progress_bar:
iterator.set_postfix(
{
"Forecast": forecast_loss.item(),
"Reg.": regularisation_loss.item(),
"total": loss.item(),
"best loss": self.best_loss.item(),
"epochs since improv.": num_epochs_without_improvement,
}
if mpi_rank == 0:
self.losses_hist["total"].append(loss.item())
self.losses_hist["forecast"].append(forecast_loss.item())
self.losses_hist["regularisation"].append(regularisation_loss.item())
self.writer.add_scalar("Loss/total", loss, epoch)
self.writer.add_scalar("Loss/forecast", forecast_loss, epoch)
self.writer.add_scalar(
"Loss/regularisation", regularisation_loss, epoch
)
if num_epochs_without_improvement >= max_epochs_without_improvement:
logger.info(
"Stopping early because the loss did not improve for {} epochs.".format(
max_epochs_without_improvement
if loss < self.best_loss:
self.best_loss = loss
self.best_model_state_dict = deepcopy(
self.posterior_estimator.state_dict()
)
)
break
self.writer.flush()
self.writer.close()
num_epochs_without_improvement = 0
else:
num_epochs_without_improvement += 1
if self.progress_bar:
iterator.set_postfix(
{
"Forecast": forecast_loss.item(),
"Reg.": regularisation_loss.item(),
"total": loss.item(),
"best loss": self.best_loss.item(),
"epochs since improv.": num_epochs_without_improvement,
}
)
if num_epochs_without_improvement >= max_epochs_without_improvement:
logger.info(
"Stopping early because the loss did not improve for {} epochs.".format(
max_epochs_without_improvement
)
)
break
if mpi_rank == 0:
self.writer.flush()
self.writer.close()
8 changes: 4 additions & 4 deletions birds/forecast.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,12 +115,12 @@ def loss_f(params):
else:
jacobians_per_rank = [jacobians_per_rank]
indices_per_rank = [indices_per_rank]
if mpi_comm is not None:
losses = mpi_comm.gather(loss, root=0)
if mpi_rank == 0:
loss = sum(losses)
if mpi_rank == 0:
jacobians = []
jacobians = list(chain(*jacobians_per_rank))
if mpi_comm is not None:
loss = sum(mpi_comm.gather(loss, root=0))
if mpi_rank == 0:
indices = list(chain(*indices_per_rank))
parameters = params_list[indices]
loss = loss / len(parameters)
Expand Down
36 changes: 25 additions & 11 deletions birds/models/sir.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,14 @@

from birds.utils import soft_minimum, soft_maximum


class SIRMessagePassing(torch_geometric.nn.conv.MessagePassing):
def forward(self, edge_index, infected, susceptible):
return self.propagate(edge_index, x=infected, y=susceptible)

def message(self, x_j, y_i):
return x_j * y_i
return x_j * y_i


class SIR(torch.nn.Module):
def __init__(self, graph, n_timesteps):
Expand All @@ -22,20 +25,19 @@ def __init__(self, graph, n_timesteps):
self.n_timesteps = n_timesteps
# convert graph from networkx to pytorch geometric
self.graph = torch_geometric.utils.convert.from_networkx(graph)
self.mp = SIRMessagePassing(aggr='add', node_dim=-1)
self.mp = SIRMessagePassing(aggr="add", node_dim=-1)

def sample_bernoulli_gs(self, probs, tau=0.1):
"""
Samples from a Bernoulli distribution in a diferentiable way using Gumble-Softmax
Arguments:
probs (torch.Tensor) : a tensor of shape (n,) containing the probabilities of success for each trial
tau (float) : the temperature of the Gumble-Softmax distribution
"""
logits = torch.vstack((probs, 1 - probs)).T.log()
gs_samples = torch.nn.functional.gumbel_softmax(logits, tau=tau, hard=True)
return gs_samples[:,0]

return gs_samples[:, 0]

def forward(self, params):
"""
Expand All @@ -45,8 +47,8 @@ def forward(self, params):
params (torch.Tensor) : a tensor of shape (3,) containing the **log10** of the fraction of infected, beta, and gamma
"""
# Initialize the parameters
params = soft_minimum(params, torch.tensor(0.0), 1)
params = 10 ** params
params = soft_minimum(params, torch.tensor(0.0), 2)
params = 10**params

initial_infected = params[0]
beta = params[1]
Expand All @@ -62,7 +64,6 @@ def forward(self, params):
infected += new_infected
susceptible -= new_infected


infected_hist = infected.sum().reshape((1,))
recovered_hist = torch.zeros((1,))

Expand All @@ -84,8 +85,21 @@ def forward(self, params):
infected = infected + new_infected - new_recovered
susceptible -= new_infected
recovered += new_recovered
infected_hist = torch.hstack((infected_hist, infected.sum().reshape(1,)))
recovered_hist = torch.hstack((recovered_hist, recovered.sum().reshape(1,)))
infected_hist = torch.hstack(
(
infected_hist,
infected.sum().reshape(
1,
),
)
)
recovered_hist = torch.hstack(
(
recovered_hist,
recovered.sum().reshape(
1,
),
)
)

return infected_hist, recovered_hist

62 changes: 31 additions & 31 deletions docs/examples/02-SIR.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion test/models/test_sir.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,9 @@ def test__implementation_against_nnlib(self):
"""
N = 1000
fraction_infected = np.log10(0.05)
beta = np.log10(0.05)
gamma = np.log10(0.10)
fraction_infected = np.log10(0.05)
n_timesteps = 10 # we do one more...
graph = networkx.erdos_renyi_graph(N, 0.01)
model = SIR(graph=graph, n_timesteps=n_timesteps)
Expand Down
4 changes: 2 additions & 2 deletions test/test_calibrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@ def test_random_walk(self, TrainableGaussian):
w=100.0,
progress_bar=False,
)
_, best_model_state_dict = calib.run(
calib.run(
100, max_epochs_without_improvement=100
)
posterior_estimator.load_state_dict(best_model_state_dict)
posterior_estimator.load_state_dict(calib.best_model_state_dict)
# check correct result is within 2 sigma
assert (
np.abs(posterior_estimator.mu.item() - true_p)
Expand Down

0 comments on commit d06bf10

Please sign in to comment.