Skip to content

Commit

Permalink
mpi forecast
Browse files Browse the repository at this point in the history
  • Loading branch information
arnauqb committed May 12, 2023
1 parent 1d44bb5 commit 34ea5f6
Show file tree
Hide file tree
Showing 6 changed files with 192 additions and 43 deletions.
98 changes: 75 additions & 23 deletions birds/forecast.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@
import warnings


from birds.mpi_setup import mpi_size, mpi_rank, mpi_comm
from birds.jacfwd import jacfwd


def compute_loss(loss_fn, observed_outputs, simulated_outputs):
r"""Compute the loss between observed and simulated outputs.
Expand Down Expand Up @@ -31,21 +35,32 @@ def compute_loss(loss_fn, observed_outputs, simulated_outputs):
loss += loss_fn(observed_output, simulated_output)
is_nan = False
if is_nan:
return torch.nan
return loss
return torch.nan, torch.nan
return loss, loss # need to return it twice for jac calculation


def compute_forecast_loss(
loss_fn, model, parameter_generator, n_samples, observed_outputs
def compute_forecast_loss_and_jacobian(
loss_fn,
model,
parameter_generator,
n_samples,
observed_outputs,
diff_mode="reverse",
jacobian_chunk_size=None,
device="cpu",
):
r"""Given a model and a parameter generator, compute the loss between the model outputs and the observed outputs.
r"""Computes the loss and the jacobian of the loss for each sample.
The jacobian is computed using the forward or reverse mode differentiation and the computation is parallelized
across the available devices.
Arguments:
loss_fn : callable
model : callable
parameter_generator : callable
n_samples : int
observed_outputs : list of torch.Tensor
loss_fn (callable) : loss function
model (callable) : PyTorch model
parameter_generator (callable) : parameter generator
n_samples (int) : number of samples
observed_outputs (list of torch.Tensor) : observed outputs
diff_mode (str) : differentiation mode can be "reverse" or "forward"
jacobian_chunk_size (int) : chunk size for the Jacobian computation (set None to get maximum chunk size)
device (str) : device to use for the computation
Example:
>>> loss_fn = torch.nn.MSELoss()
>>> model = lambda x: [x**2]
Expand All @@ -54,19 +69,56 @@ def compute_forecast_loss(
>>> compute_forecast_loss(loss_fn, model, parameter_generator, 5, observed_outputs)
tensor(0.)
"""
n_samples_not_nan = 0
# Rank 0 samples from the flow
if mpi_rank == 0:
params_list = parameter_generator(n_samples)
params_list_comm = params_list.detach().cpu().numpy()
else:
params_list_comm = None
# scatter the parameters to all ranks
params_list_comm = mpi_comm.bcast(params_list_comm, root=0)
# select forward or reverse jacobian calculator
if diff_mode == "reverse":
jacobian_diff_mode = torch.func.jacrev
else:
jacobian_diff_mode = lambda **kwargs: jacfwd(randomness="same", **kwargs)
loss_f = lambda x: compute_loss(
loss_fn=loss_fn, observed_outputs=observed_outputs, simulated_outputs=x
)
jacobian_calculator = jacobian_diff_mode(
func=loss_f,
argnums=0,
has_aux=True,
chunk_size=jacobian_chunk_size,
)
# make each rank compute the loss for its parameters
loss = 0
for _ in range(n_samples):
parameters = parameter_generator()
simulated_outputs = model(parameters)
loss_i = compute_loss(loss_fn, observed_outputs, simulated_outputs)
if np.isnan(loss_i):
jacobians_per_rank = []
parameters_per_rank = []
for params in params_list_comm[mpi_rank::mpi_size]:
simulated_outputs = model(torch.tensor(params, device=device))
jacobian, loss_i = jacobian_calculator(simulated_outputs)
if np.isnan(loss):
continue
loss += loss_i
n_samples_not_nan += 1
if n_samples_not_nan == 0:
loss = torch.nan
else:
loss = loss / n_samples_not_nan
return loss, loss # need to return it twice for the jacobian calculation
jacobians_per_rank.append(jacobian[0].cpu().numpy())
parameters_per_rank.append(params)

# gather the jacobians and parameters from all ranks
if mpi_size > 1:
jacobians_per_rank = mpi_comm.gather(jacobians_per_rank, root=0)
parameters_per_rank = mpi_comm.gather(parameters_per_rank, root=0)
if mpi_rank == 0:
jacobians = []
parameters = []
for jacobians_rank, parameters_rank in zip(
jacobians_per_rank, parameters_per_rank
):
jacobians.append(torch.tensor(jacobians_rank))
parameters.append(torch.tensor(parameters_rank))
loss = sum(mpi_comm.gather(loss, root=0))
if mpi_rank == 0:
loss = loss / len(parameters)
return parameters, loss, jacobians
else:
return None, None, None
79 changes: 79 additions & 0 deletions birds/jacfwd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
"""
Modified version of torch.func.jacfwd that accepts chunk size.
"""
from torch._functorch.eager_transforms import (
_slice_argnums,
_construct_standard_basis_for,
vmap,
tree_flatten,
Callable,
argnums_t,
wraps,
tree_unflatten,
_jvp_with_argnums,
safe_unflatten,
)


def jacfwd(
func: Callable,
argnums: argnums_t = 0,
has_aux: bool = False,
*,
randomness: str = "error",
chunk_size=None
):
@wraps(func)
def wrapper_fn(*args):
primals = args if argnums is None else _slice_argnums(args, argnums)
flat_primals, primals_spec = tree_flatten(primals)
flat_primals_numels = tuple(p.numel() for p in flat_primals)
flat_basis = _construct_standard_basis_for(flat_primals, flat_primals_numels)
basis = tree_unflatten(flat_basis, primals_spec)

def push_jvp(basis):
output = _jvp_with_argnums(
func, args, basis, argnums=argnums, has_aux=has_aux
)
# output[0] is the output of `func(*args)`
if has_aux:
_, jvp_out, aux = output
return jvp_out, aux
_, jvp_out = output
return jvp_out

results = vmap(push_jvp, randomness=randomness, chunk_size=chunk_size)(basis)
if has_aux:
results, aux = results
# aux is in the standard basis format, e.g. NxN matrix
# We need to fetch the first element as original `func` output
flat_aux, aux_spec = tree_flatten(aux)
flat_aux = [value[0] for value in flat_aux]
aux = tree_unflatten(flat_aux, aux_spec)

jac_outs, spec = tree_flatten(results)
# Most probably below output check can never raise an error
# as jvp should test the output before
# assert_non_empty_output(jac_outs, 'jacfwd(f, ...)(*args)')

jac_outs_ins = tuple(
tuple(
safe_unflatten(jac_out_in, -1, primal.shape)
for primal, jac_out_in in zip(
flat_primals,
jac_out.movedim(0, -1).split(flat_primals_numels, dim=-1),
)
)
for jac_out in jac_outs
)
jac_outs_ins = tuple(
tree_unflatten(jac_ins, primals_spec) for jac_ins in jac_outs_ins
)

if isinstance(argnums, int):
jac_outs_ins = tuple(jac_ins[0] for jac_ins in jac_outs_ins)
if has_aux:
return tree_unflatten(jac_outs_ins, spec), aux
return tree_unflatten(jac_outs_ins, spec)

return wrapper_fn
Empty file added birds/models/coin_flipping.py
Empty file.
9 changes: 9 additions & 0 deletions birds/mpi_setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
try:
from mpi4py import MPI
mpi_comm = MPI.COMM_WORLD
mpi_rank = mpi_comm.Get_rank()
mpi_size = mpi_comm.Get_size()
except:
mpi_comm = None
mpi_rank = 0
mpi_size = 1
36 changes: 16 additions & 20 deletions test/test_forecast.py
Original file line number Diff line number Diff line change
@@ -1,48 +1,44 @@
import pytest
import numpy as np
import torch

from birds.forecast import compute_loss, compute_forecast_loss
from birds.forecast import compute_loss, compute_forecast_loss_and_jacobian

class TestForecast:
def test__compute_loss(self):
loss_fn = torch.nn.MSELoss()
observed_outputs = [torch.tensor([1.0, 2, 3]), torch.tensor([4.0, 5, 6])]
simulated_outputs = [torch.tensor([1.0, 2, 3]), torch.tensor([4.0, 5, 6])]
assert compute_loss(loss_fn, observed_outputs, simulated_outputs) == 0
assert compute_loss(loss_fn, observed_outputs, simulated_outputs) == (0, 0)
simulated_outputs = [torch.tensor([1.0, 2, 3]), torch.tensor([4.0, 5, 7])]
assert torch.isclose(
compute_loss(loss_fn, observed_outputs, simulated_outputs),
compute_loss(loss_fn, observed_outputs, simulated_outputs)[0],
torch.tensor(0.3333),
rtol=1e-3,
)
simulated_outputs = [
torch.tensor([1.0, 2, 3]),
torch.tensor([4.0, 5, float("nan")]),
]
assert compute_loss(loss_fn, observed_outputs, simulated_outputs) == 0
assert compute_loss(loss_fn, observed_outputs, simulated_outputs) == (0, 0)
simulated_outputs = [
torch.tensor([1.0, 2, float("nan")]),
torch.tensor([4.0, 5, float("nan")]),
]
assert np.isnan(compute_loss(loss_fn, observed_outputs, simulated_outputs))
assert np.isnan(compute_loss(loss_fn, observed_outputs, simulated_outputs)[0])

def test__compute_forecast_loss(self):
loss_fn = torch.nn.MSELoss()
model = lambda x: [x ** 2]
parameter_generator = lambda: torch.tensor(2.0)
observed_outputs = [torch.tensor(4.0)]
assert compute_forecast_loss(
parameter_generator = lambda x: 2.0 * torch.ones((x, 2))
observed_outputs = [torch.tensor([4.0, 4.0])]
parameters, loss, jacobians = compute_forecast_loss_and_jacobian(
loss_fn, model, parameter_generator, 5, observed_outputs
) == (0,0)
parameter_generator = lambda: torch.tensor(float("nan"))
assert np.isnan(
compute_forecast_loss(
loss_fn, model, parameter_generator, 5, observed_outputs
)[0]
)
parameter_generator = lambda: torch.tensor(2.0)
model = lambda x: [x ** 3]
assert compute_forecast_loss(
loss_fn, model, parameter_generator, 5, observed_outputs
) == (16,16)
assert len(parameters) == 5
for param in parameters:
assert torch.allclose(param, torch.tensor([2.0, 2.0]))
assert loss == torch.tensor(0.0)
assert len(jacobians) == 5
for jacob in jacobians:
assert torch.allclose(jacob, torch.tensor([0.0, 0.0]))

13 changes: 13 additions & 0 deletions test/test_jacfwd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from birds.jacfwd import jacfwd
from torch.func import jacrev
import torch

class TestJacFwd:
def test__vs_pytorch(self):
func = lambda x: x ** 2 + 3*x
x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)
assert torch.allclose(jacfwd(func)(x), torch.func.jacfwd(func)(x))
func = lambda x, y: x**2 + y
x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)
y = torch.tensor([2.0, 3.0, 4.0], requires_grad=True)
assert torch.allclose(jacfwd(func)(x, y), torch.func.jacfwd(func)(x, y))

0 comments on commit 34ea5f6

Please sign in to comment.