From 9484922a1e6b686fbef6ca83b9af64753a3fe9d3 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 22 Dec 2024 19:57:49 -0500 Subject: [PATCH 1/2] Fix ReactantPythonCallExt.jl --- ext/ReactantPythonCallExt.jl | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/ext/ReactantPythonCallExt.jl b/ext/ReactantPythonCallExt.jl index d42945018f..84f7314a52 100644 --- a/ext/ReactantPythonCallExt.jl +++ b/ext/ReactantPythonCallExt.jl @@ -8,6 +8,25 @@ using PythonCall const jaxptr = Ref{Py}() + +const NUMPY_SIMPLE_TYPES = ( + ("bool_", Bool), + ("int8", Int8), + ("int16", Int16), + ("int32", Int32), + ("int64", Int64), + ("uint8", UInt8), + ("uint16", UInt16), + ("uint32", UInt32), + ("uint64", UInt64), + ("float16", Float16), + ("float32", Float32), + ("float64", Float64), + ("complex32", ComplexF16), + ("complex64", ComplexF32), + ("complex128", ComplexF64), +) + function PythonCall.pycall( f::Py, arg0::Reactant.TracedRArray, argNs::Reactant.TracedRArray...; kwargs... ) @@ -16,7 +35,7 @@ function PythonCall.pycall( inputs = map((arg0, argNs...)) do arg JT = eltype(arg) PT = nothing - for (CPT, CJT) in PythonCall.Convert.NUMPY_SIMPLE_TYPES + for (CPT, CJT) in NUMPY_SIMPLE_TYPES if JT == CJT PT = CPT break From e8914429136206874c2fc1c556cc4ddee5022d30 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 22 Dec 2024 19:59:38 -0500 Subject: [PATCH 2/2] Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- ext/ReactantPythonCallExt.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/ext/ReactantPythonCallExt.jl b/ext/ReactantPythonCallExt.jl index 84f7314a52..be5b61fdd3 100644 --- a/ext/ReactantPythonCallExt.jl +++ b/ext/ReactantPythonCallExt.jl @@ -8,7 +8,6 @@ using PythonCall const jaxptr = Ref{Py}() - const NUMPY_SIMPLE_TYPES = ( ("bool_", Bool), ("int8", Int8),