diff --git a/tests/strategies/test_deepspeed_strategy.py b/tests/strategies/test_deepspeed_strategy.py index c99cd2fe357bd..a9b30b0015bf3 100644 --- a/tests/strategies/test_deepspeed_strategy.py +++ b/tests/strategies/test_deepspeed_strategy.py @@ -168,12 +168,13 @@ def test_deepspeed_strategy_env(tmpdir, monkeypatch, deepspeed_config): @RunIf(deepspeed=True) +@mock.patch("torch.cuda.device_count", return_value=1) @pytest.mark.parametrize("precision", [16, "mixed"]) @pytest.mark.parametrize( "amp_backend", ["native", pytest.param("apex", marks=RunIf(amp_apex=True))], ) -def test_deepspeed_precision_choice(amp_backend, precision, tmpdir): +def test_deepspeed_precision_choice(_, amp_backend, precision, tmpdir): """Test to ensure precision plugin is also correctly chosen. DeepSpeed handles precision via Custom DeepSpeedPrecisionPlugin