diff --git a/birds/forecast.py b/birds/forecast.py index 711203c..352a2b0 100644 --- a/birds/forecast.py +++ b/birds/forecast.py @@ -3,6 +3,10 @@ import warnings +from birds.mpi_setup import mpi_size, mpi_rank, mpi_comm +from birds.jacfwd import jacfwd + + def compute_loss(loss_fn, observed_outputs, simulated_outputs): r"""Compute the loss between observed and simulated outputs. @@ -31,21 +35,32 @@ def compute_loss(loss_fn, observed_outputs, simulated_outputs): loss += loss_fn(observed_output, simulated_output) is_nan = False if is_nan: - return torch.nan - return loss + return torch.nan, torch.nan + return loss, loss # need to return it twice for jac calculation -def compute_forecast_loss( - loss_fn, model, parameter_generator, n_samples, observed_outputs +def compute_forecast_loss_and_jacobian( + loss_fn, + model, + parameter_generator, + n_samples, + observed_outputs, + diff_mode="reverse", + jacobian_chunk_size=None, + device="cpu", ): - r"""Given a model and a parameter generator, compute the loss between the model outputs and the observed outputs. - + r"""Computes the loss and the jacobian of the loss for each sample. + The jacobian is computed using the forward or reverse mode differentiation and the computation is parallelized + across the available devices. Arguments: - loss_fn : callable - model : callable - parameter_generator : callable - n_samples : int - observed_outputs : list of torch.Tensor + loss_fn (callable) : loss function + model (callable) : PyTorch model + parameter_generator (callable) : parameter generator + n_samples (int) : number of samples + observed_outputs (list of torch.Tensor) : observed outputs + diff_mode (str) : differentiation mode can be "reverse" or "forward" + jacobian_chunk_size (int) : chunk size for the Jacobian computation (set None to get maximum chunk size) + device (str) : device to use for the computation Example: >>> loss_fn = torch.nn.MSELoss() >>> model = lambda x: [x**2] @@ -54,19 +69,56 @@ def compute_forecast_loss( >>> compute_forecast_loss(loss_fn, model, parameter_generator, 5, observed_outputs) tensor(0.) """ - n_samples_not_nan = 0 + # Rank 0 samples from the flow + if mpi_rank == 0: + params_list = parameter_generator(n_samples) + params_list_comm = params_list.detach().cpu().numpy() + else: + params_list_comm = None + # scatter the parameters to all ranks + params_list_comm = mpi_comm.bcast(params_list_comm, root=0) + # select forward or reverse jacobian calculator + if diff_mode == "reverse": + jacobian_diff_mode = torch.func.jacrev + else: + jacobian_diff_mode = lambda **kwargs: jacfwd(randomness="same", **kwargs) + loss_f = lambda x: compute_loss( + loss_fn=loss_fn, observed_outputs=observed_outputs, simulated_outputs=x + ) + jacobian_calculator = jacobian_diff_mode( + func=loss_f, + argnums=0, + has_aux=True, + chunk_size=jacobian_chunk_size, + ) + # make each rank compute the loss for its parameters loss = 0 - for _ in range(n_samples): - parameters = parameter_generator() - simulated_outputs = model(parameters) - loss_i = compute_loss(loss_fn, observed_outputs, simulated_outputs) - if np.isnan(loss_i): + jacobians_per_rank = [] + parameters_per_rank = [] + for params in params_list_comm[mpi_rank::mpi_size]: + simulated_outputs = model(torch.tensor(params, device=device)) + jacobian, loss_i = jacobian_calculator(simulated_outputs) + if np.isnan(loss): continue loss += loss_i - n_samples_not_nan += 1 - if n_samples_not_nan == 0: - loss = torch.nan - else: - loss = loss / n_samples_not_nan - return loss, loss # need to return it twice for the jacobian calculation + jacobians_per_rank.append(jacobian[0].cpu().numpy()) + parameters_per_rank.append(params) + # gather the jacobians and parameters from all ranks + if mpi_size > 1: + jacobians_per_rank = mpi_comm.gather(jacobians_per_rank, root=0) + parameters_per_rank = mpi_comm.gather(parameters_per_rank, root=0) + if mpi_rank == 0: + jacobians = [] + parameters = [] + for jacobians_rank, parameters_rank in zip( + jacobians_per_rank, parameters_per_rank + ): + jacobians.append(torch.tensor(jacobians_rank)) + parameters.append(torch.tensor(parameters_rank)) + loss = sum(mpi_comm.gather(loss, root=0)) + if mpi_rank == 0: + loss = loss / len(parameters) + return parameters, loss, jacobians + else: + return None, None, None diff --git a/birds/jacfwd.py b/birds/jacfwd.py new file mode 100644 index 0000000..60671ac --- /dev/null +++ b/birds/jacfwd.py @@ -0,0 +1,79 @@ +""" +Modified version of torch.func.jacfwd that accepts chunk size. +""" +from torch._functorch.eager_transforms import ( + _slice_argnums, + _construct_standard_basis_for, + vmap, + tree_flatten, + Callable, + argnums_t, + wraps, + tree_unflatten, + _jvp_with_argnums, + safe_unflatten, +) + + +def jacfwd( + func: Callable, + argnums: argnums_t = 0, + has_aux: bool = False, + *, + randomness: str = "error", + chunk_size=None +): + @wraps(func) + def wrapper_fn(*args): + primals = args if argnums is None else _slice_argnums(args, argnums) + flat_primals, primals_spec = tree_flatten(primals) + flat_primals_numels = tuple(p.numel() for p in flat_primals) + flat_basis = _construct_standard_basis_for(flat_primals, flat_primals_numels) + basis = tree_unflatten(flat_basis, primals_spec) + + def push_jvp(basis): + output = _jvp_with_argnums( + func, args, basis, argnums=argnums, has_aux=has_aux + ) + # output[0] is the output of `func(*args)` + if has_aux: + _, jvp_out, aux = output + return jvp_out, aux + _, jvp_out = output + return jvp_out + + results = vmap(push_jvp, randomness=randomness, chunk_size=chunk_size)(basis) + if has_aux: + results, aux = results + # aux is in the standard basis format, e.g. NxN matrix + # We need to fetch the first element as original `func` output + flat_aux, aux_spec = tree_flatten(aux) + flat_aux = [value[0] for value in flat_aux] + aux = tree_unflatten(flat_aux, aux_spec) + + jac_outs, spec = tree_flatten(results) + # Most probably below output check can never raise an error + # as jvp should test the output before + # assert_non_empty_output(jac_outs, 'jacfwd(f, ...)(*args)') + + jac_outs_ins = tuple( + tuple( + safe_unflatten(jac_out_in, -1, primal.shape) + for primal, jac_out_in in zip( + flat_primals, + jac_out.movedim(0, -1).split(flat_primals_numels, dim=-1), + ) + ) + for jac_out in jac_outs + ) + jac_outs_ins = tuple( + tree_unflatten(jac_ins, primals_spec) for jac_ins in jac_outs_ins + ) + + if isinstance(argnums, int): + jac_outs_ins = tuple(jac_ins[0] for jac_ins in jac_outs_ins) + if has_aux: + return tree_unflatten(jac_outs_ins, spec), aux + return tree_unflatten(jac_outs_ins, spec) + + return wrapper_fn diff --git a/birds/models/coin_flipping.py b/birds/models/coin_flipping.py new file mode 100644 index 0000000..e69de29 diff --git a/birds/mpi_setup.py b/birds/mpi_setup.py new file mode 100644 index 0000000..eaf31c9 --- /dev/null +++ b/birds/mpi_setup.py @@ -0,0 +1,9 @@ +try: + from mpi4py import MPI + mpi_comm = MPI.COMM_WORLD + mpi_rank = mpi_comm.Get_rank() + mpi_size = mpi_comm.Get_size() +except: + mpi_comm = None + mpi_rank = 0 + mpi_size = 1 diff --git a/test/test_forecast.py b/test/test_forecast.py index 2b1a7a8..a94ad96 100644 --- a/test/test_forecast.py +++ b/test/test_forecast.py @@ -1,18 +1,17 @@ -import pytest import numpy as np import torch -from birds.forecast import compute_loss, compute_forecast_loss +from birds.forecast import compute_loss, compute_forecast_loss_and_jacobian class TestForecast: def test__compute_loss(self): loss_fn = torch.nn.MSELoss() observed_outputs = [torch.tensor([1.0, 2, 3]), torch.tensor([4.0, 5, 6])] simulated_outputs = [torch.tensor([1.0, 2, 3]), torch.tensor([4.0, 5, 6])] - assert compute_loss(loss_fn, observed_outputs, simulated_outputs) == 0 + assert compute_loss(loss_fn, observed_outputs, simulated_outputs) == (0, 0) simulated_outputs = [torch.tensor([1.0, 2, 3]), torch.tensor([4.0, 5, 7])] assert torch.isclose( - compute_loss(loss_fn, observed_outputs, simulated_outputs), + compute_loss(loss_fn, observed_outputs, simulated_outputs)[0], torch.tensor(0.3333), rtol=1e-3, ) @@ -20,29 +19,26 @@ def test__compute_loss(self): torch.tensor([1.0, 2, 3]), torch.tensor([4.0, 5, float("nan")]), ] - assert compute_loss(loss_fn, observed_outputs, simulated_outputs) == 0 + assert compute_loss(loss_fn, observed_outputs, simulated_outputs) == (0, 0) simulated_outputs = [ torch.tensor([1.0, 2, float("nan")]), torch.tensor([4.0, 5, float("nan")]), ] - assert np.isnan(compute_loss(loss_fn, observed_outputs, simulated_outputs)) + assert np.isnan(compute_loss(loss_fn, observed_outputs, simulated_outputs)[0]) def test__compute_forecast_loss(self): loss_fn = torch.nn.MSELoss() model = lambda x: [x ** 2] - parameter_generator = lambda: torch.tensor(2.0) - observed_outputs = [torch.tensor(4.0)] - assert compute_forecast_loss( + parameter_generator = lambda x: 2.0 * torch.ones((x, 2)) + observed_outputs = [torch.tensor([4.0, 4.0])] + parameters, loss, jacobians = compute_forecast_loss_and_jacobian( loss_fn, model, parameter_generator, 5, observed_outputs - ) == (0,0) - parameter_generator = lambda: torch.tensor(float("nan")) - assert np.isnan( - compute_forecast_loss( - loss_fn, model, parameter_generator, 5, observed_outputs - )[0] ) - parameter_generator = lambda: torch.tensor(2.0) - model = lambda x: [x ** 3] - assert compute_forecast_loss( - loss_fn, model, parameter_generator, 5, observed_outputs - ) == (16,16) + assert len(parameters) == 5 + for param in parameters: + assert torch.allclose(param, torch.tensor([2.0, 2.0])) + assert loss == torch.tensor(0.0) + assert len(jacobians) == 5 + for jacob in jacobians: + assert torch.allclose(jacob, torch.tensor([0.0, 0.0])) + diff --git a/test/test_jacfwd.py b/test/test_jacfwd.py new file mode 100644 index 0000000..decc8d5 --- /dev/null +++ b/test/test_jacfwd.py @@ -0,0 +1,13 @@ +from birds.jacfwd import jacfwd +from torch.func import jacrev +import torch + +class TestJacFwd: + def test__vs_pytorch(self): + func = lambda x: x ** 2 + 3*x + x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True) + assert torch.allclose(jacfwd(func)(x), torch.func.jacfwd(func)(x)) + func = lambda x, y: x**2 + y + x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True) + y = torch.tensor([2.0, 3.0, 4.0], requires_grad=True) + assert torch.allclose(jacfwd(func)(x, y), torch.func.jacfwd(func)(x, y))