Skip to content

Commit

Permalink
adapted the code to work with normflows
Browse files Browse the repository at this point in the history
  • Loading branch information
arnauqb committed May 15, 2023
1 parent 6ac14d4 commit f4c4e5e
Show file tree
Hide file tree
Showing 8 changed files with 102 additions and 100 deletions.
2 changes: 1 addition & 1 deletion birds/calibrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def step(self):
) = compute_forecast_loss_and_jacobian(
loss_fn=self.forecast_loss,
model=self.model,
parameter_generator=lambda x: self.posterior_estimator.rsample((x,)),
parameter_generator=lambda x: self.posterior_estimator.sample(x)[0],
observed_outputs=self.data,
n_samples=self.n_samples_per_epoch,
diff_mode=self.diff_mode,
Expand Down
3 changes: 2 additions & 1 deletion birds/models/sir.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,10 @@ def forward(self, params):
Runs the model forward
Arguments:
params (torch.Tensor) : a tensor of shape (3,) containing the fraction of infected, beta, and gamma
params (torch.Tensor) : a tensor of shape (3,) containing the **log10** of the fraction of infected, beta, and gamma
"""
# Initialize the parameters
params = 10 ** params
initial_infected = params[0]
beta = params[1]
gamma = params[2]
Expand Down
4 changes: 1 addition & 3 deletions birds/regularisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,7 @@ def compute_regularisation_loss(posterior_estimator, prior, n_samples):
tensor(0.5)
"""
# sample from the posterior
z = posterior_estimator.sample((n_samples,))
# compute the log probability of the samples under the posterior
log_prob_posterior = posterior_estimator.log_prob(z)
z, log_prob_posterior = posterior_estimator.sample(n_samples)
# compute the log probability of the samples under the prior
log_prob_prior = prior.log_prob(z)
# compute the Monte Carlo estimate of the KL divergence
Expand Down
112 changes: 56 additions & 56 deletions docs/examples/01-random_walk.ipynb

Large diffs are not rendered by default.

20 changes: 20 additions & 0 deletions test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,23 @@ def set_random_seed(seed=999):
torch.cuda.manual_seed(seed)
return

class TrainableGaussian(torch.nn.Module):
def __init__(self, mu = 0.0, sigma=1.0):
super().__init__()
self.mu = torch.nn.Parameter(mu * torch.ones(1))
self.sigma = torch.nn.Parameter(sigma * torch.ones(1))

def log_prob(self, x):
sigma = torch.clip(self.sigma, min=1e-3)
return torch.distributions.Normal(self.mu, sigma).log_prob(x)

def sample(self, x):
sigma = torch.clip(self.sigma, min=1e-3)
dist = torch.distributions.Normal(self.mu, sigma)
sample = dist.rsample((x,))
return sample, dist.log_prob(sample)

@fixture(name="TrainableGaussian")
def make_trainable_gaussian():
return TrainableGaussian

12 changes: 6 additions & 6 deletions test/models/test_sir.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,9 @@ def test__implementation_against_nnlib(self):
"""
N = 1000
beta = 0.05
gamma = 0.10
fraction_infected = 0.05
beta = np.log10(0.05)
gamma = np.log10(0.10)
fraction_infected = np.log10(0.05)
n_timesteps = 10 # we do one more...
graph = networkx.erdos_renyi_graph(N, 0.01)
model = SIR(graph=graph, n_timesteps=n_timesteps)
Expand All @@ -97,9 +97,9 @@ def test__gradient_propagates(self):
Checks that the gradient propagates through the model.
"""
N = 1000
beta = 0.05
gamma = 0.10
fraction_infected = 0.05
beta = -2.
gamma = -1.
fraction_infected = -2.
n_timesteps = 10
graph = networkx.erdos_renyi_graph(N, 0.01)
model = SIR(graph=graph, n_timesteps=n_timesteps)
Expand Down
34 changes: 9 additions & 25 deletions test/test_calibrator.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,13 @@
import torch
import numpy as np
import matplotlib.pyplot as plt

from birds.models.random_walk import RandomWalk
from birds.calibrator import Calibrator


class TrainableGaussian(torch.nn.Module):
def __init__(self):
super().__init__()
self.mu = torch.nn.Parameter(0.5 * torch.ones(1))
self.sigma = torch.nn.Parameter(0.1 * torch.ones(1))

def log_prob(self, x):
sigma = torch.clip(self.sigma, min=1e-3)
return torch.distributions.Normal(self.mu, sigma).log_prob(x)

def rsample(self, x):
sigma = torch.clip(self.sigma, min=1e-3)
return torch.distributions.Normal(self.mu, sigma).rsample(x)

def sample(self, x):
sigma = torch.clip(self.sigma, min=1e-3)
return torch.distributions.Normal(self.mu, sigma).sample(x)


class TestCalibrator:
def test_random_walk(self):
def test_random_walk(self, TrainableGaussian):
"""
Tests inference in a random walk model.
"""
Expand All @@ -35,17 +17,19 @@ def test_random_walk(self):
for diff_mode in ("reverse", "forward"):
for true_p in true_ps:
data = rw(torch.tensor([true_p]))
posterior_estimator = TrainableGaussian()
optimizer = torch.optim.Adam(posterior_estimator.parameters(), lr=5e-2)
posterior_estimator = TrainableGaussian(0.5, 0.1)
optimizer = torch.optim.Adam(posterior_estimator.parameters(), lr=1e-2)
calib = Calibrator(
model=rw,
posterior_estimator=posterior_estimator,
prior=prior,
data=data,
optimizer=optimizer,
diff_mode=diff_mode,
w = 100.0,
progress_bar=False,
)
_, best_model_state_dict = calib.run(1000)
_, best_model_state_dict = calib.run(100, max_epochs_without_improvement=100)
posterior_estimator.load_state_dict(best_model_state_dict)
assert np.isclose(posterior_estimator.mu.item(), true_p, rtol=0.25)
assert posterior_estimator.sigma.item() < 1e-2
# check correct result is within 2 sigma
assert np.abs(posterior_estimator.mu.item() - true_p) < 2 * posterior_estimator.sigma.item()
15 changes: 7 additions & 8 deletions test/test_regularisation.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,19 @@
import torch
import numpy as np
from birds.regularisation import compute_regularisation_loss

class TestRegularisation:
def test_regularisation(self):
def test_regularisation(self, TrainableGaussian):
n_samples = 100000
# define two normal distributions
dist1 = torch.distributions.Normal(0, 1)
dist2 = torch.distributions.Normal(0, 1)
dist1 = TrainableGaussian(0, 1)
dist2 = TrainableGaussian(0, 1)
# check that the KL divergence is 0
assert np.isclose(compute_regularisation_loss(dist1, dist2, n_samples), 0.)
assert np.isclose(compute_regularisation_loss(dist1, dist2, n_samples).detach(), 0.)
# define two normal distributions with different means
dist1 = torch.distributions.Normal(0, 1)
dist2 = torch.distributions.Normal(1, 1)
dist1 = TrainableGaussian(0, 1)
dist2 = TrainableGaussian(1, 1)
# check that the KL divergence is the right result
assert np.isclose(compute_regularisation_loss(dist1, dist2, n_samples), 0.5, rtol=1e-2)
assert np.isclose(compute_regularisation_loss(dist1, dist2, n_samples).detach(), 0.5, rtol=1e-2)



0 comments on commit f4c4e5e

Please sign in to comment.