-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
6 changed files
with
192 additions
and
43 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
""" | ||
Modified version of torch.func.jacfwd that accepts chunk size. | ||
""" | ||
from torch._functorch.eager_transforms import ( | ||
_slice_argnums, | ||
_construct_standard_basis_for, | ||
vmap, | ||
tree_flatten, | ||
Callable, | ||
argnums_t, | ||
wraps, | ||
tree_unflatten, | ||
_jvp_with_argnums, | ||
safe_unflatten, | ||
) | ||
|
||
|
||
def jacfwd( | ||
func: Callable, | ||
argnums: argnums_t = 0, | ||
has_aux: bool = False, | ||
*, | ||
randomness: str = "error", | ||
chunk_size=None | ||
): | ||
@wraps(func) | ||
def wrapper_fn(*args): | ||
primals = args if argnums is None else _slice_argnums(args, argnums) | ||
flat_primals, primals_spec = tree_flatten(primals) | ||
flat_primals_numels = tuple(p.numel() for p in flat_primals) | ||
flat_basis = _construct_standard_basis_for(flat_primals, flat_primals_numels) | ||
basis = tree_unflatten(flat_basis, primals_spec) | ||
|
||
def push_jvp(basis): | ||
output = _jvp_with_argnums( | ||
func, args, basis, argnums=argnums, has_aux=has_aux | ||
) | ||
# output[0] is the output of `func(*args)` | ||
if has_aux: | ||
_, jvp_out, aux = output | ||
return jvp_out, aux | ||
_, jvp_out = output | ||
return jvp_out | ||
|
||
results = vmap(push_jvp, randomness=randomness, chunk_size=chunk_size)(basis) | ||
if has_aux: | ||
results, aux = results | ||
# aux is in the standard basis format, e.g. NxN matrix | ||
# We need to fetch the first element as original `func` output | ||
flat_aux, aux_spec = tree_flatten(aux) | ||
flat_aux = [value[0] for value in flat_aux] | ||
aux = tree_unflatten(flat_aux, aux_spec) | ||
|
||
jac_outs, spec = tree_flatten(results) | ||
# Most probably below output check can never raise an error | ||
# as jvp should test the output before | ||
# assert_non_empty_output(jac_outs, 'jacfwd(f, ...)(*args)') | ||
|
||
jac_outs_ins = tuple( | ||
tuple( | ||
safe_unflatten(jac_out_in, -1, primal.shape) | ||
for primal, jac_out_in in zip( | ||
flat_primals, | ||
jac_out.movedim(0, -1).split(flat_primals_numels, dim=-1), | ||
) | ||
) | ||
for jac_out in jac_outs | ||
) | ||
jac_outs_ins = tuple( | ||
tree_unflatten(jac_ins, primals_spec) for jac_ins in jac_outs_ins | ||
) | ||
|
||
if isinstance(argnums, int): | ||
jac_outs_ins = tuple(jac_ins[0] for jac_ins in jac_outs_ins) | ||
if has_aux: | ||
return tree_unflatten(jac_outs_ins, spec), aux | ||
return tree_unflatten(jac_outs_ins, spec) | ||
|
||
return wrapper_fn |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
try: | ||
from mpi4py import MPI | ||
mpi_comm = MPI.COMM_WORLD | ||
mpi_rank = mpi_comm.Get_rank() | ||
mpi_size = mpi_comm.Get_size() | ||
except: | ||
mpi_comm = None | ||
mpi_rank = 0 | ||
mpi_size = 1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,48 +1,44 @@ | ||
import pytest | ||
import numpy as np | ||
import torch | ||
|
||
from birds.forecast import compute_loss, compute_forecast_loss | ||
from birds.forecast import compute_loss, compute_forecast_loss_and_jacobian | ||
|
||
class TestForecast: | ||
def test__compute_loss(self): | ||
loss_fn = torch.nn.MSELoss() | ||
observed_outputs = [torch.tensor([1.0, 2, 3]), torch.tensor([4.0, 5, 6])] | ||
simulated_outputs = [torch.tensor([1.0, 2, 3]), torch.tensor([4.0, 5, 6])] | ||
assert compute_loss(loss_fn, observed_outputs, simulated_outputs) == 0 | ||
assert compute_loss(loss_fn, observed_outputs, simulated_outputs) == (0, 0) | ||
simulated_outputs = [torch.tensor([1.0, 2, 3]), torch.tensor([4.0, 5, 7])] | ||
assert torch.isclose( | ||
compute_loss(loss_fn, observed_outputs, simulated_outputs), | ||
compute_loss(loss_fn, observed_outputs, simulated_outputs)[0], | ||
torch.tensor(0.3333), | ||
rtol=1e-3, | ||
) | ||
simulated_outputs = [ | ||
torch.tensor([1.0, 2, 3]), | ||
torch.tensor([4.0, 5, float("nan")]), | ||
] | ||
assert compute_loss(loss_fn, observed_outputs, simulated_outputs) == 0 | ||
assert compute_loss(loss_fn, observed_outputs, simulated_outputs) == (0, 0) | ||
simulated_outputs = [ | ||
torch.tensor([1.0, 2, float("nan")]), | ||
torch.tensor([4.0, 5, float("nan")]), | ||
] | ||
assert np.isnan(compute_loss(loss_fn, observed_outputs, simulated_outputs)) | ||
assert np.isnan(compute_loss(loss_fn, observed_outputs, simulated_outputs)[0]) | ||
|
||
def test__compute_forecast_loss(self): | ||
loss_fn = torch.nn.MSELoss() | ||
model = lambda x: [x ** 2] | ||
parameter_generator = lambda: torch.tensor(2.0) | ||
observed_outputs = [torch.tensor(4.0)] | ||
assert compute_forecast_loss( | ||
parameter_generator = lambda x: 2.0 * torch.ones((x, 2)) | ||
observed_outputs = [torch.tensor([4.0, 4.0])] | ||
parameters, loss, jacobians = compute_forecast_loss_and_jacobian( | ||
loss_fn, model, parameter_generator, 5, observed_outputs | ||
) == (0,0) | ||
parameter_generator = lambda: torch.tensor(float("nan")) | ||
assert np.isnan( | ||
compute_forecast_loss( | ||
loss_fn, model, parameter_generator, 5, observed_outputs | ||
)[0] | ||
) | ||
parameter_generator = lambda: torch.tensor(2.0) | ||
model = lambda x: [x ** 3] | ||
assert compute_forecast_loss( | ||
loss_fn, model, parameter_generator, 5, observed_outputs | ||
) == (16,16) | ||
assert len(parameters) == 5 | ||
for param in parameters: | ||
assert torch.allclose(param, torch.tensor([2.0, 2.0])) | ||
assert loss == torch.tensor(0.0) | ||
assert len(jacobians) == 5 | ||
for jacob in jacobians: | ||
assert torch.allclose(jacob, torch.tensor([0.0, 0.0])) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
from birds.jacfwd import jacfwd | ||
from torch.func import jacrev | ||
import torch | ||
|
||
class TestJacFwd: | ||
def test__vs_pytorch(self): | ||
func = lambda x: x ** 2 + 3*x | ||
x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True) | ||
assert torch.allclose(jacfwd(func)(x), torch.func.jacfwd(func)(x)) | ||
func = lambda x, y: x**2 + y | ||
x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True) | ||
y = torch.tensor([2.0, 3.0, 4.0], requires_grad=True) | ||
assert torch.allclose(jacfwd(func)(x, y), torch.func.jacfwd(func)(x, y)) |