diff --git a/cuda_core/cuda/core/experimental/_graph.py b/cuda_core/cuda/core/experimental/_graph.py index b8ebe9ae5..a82bd70f5 100644 --- a/cuda_core/cuda/core/experimental/_graph.py +++ b/cuda_core/cuda/core/experimental/_graph.py @@ -318,7 +318,10 @@ def complete(self, options: GraphCompleteOptions | None = None) -> Graph: raise RuntimeError( "Instantiation for device launch failed due to the nodes belonging to different contexts." ) - elif params.result_out == driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_CONDITIONAL_HANDLE_UNUSED: + elif ( + _py_major_minor >= (12, 8) + and params.result_out == driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_CONDITIONAL_HANDLE_UNUSED + ): raise RuntimeError("One or more conditional handles are not associated with conditional builders.") elif params.result_out != driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_SUCCESS: raise RuntimeError(f"Graph instantiation failed with unexpected error code: {params.result_out}") diff --git a/cuda_core/tests/test_graph.py b/cuda_core/tests/test_graph.py index cc558b6d2..615f7242c 100644 --- a/cuda_core/tests/test_graph.py +++ b/cuda_core/tests/test_graph.py @@ -304,7 +304,7 @@ def test_graph_conditional_if_else(init_cuda, condition_value): try: gb_if, gb_else = gb.if_else(handle) except RuntimeError as e: - with pytest.raises(RuntimeError, match="^Driver version"): + with pytest.raises(RuntimeError, match="^(Driver|Binding) version"): raise e gb.end_building() b.close() @@ -377,7 +377,7 @@ def test_graph_conditional_switch(init_cuda, condition_value): try: gb_case = list(gb.switch(handle, 3)) except RuntimeError as e: - with pytest.raises(RuntimeError, match="^Driver version"): + with pytest.raises(RuntimeError, match="^(Driver|Binding) version"): raise e gb.end_building() b.close() @@ -568,7 +568,7 @@ def build_graph(condition_value): try: gb_case = list(gb.switch(handle, 3)) except Exception as e: - with pytest.raises(RuntimeError, match="^Driver version"): + with pytest.raises(RuntimeError, match="^(Driver|Binding) version"): raise e gb.end_building() raise e @@ -599,7 +599,7 @@ def build_graph(condition_value): try: graph_variants = [build_graph(0), build_graph(1), build_graph(2)] except Exception as e: - with pytest.raises(RuntimeError, match="^Driver version"): + with pytest.raises(RuntimeError, match="^(Driver|Binding) version"): raise e b.close() pytest.skip("Driver does not support conditional switch")