diff --git a/ext/ReactantPythonCallExt.jl b/ext/ReactantPythonCallExt.jl index d42945018f..be5b61fdd3 100644 --- a/ext/ReactantPythonCallExt.jl +++ b/ext/ReactantPythonCallExt.jl @@ -8,6 +8,24 @@ using PythonCall const jaxptr = Ref{Py}() +const NUMPY_SIMPLE_TYPES = ( + ("bool_", Bool), + ("int8", Int8), + ("int16", Int16), + ("int32", Int32), + ("int64", Int64), + ("uint8", UInt8), + ("uint16", UInt16), + ("uint32", UInt32), + ("uint64", UInt64), + ("float16", Float16), + ("float32", Float32), + ("float64", Float64), + ("complex32", ComplexF16), + ("complex64", ComplexF32), + ("complex128", ComplexF64), +) + function PythonCall.pycall( f::Py, arg0::Reactant.TracedRArray, argNs::Reactant.TracedRArray...; kwargs... ) @@ -16,7 +34,7 @@ function PythonCall.pycall( inputs = map((arg0, argNs...)) do arg JT = eltype(arg) PT = nothing - for (CPT, CJT) in PythonCall.Convert.NUMPY_SIMPLE_TYPES + for (CPT, CJT) in NUMPY_SIMPLE_TYPES if JT == CJT PT = CPT break