diff --git a/docs/source/examples/monitoring.rst b/docs/source/examples/monitoring.rst index 186c4c2a4..206136cab 100644 --- a/docs/source/examples/monitoring.rst +++ b/docs/source/examples/monitoring.rst @@ -29,7 +29,7 @@ they have a negative inner product). """Prints the extracted weights.""" print(f"Weights: {weights}") - def print_similarity_with_gd(_, inputs: torch.Tensor, aggregation: torch.Tensor) -> None: + def print_similarity_with_gd(_, inputs: tuple[torch.Tensor], aggregation: torch.Tensor) -> None: """Prints the cosine similarity between the aggregation and the average gradient.""" matrix = inputs[0] gd_output = matrix.mean(dim=0) diff --git a/tests/doc/test_rst.py b/tests/doc/test_rst.py index 9dbc440bb..04c586eba 100644 --- a/tests/doc/test_rst.py +++ b/tests/doc/test_rst.py @@ -222,7 +222,7 @@ def print_weights(_, __, weights: torch.Tensor) -> None: """Prints the extracted weights.""" print(f"Weights: {weights}") - def print_similarity_with_gd(_, inputs: torch.Tensor, aggregation: torch.Tensor) -> None: + def print_similarity_with_gd(_, inputs: tuple[torch.Tensor], aggregation: torch.Tensor) -> None: """Prints the cosine similarity between the aggregation and the average gradient.""" matrix = inputs[0] gd_output = matrix.mean(dim=0)