diff --git a/birds/forecast.py b/birds/forecast.py index 352a2b0..3b0fc56 100644 --- a/birds/forecast.py +++ b/birds/forecast.py @@ -2,7 +2,6 @@ import numpy as np import warnings - from birds.mpi_setup import mpi_size, mpi_rank, mpi_comm from birds.jacfwd import jacfwd @@ -76,7 +75,8 @@ def compute_forecast_loss_and_jacobian( else: params_list_comm = None # scatter the parameters to all ranks - params_list_comm = mpi_comm.bcast(params_list_comm, root=0) + if mpi_comm is not None: + params_list_comm = mpi_comm.bcast(params_list_comm, root=0) # select forward or reverse jacobian calculator if diff_mode == "reverse": jacobian_diff_mode = torch.func.jacrev diff --git a/test/test_forecast.py b/test/test_forecast.py index a94ad96..3814c5b 100644 --- a/test/test_forecast.py +++ b/test/test_forecast.py @@ -3,6 +3,7 @@ from birds.forecast import compute_loss, compute_forecast_loss_and_jacobian + class TestForecast: def test__compute_loss(self): loss_fn = torch.nn.MSELoss() @@ -28,7 +29,7 @@ def test__compute_loss(self): def test__compute_forecast_loss(self): loss_fn = torch.nn.MSELoss() - model = lambda x: [x ** 2] + model = lambda x: [x**2] parameter_generator = lambda x: 2.0 * torch.ones((x, 2)) observed_outputs = [torch.tensor([4.0, 4.0])] parameters, loss, jacobians = compute_forecast_loss_and_jacobian( @@ -41,4 +42,3 @@ def test__compute_forecast_loss(self): assert len(jacobians) == 5 for jacob in jacobians: assert torch.allclose(jacob, torch.tensor([0.0, 0.0])) - diff --git a/test/test_jacfwd.py b/test/test_jacfwd.py index decc8d5..d985545 100644 --- a/test/test_jacfwd.py +++ b/test/test_jacfwd.py @@ -2,12 +2,13 @@ from torch.func import jacrev import torch + class TestJacFwd: def test__vs_pytorch(self): - func = lambda x: x ** 2 + 3*x + func = lambda x: x**2 + 3 * x x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True) - assert torch.allclose(jacfwd(func)(x), torch.func.jacfwd(func)(x)) + assert torch.allclose(jacfwd(func)(x), torch.func.jacfwd(func)(x)) func = lambda x, y: x**2 + y x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True) y = torch.tensor([2.0, 3.0, 4.0], requires_grad=True) - assert torch.allclose(jacfwd(func)(x, y), torch.func.jacfwd(func)(x, y)) + assert torch.allclose(jacfwd(func)(x, y), torch.func.jacfwd(func)(x, y))