Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions src/pyrecest/_backend/pytorch/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,22 @@ def _torch_as_like(value, like):
}


def _default_linalg_dtype():
dtype = get_default_dtype()
if dtype in (_torch.float32, _torch.float64):
return dtype
if dtype == _np.dtype("float32"):
return _torch.float32
if dtype == _np.dtype("float64"):
return _torch.float64
return _torch.float64


def _as_linalg_tensor(value):
"""Convert array-like values to a floating/complex tensor for torch.linalg."""
tensor = array(value)
if not is_floating(tensor) and not is_complex(tensor):
tensor = cast(tensor, dtype=get_default_dtype())
tensor = cast(tensor, dtype=_default_linalg_dtype())
return tensor


Expand All @@ -73,7 +84,7 @@ def _common_linalg_dtype(*tensors):
dtype = _torch.promote_types(dtype, tensor.dtype)
if dtype.is_floating_point or dtype.is_complex:
return dtype
return get_default_dtype()
return _default_linalg_dtype()


class _Logm(_torch.autograd.Function):
Expand Down
60 changes: 35 additions & 25 deletions src/pyrecest/_backend/pytorch/random.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Torch based random backend."""

from math import prod as _prod
from numbers import Integral as _Integral

import torch as _torch
Expand All @@ -17,21 +18,42 @@
}


def _size_type_error():
return TypeError("size must be None, an integer, or a sequence of integers")


def _looks_like_integer_dimension(value):
return isinstance(value, _Integral) and not isinstance(value, bool)


def _integer_dimension(value):
if not _looks_like_integer_dimension(value):
raise _size_type_error()
value = int(value)
if value < 0:
raise ValueError("size dimensions must be non-negative")
return value


def _shape_from_size(size):
if size is None:
return ()
if _looks_like_integer_dimension(size):
return (_integer_dimension(size),)
if isinstance(size, (str, bytes)) or not hasattr(size, "__iter__"):
raise _size_type_error()
return tuple(_integer_dimension(dim) for dim in size)


def _choice_size(size):
if size is None:
return None, 1
if not hasattr(size, "__iter__"):
size = (size,)
size = tuple(int(dim) for dim in size)
return size, int(_torch.prod(_torch.tensor(size)).item())
size = _shape_from_size(size)
return size, _prod(size) if size else 1


def _randint_size(size):
if size is None:
return ()
if not hasattr(size, "__iter__") or isinstance(size, (str, bytes)):
return (size,)
return tuple(size)
return _shape_from_size(size)


def randint(low, high=None, size=None, *args, **kwargs):
Expand All @@ -45,9 +67,7 @@ def randint(low, high=None, size=None, *args, **kwargs):
def _normal_size(size):
if size is None:
return None
if not hasattr(size, "__iter__"):
return (size,)
return tuple(int(dim) for dim in size)
return _shape_from_size(size)


def _normal_device(*values):
Expand Down Expand Up @@ -166,11 +186,7 @@ def seed(*args, **kwargs):


def rand(size=None, dtype=None):
if size is None:
size = ()
elif not hasattr(size, "__iter__"):
size = (size,)
return _torch.rand(size, dtype=dtype)
return _torch.rand(_shape_from_size(size), dtype=dtype)


def multinomial(n, pvals):
Expand Down Expand Up @@ -198,9 +214,7 @@ def normal(loc=0.0, scale=1.0, size=None):

def _uniform_size(size, low, high):
if size is not None:
if not hasattr(size, "__iter__") or isinstance(size, (str, bytes)):
return (size,)
return tuple(int(dim) for dim in size)
return _shape_from_size(size)

try:
return tuple(_torch.broadcast_shapes(low.shape, high.shape))
Expand Down Expand Up @@ -242,11 +256,7 @@ def _floating_distribution_dtype(*values):


def _normal_sample_size(size):
if size is None:
return ()
if not hasattr(size, "__iter__"):
return (size,)
return tuple(size)
return _shape_from_size(size)


@_modify_func_default_dtype(copy=False, kw_only=True)
Expand Down
55 changes: 55 additions & 0 deletions tests/backend/test_pytorch_random_backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import pytest

pytest.importorskip("torch")

from pyrecest._backend.pytorch import random # noqa: E402


@pytest.mark.parametrize(
"bad_size",
[True, False, (True,), [False, 2], 1.5, (2.0,), "3"],
)
def test_size_arguments_reject_bool_and_non_integral_dimensions(bad_size):
samplers = (
lambda size: random.rand(size=size),
lambda size: random.uniform(size=size),
lambda size: random.normal(size=size),
lambda size: random.randint(0, 3, size=size),
lambda size: random.choice(3, size=size),
lambda size: random.multivariate_normal([0.0], [[1.0]], size=size),
)

for sampler in samplers:
with pytest.raises(TypeError):
sampler(bad_size)


@pytest.mark.parametrize("bad_size", [-1, (2, -1)])
def test_size_arguments_reject_negative_dimensions(bad_size):
samplers = (
lambda size: random.rand(size=size),
lambda size: random.uniform(size=size),
lambda size: random.normal(size=size),
lambda size: random.randint(0, 3, size=size),
lambda size: random.choice(3, size=size),
lambda size: random.multivariate_normal([0.0], [[1.0]], size=size),
)

for sampler in samplers:
with pytest.raises(ValueError):
sampler(bad_size)


def test_scalar_and_empty_tuple_sizes_keep_scalar_shape():
assert random.rand().shape == ()
assert random.rand(size=()).shape == ()
assert random.normal(size=()).shape == ()
assert random.uniform(size=()).shape == ()
assert random.randint(0, 3, size=()).shape == ()
assert random.multivariate_normal([0.0], [[1.0]], size=()).shape == (1,)


def test_zero_sized_choice_still_works_for_empty_population():
sample = random.choice(0, size=(0,))

assert sample.shape == (0,)
Loading