From 4f42acaba7b6cf3fdb5467235e0460fe249e2221 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Mon, 12 May 2025 16:48:22 +0200 Subject: [PATCH] Fix type hint of inputs parameter in print_similarity_with_gd --- docs/source/examples/monitoring.rst | 2 +- tests/doc/test_rst.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/examples/monitoring.rst b/docs/source/examples/monitoring.rst index 186c4c2a4..206136cab 100644 --- a/docs/source/examples/monitoring.rst +++ b/docs/source/examples/monitoring.rst @@ -29,7 +29,7 @@ they have a negative inner product). """Prints the extracted weights.""" print(f"Weights: {weights}") - def print_similarity_with_gd(_, inputs: torch.Tensor, aggregation: torch.Tensor) -> None: + def print_similarity_with_gd(_, inputs: tuple[torch.Tensor], aggregation: torch.Tensor) -> None: """Prints the cosine similarity between the aggregation and the average gradient.""" matrix = inputs[0] gd_output = matrix.mean(dim=0) diff --git a/tests/doc/test_rst.py b/tests/doc/test_rst.py index 9dbc440bb..04c586eba 100644 --- a/tests/doc/test_rst.py +++ b/tests/doc/test_rst.py @@ -222,7 +222,7 @@ def print_weights(_, __, weights: torch.Tensor) -> None: """Prints the extracted weights.""" print(f"Weights: {weights}") - def print_similarity_with_gd(_, inputs: torch.Tensor, aggregation: torch.Tensor) -> None: + def print_similarity_with_gd(_, inputs: tuple[torch.Tensor], aggregation: torch.Tensor) -> None: """Prints the cosine similarity between the aggregation and the average gradient.""" matrix = inputs[0] gd_output = matrix.mean(dim=0)