Skip to content

Commit

Permalink
need to implement forecast loss calculation in parallel
Browse files Browse the repository at this point in the history
  • Loading branch information
arnauqb committed May 11, 2023
1 parent ee59e4c commit 1d44bb5
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 5 deletions.
6 changes: 4 additions & 2 deletions birds/forecast.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ def compute_forecast_loss(
loss += loss_i
n_samples_not_nan += 1
if n_samples_not_nan == 0:
return torch.nan
return loss / n_samples_not_nan
loss = torch.nan
else:
loss = loss / n_samples_not_nan
return loss, loss # need to return it twice for the jacobian calculation

Binary file removed logo.png
Binary file not shown.
6 changes: 3 additions & 3 deletions test/test_forecast.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,15 @@ def test__compute_forecast_loss(self):
observed_outputs = [torch.tensor(4.0)]
assert compute_forecast_loss(
loss_fn, model, parameter_generator, 5, observed_outputs
) == 0
) == (0,0)
parameter_generator = lambda: torch.tensor(float("nan"))
assert np.isnan(
compute_forecast_loss(
loss_fn, model, parameter_generator, 5, observed_outputs
)
)[0]
)
parameter_generator = lambda: torch.tensor(2.0)
model = lambda x: [x ** 3]
assert compute_forecast_loss(
loss_fn, model, parameter_generator, 5, observed_outputs
) == (8-4)**2
) == (16,16)

0 comments on commit 1d44bb5

Please sign in to comment.