Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 2 additions & 6 deletions src/pyrecest/_backend/pytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -775,9 +775,7 @@ def assignment(x, values, indices, axis=0):
x_new[indices] = values
return x_new
zip_indices = (
_is_iterable(indices)
and len(indices) > 0
and _is_iterable(indices[0])
_is_iterable(indices) and len(indices) > 0 and _is_iterable(indices[0])
)
len_indices = _assignment_index_length(indices, zip_indices)
if zip_indices:
Expand Down Expand Up @@ -832,9 +830,7 @@ def assignment_by_sum(x, values, indices, axis=0):
x_new[indices] += values
return x_new
zip_indices = (
_is_iterable(indices)
and len(indices) > 0
and _is_iterable(indices[0])
_is_iterable(indices) and len(indices) > 0 and _is_iterable(indices[0])
)
len_indices = _assignment_index_length(indices, zip_indices)
if zip_indices:
Expand Down
2 changes: 1 addition & 1 deletion src/pyrecest/_backend/pytorch/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,4 +115,4 @@ def multivariate_normal(mean, cov, size=None):
size = ()
elif not hasattr(size, "__iter__"):
size = (size,)
return _MultivariateNormal(mean, cov).sample(size)
return _MultivariateNormal(mean, cov).sample(size)
2 changes: 1 addition & 1 deletion src/pyrecest/utils/association_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,4 +453,4 @@ def pairwise_cost_matrix(
return -log(probabilities)
if mode == "one_minus_probability":
return 1.0 - probabilities
raise ValueError(f"Unsupported cost mode: {mode}")
raise ValueError(f"Unsupported cost mode: {mode}")
2 changes: 0 additions & 2 deletions tests/backend_support/test_random_uniform_contract.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,8 @@
import importlib.util

import pytest

from tests.support.backend_runner import run_backend_code


_UNIFORM_ARRAY_BOUNDS_CHECK = """
import pyrecest.backend as backend
from pyrecest.backend import random
Expand Down
2 changes: 1 addition & 1 deletion tests/test_association_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,4 +226,4 @@ def test_calibrated_pairwise_association_model_uses_named_components(self):


if __name__ == "__main__":
unittest.main()
unittest.main()
2 changes: 1 addition & 1 deletion tests/test_backend_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,4 +75,4 @@ def test_jax_multinomial_uses_and_advances_global_state(self):


if __name__ == "__main__":
unittest.main()
unittest.main()
Loading