Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions tests/doc/test_backward.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
"""

from torch.testing import assert_close
from unit.conftest import DEVICE


def test_backward():
Expand All @@ -13,11 +12,11 @@ def test_backward():
from torchjd import backward
from torchjd.aggregation import UPGrad

param = torch.tensor([1.0, 2.0], requires_grad=True, device=DEVICE)
param = torch.tensor([1.0, 2.0], requires_grad=True)
# Compute arbitrary quantities that are function of param
y1 = torch.tensor([-1.0, 1.0], device=DEVICE) @ param
y1 = torch.tensor([-1.0, 1.0]) @ param
y2 = (param**2).sum()

backward([y1, y2], UPGrad())

assert_close(param.grad, torch.tensor([0.5000, 2.5000], device=DEVICE), rtol=0.0, atol=1e-04)
assert_close(param.grad, torch.tensor([0.5000, 2.5000]), rtol=0.0, atol=1e-04)
19 changes: 8 additions & 11 deletions tests/unit/aggregation/_inputs.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import torch
from torch import Tensor
from unit.conftest import DEVICE


def _check_valid_dimensions(n_rows: int, n_cols: int) -> None:
Expand Down Expand Up @@ -37,9 +36,9 @@ def _augment_orthogonal_matrix(orthogonal_matrix: Tensor) -> Tensor:

n_rows = orthogonal_matrix.shape[0]
projection = orthogonal_matrix @ orthogonal_matrix.T
zero = torch.zeros([n_rows], device=DEVICE)
zero = torch.zeros([n_rows])
while True:
random_vector = torch.randn([n_rows], device=DEVICE)
random_vector = torch.randn([n_rows])
projected_vector = random_vector - projection @ random_vector
if not torch.allclose(projected_vector, zero):
break
Expand Down Expand Up @@ -70,7 +69,7 @@ def _generate_unitary_matrix(n_rows: int, n_cols: int) -> Tensor:
"""Generates a unitary matrix of shape [n_rows, n_cols]."""

_check_valid_dimensions(n_rows, n_cols)
partial_matrix = torch.randn([n_rows, 1], device=DEVICE)
partial_matrix = torch.randn([n_rows, 1])
partial_matrix = torch.nn.functional.normalize(partial_matrix, dim=0)

unitary_matrix = _complete_orthogonal_matrix(partial_matrix, n_cols)
Expand All @@ -83,7 +82,7 @@ def _generate_unitary_matrix_with_positive_column(n_rows: int, n_cols: int) -> T
positive vector.
"""
_check_valid_dimensions(n_rows, n_cols)
partial_matrix = torch.abs(torch.randn([n_rows, 1], device=DEVICE))
partial_matrix = torch.abs(torch.randn([n_rows, 1]))
partial_matrix = torch.nn.functional.normalize(partial_matrix, dim=0)

unitary_matrix_with_positive_column = _complete_orthogonal_matrix(partial_matrix, n_cols)
Expand All @@ -94,7 +93,7 @@ def _generate_diagonal_singular_values(rank: int) -> Tensor:
"""
generates a diagonal matrix of positive values sorted in descending order.
"""
singular_values = torch.abs(torch.randn([rank], device=DEVICE))
singular_values = torch.abs(torch.randn([rank]))
singular_values = torch.sort(singular_values, descending=True)[0]
S = torch.diag(singular_values)
return S
Expand All @@ -108,7 +107,7 @@ def generate_matrix(n_rows: int, n_cols: int, rank: int) -> Tensor:
_check_valid_rank(n_rows, n_cols, rank)

if rank == 0:
matrix = torch.zeros([n_rows, n_cols], device=DEVICE)
matrix = torch.zeros([n_rows, n_cols])
else:
U = _generate_unitary_matrix(n_rows, rank)
V = _generate_unitary_matrix(n_cols, rank)
Expand All @@ -126,7 +125,7 @@ def generate_stationary_matrix(n_rows: int, n_cols: int, rank: int) -> Tensor:

_check_valid_rank(n_rows, n_cols, rank)
if rank == 0:
matrix = torch.zeros([n_rows, n_cols], device=DEVICE)
matrix = torch.zeros([n_rows, n_cols])
else:
U = _generate_unitary_matrix_with_positive_column(n_rows, rank)
V = _generate_unitary_matrix(n_cols, rank)
Expand Down Expand Up @@ -161,9 +160,7 @@ def generate_stationary_matrix(n_rows: int, n_cols: int, rank: int) -> Tensor:
generate_matrix(n_rows, n_cols, rank) for n_rows, n_cols, rank in _matrix_dimension_triples
]
scaled_matrices = [scale * matrix for scale in _scales for matrix in matrices]
zero_rank_matrices = [
torch.zeros([n_rows, n_cols], device=DEVICE) for n_rows, n_cols in _zero_rank_matrix_shapes
]
zero_rank_matrices = [torch.zeros([n_rows, n_cols]) for n_rows, n_cols in _zero_rank_matrix_shapes]
matrices_2_plus_rows = [matrix for matrix in matrices + zero_rank_matrices if matrix.shape[0] >= 2]
scaled_matrices_2_plus_rows = [
matrix for matrix in scaled_matrices + zero_rank_matrices if matrix.shape[0] >= 2
Expand Down
3 changes: 1 addition & 2 deletions tests/unit/aggregation/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import torch
from pytest import mark, raises
from unit._utils import ExceptionContext
from unit.conftest import DEVICE

from torchjd.aggregation import Aggregator

Expand All @@ -21,4 +20,4 @@
)
def test_check_is_matrix(shape: Sequence[int], expectation: ExceptionContext):
with expectation:
Aggregator._check_is_matrix(torch.randn(shape, device=DEVICE))
Aggregator._check_is_matrix(torch.randn(shape))
5 changes: 2 additions & 3 deletions tests/unit/aggregation/test_constant.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import torch
from pytest import mark
from torch import Tensor
from unit.conftest import DEVICE

from torchjd.aggregation import Constant

Expand All @@ -16,7 +15,7 @@

def _make_aggregator(matrix: Tensor) -> Constant:
n_rows = matrix.shape[0]
weights = torch.tensor([1.0 / n_rows] * n_rows, device=DEVICE)
weights = torch.tensor([1.0 / n_rows] * n_rows)
return Constant(weights)


Expand All @@ -38,6 +37,6 @@ def test_expected_structure_property(cls, aggregator: Constant, matrix: Tensor):


def test_representations():
A = Constant(weights=torch.tensor([1.0, 2.0]))
A = Constant(weights=torch.tensor([1.0, 2.0], device="cpu"))
assert repr(A) == "Constant(weights=tensor([1., 2.]))"
assert str(A) == "Constant([1., 2.])"
2 changes: 1 addition & 1 deletion tests/unit/aggregation/test_graddrop.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ class TestGradDrop(ExpectedStructureProperty):


def test_representations():
A = GradDrop(leak=torch.tensor([0.0, 1.0]))
A = GradDrop(leak=torch.tensor([0.0, 1.0], device="cpu"))
assert repr(A) == "GradDrop(leak=tensor([0., 1.]))"
assert str(A) == "GradDrop([0., 1.])"

Expand Down
5 changes: 2 additions & 3 deletions tests/unit/aggregation/test_imtl_g.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import torch
from pytest import mark
from torch.testing import assert_close
from unit.conftest import DEVICE

from torchjd.aggregation import IMTLG

Expand All @@ -20,8 +19,8 @@ def test_imtlg_zero():
"""

A = IMTLG()
J = torch.zeros(2, 3, device=DEVICE)
assert_close(A(J), torch.zeros(3, device=DEVICE))
J = torch.zeros(2, 3)
assert_close(A(J), torch.zeros(3))


def test_representations():
Expand Down
5 changes: 2 additions & 3 deletions tests/unit/aggregation/test_mgda.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import torch
from pytest import mark
from torch.testing import assert_close
from unit.conftest import DEVICE

from torchjd.aggregation import MGDA
from torchjd.aggregation.mgda import _MGDAWeighting
Expand Down Expand Up @@ -29,7 +28,7 @@ class TestMGDA(ExpectedStructureProperty, NonConflictingProperty, PermutationInv
],
)
def test_mgda_satisfies_kkt_conditions(shape: tuple[int, int]):
matrix = torch.randn(shape, device=DEVICE)
matrix = torch.randn(shape)
weighting = _MGDAWeighting(epsilon=1e-05, max_iters=1000)

gramian = matrix @ matrix.T
Expand All @@ -45,7 +44,7 @@ def test_mgda_satisfies_kkt_conditions(shape: tuple[int, int]):
assert_close(positive_weights.norm(), weights.norm())

weights_sum = weights.sum()
assert_close(weights_sum, torch.ones([], device=DEVICE))
assert_close(weights_sum, torch.ones([]))

# Dual feasibility
positive_mu = mu[mu >= 0]
Expand Down
3 changes: 1 addition & 2 deletions tests/unit/aggregation/test_pcgrad.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import torch
from pytest import mark
from torch.testing import assert_close
from unit.conftest import DEVICE

from torchjd.aggregation import PCGrad
from torchjd.aggregation.pcgrad import _PCGradWeighting
Expand Down Expand Up @@ -37,7 +36,7 @@ def test_equivalence_upgrad_sum_two_rows(shape: tuple[int, int]):
rows.
"""

matrix = torch.randn(shape, device=DEVICE)
matrix = torch.randn(shape)

pc_grad_weighting = _PCGradWeighting()
upgrad_sum_weighting = _UPGradWrapper(
Expand Down
7 changes: 3 additions & 4 deletions tests/unit/aggregation/test_upgrad.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import torch
from pytest import mark
from torch.testing import assert_close
from unit.conftest import DEVICE

from torchjd.aggregation import UPGrad
from torchjd.aggregation.mean import _MeanWeighting
Expand All @@ -21,8 +20,8 @@ class TestUPGrad(ExpectedStructureProperty, NonConflictingProperty, PermutationI

@mark.parametrize("shape", [(5, 7), (9, 37), (2, 14), (32, 114), (50, 100)])
def test_upgrad_lagrangian_satisfies_kkt_conditions(shape: tuple[int, int]):
matrix = torch.randn(shape, device=DEVICE)
weights = torch.rand(shape[0], device=DEVICE)
matrix = torch.randn(shape)
weights = torch.rand(shape[0])

gramian = matrix @ matrix.T

Expand All @@ -41,7 +40,7 @@ def test_upgrad_lagrangian_satisfies_kkt_conditions(shape: tuple[int, int]):
assert_close(positive_constraint.norm(), constraint.norm(), atol=1e-04, rtol=0)

slackness = torch.trace(lagrange_multiplier @ constraint)
assert_close(slackness, torch.zeros_like(slackness, device=DEVICE), atol=3e-03, rtol=0)
assert_close(slackness, torch.zeros_like(slackness), atol=3e-03, rtol=0)


def test_representations():
Expand Down
33 changes: 16 additions & 17 deletions tests/unit/autojac/_transform/test_accumulate.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import torch
from pytest import mark, raises
from unit.conftest import DEVICE

from torchjd.autojac._transform import Accumulate, Gradients

Expand All @@ -13,12 +12,12 @@ def test_single_accumulation():
once.
"""

key1 = torch.zeros([], requires_grad=True, device=DEVICE)
key2 = torch.zeros([1], requires_grad=True, device=DEVICE)
key3 = torch.zeros([2, 3], requires_grad=True, device=DEVICE)
value1 = torch.ones([], device=DEVICE)
value2 = torch.ones([1], device=DEVICE)
value3 = torch.ones([2, 3], device=DEVICE)
key1 = torch.zeros([], requires_grad=True)
key2 = torch.zeros([1], requires_grad=True)
key3 = torch.zeros([2, 3], requires_grad=True)
value1 = torch.ones([])
value2 = torch.ones([1])
value3 = torch.ones([2, 3])
input = Gradients({key1: value1, key2: value2, key3: value3})

accumulate = Accumulate([key1, key2, key3])
Expand All @@ -41,12 +40,12 @@ def test_multiple_accumulation(iterations: int):
`iterations` times.
"""

key1 = torch.zeros([], requires_grad=True, device=DEVICE)
key2 = torch.zeros([1], requires_grad=True, device=DEVICE)
key3 = torch.zeros([2, 3], requires_grad=True, device=DEVICE)
value1 = torch.ones([], device=DEVICE)
value2 = torch.ones([1], device=DEVICE)
value3 = torch.ones([2, 3], device=DEVICE)
key1 = torch.zeros([], requires_grad=True)
key2 = torch.zeros([1], requires_grad=True)
key3 = torch.zeros([2, 3], requires_grad=True)
value1 = torch.ones([])
value2 = torch.ones([1])
value3 = torch.ones([2, 3])
input = Gradients({key1: value1, key2: value2, key3: value3})

accumulate = Accumulate([key1, key2, key3])
Expand All @@ -70,8 +69,8 @@ def test_no_requires_grad_fails():
tensor that does not require grad.
"""

key = torch.zeros([1], requires_grad=False, device=DEVICE)
value = torch.ones([1], device=DEVICE)
key = torch.zeros([1], requires_grad=False)
value = torch.ones([1])
input = Gradients({key: value})

accumulate = Accumulate([key])
Expand All @@ -86,8 +85,8 @@ def test_no_leaf_and_no_retains_grad_fails():
tensor that is not a leaf and that does not retain grad.
"""

key = torch.tensor([1.0], requires_grad=True, device=DEVICE) * 2
value = torch.ones([1], device=DEVICE)
key = torch.tensor([1.0], requires_grad=True) * 2
value = torch.ones([1])
input = Gradients({key: value})

accumulate = Accumulate([key])
Expand Down
Loading