Skip to content

Commit

Permalink
remove mpi4py dependency
Browse files Browse the repository at this point in the history
  • Loading branch information
arnauqb committed May 12, 2023
1 parent 34ea5f6 commit b8bc997
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 7 deletions.
4 changes: 2 additions & 2 deletions birds/forecast.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import numpy as np
import warnings


from birds.mpi_setup import mpi_size, mpi_rank, mpi_comm
from birds.jacfwd import jacfwd

Expand Down Expand Up @@ -76,7 +75,8 @@ def compute_forecast_loss_and_jacobian(
else:
params_list_comm = None
# scatter the parameters to all ranks
params_list_comm = mpi_comm.bcast(params_list_comm, root=0)
if mpi_comm is not None:
params_list_comm = mpi_comm.bcast(params_list_comm, root=0)
# select forward or reverse jacobian calculator
if diff_mode == "reverse":
jacobian_diff_mode = torch.func.jacrev
Expand Down
4 changes: 2 additions & 2 deletions test/test_forecast.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from birds.forecast import compute_loss, compute_forecast_loss_and_jacobian


class TestForecast:
def test__compute_loss(self):
loss_fn = torch.nn.MSELoss()
Expand All @@ -28,7 +29,7 @@ def test__compute_loss(self):

def test__compute_forecast_loss(self):
loss_fn = torch.nn.MSELoss()
model = lambda x: [x ** 2]
model = lambda x: [x**2]
parameter_generator = lambda x: 2.0 * torch.ones((x, 2))
observed_outputs = [torch.tensor([4.0, 4.0])]
parameters, loss, jacobians = compute_forecast_loss_and_jacobian(
Expand All @@ -41,4 +42,3 @@ def test__compute_forecast_loss(self):
assert len(jacobians) == 5
for jacob in jacobians:
assert torch.allclose(jacob, torch.tensor([0.0, 0.0]))

7 changes: 4 additions & 3 deletions test/test_jacfwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@
from torch.func import jacrev
import torch


class TestJacFwd:
def test__vs_pytorch(self):
func = lambda x: x ** 2 + 3*x
func = lambda x: x**2 + 3 * x
x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)
assert torch.allclose(jacfwd(func)(x), torch.func.jacfwd(func)(x))
assert torch.allclose(jacfwd(func)(x), torch.func.jacfwd(func)(x))
func = lambda x, y: x**2 + y
x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)
y = torch.tensor([2.0, 3.0, 4.0], requires_grad=True)
assert torch.allclose(jacfwd(func)(x, y), torch.func.jacfwd(func)(x, y))
assert torch.allclose(jacfwd(func)(x, y), torch.func.jacfwd(func)(x, y))

0 comments on commit b8bc997

Please sign in to comment.