Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/torchjd/autogram/_jacobian_computer.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def vmap(
jac_outputs: tuple[Tensor, ...],
args: tuple[PyTree, ...],
kwargs: dict[str, PyTree],
) -> tuple[Tensor, None]: # type: ignore[reportIncompatibleMethodOverride]
) -> tuple[Tensor, None]: # ty: ignore[invalid-method-override]
# There is a non-batched dimension
# We do not vmap over the args, kwargs, or rg_outputs for the non-batched dimension
generalized_jacobian = torch.vmap(compute_jacobian_fn, in_dims=in_dims[1:])(
Expand Down
2 changes: 1 addition & 1 deletion src/torchjd/autogram/_module_hook_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def setup_context(
ctx: Any,
inputs: tuple,
_,
) -> None: # type: ignore[reportIncompatibleMethodOverride]
) -> None: # ty: ignore[invalid-method-override]
ctx.gramian_accumulation_phase = inputs[0]
ctx.gramian_computer = inputs[1]
ctx.args = inputs[2]
Expand Down
2 changes: 1 addition & 1 deletion src/torchjd/autojac/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def get_leaf_tensors(tensors: Iterable[Tensor], excluded: Iterable[Tensor]) -> O

# accumulate_grads contains instances of AccumulateGrad, which contain a `variable` field.
# They cannot be typed as such because AccumulateGrad is not public.
leaves = OrderedSet([g.variable for g in accumulate_grads]) # type: ignore[attr-defined]
leaves = OrderedSet([g.variable for g in accumulate_grads]) # ty: ignore[unresolved-attribute]

return leaves

Expand Down
2 changes: 1 addition & 1 deletion tests/unit/aggregation/test_aligned_mtl.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def test_representations() -> None:


def test_invalid_scale_mode() -> None:
aggregator = AlignedMTL(scale_mode="test") # type: ignore[arg-type]
aggregator = AlignedMTL(scale_mode="test") # ty: ignore[invalid-argument-type]
matrix = ones_(3, 4)
with raises(ValueError, match=r"Invalid scale_mode=.*Expected"):
aggregator(matrix)
4 changes: 2 additions & 2 deletions tests/unit/autojac/test_backward.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ def test_input_retaining_grad_fails() -> None:

with raises(RuntimeError):
# Using such a BatchedTensor should result in an error
_ = -b.grad # type: ignore[unsupported-operator]
_ = -b.grad # ty: ignore[unsupported-operator]


def test_non_input_retaining_grad_fails() -> None:
Expand All @@ -336,7 +336,7 @@ def test_non_input_retaining_grad_fails() -> None:

with raises(RuntimeError):
# Using such a BatchedTensor should result in an error
_ = -b.grad # type: ignore[unsupported-operator]
_ = -b.grad # ty: ignore[unsupported-operator]


@mark.parametrize("chunk_size", [1, 3, None])
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/autojac/test_jac.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ def test_input_retaining_grad_fails() -> None:

with raises(RuntimeError):
# Using such a BatchedTensor should result in an error
_ = -b.grad # type: ignore[unsupported-operator]
_ = -b.grad # ty: ignore[unsupported-operator]


def test_non_input_retaining_grad_fails() -> None:
Expand All @@ -334,7 +334,7 @@ def test_non_input_retaining_grad_fails() -> None:

with raises(RuntimeError):
# Using such a BatchedTensor should result in an error
_ = -b.grad # type: ignore[unsupported-operator]
_ = -b.grad # ty: ignore[unsupported-operator]


@mark.parametrize("chunk_size", [1, 3, None])
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/autojac/test_mtl_backward.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,7 +448,7 @@ def test_shared_param_retaining_grad_fails() -> None:

with raises(RuntimeError):
# Using such a BatchedTensor should result in an error
_ = -a.grad # type: ignore[unsupported-operator]
_ = -a.grad # ty: ignore[unsupported-operator]


def test_shared_activation_retaining_grad_fails() -> None:
Expand Down Expand Up @@ -477,7 +477,7 @@ def test_shared_activation_retaining_grad_fails() -> None:

with raises(RuntimeError):
# Using such a BatchedTensor should result in an error
_ = -a.grad # type: ignore[unsupported-operator]
_ = -a.grad # ty: ignore[unsupported-operator]


def test_tasks_params_overlap() -> None:
Expand Down
5 changes: 3 additions & 2 deletions tests/utils/forward_backwards.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,9 +139,10 @@ def get_vjp(grad_outputs: Tensor) -> list[Tensor]:

jacobians = vmap(get_vjp)(torch.diag(torch.ones_like(output)))
jacobian_matrices = [jacobian.reshape([jacobian.shape[0], -1]) for jacobian in jacobians]
gramian = sum([jacobian @ jacobian.T for jacobian in jacobian_matrices])
products = [jacobian @ jacobian.T for jacobian in jacobian_matrices]
gramian = torch.stack(products).sum(dim=0)

return gramian
return PSDTensor(gramian)


class CloneParams:
Expand Down
Loading