Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PyArray conversion speedups and PyArrayFromBuffer #487

Merged
merged 22 commits into from Nov 13, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
57 changes: 57 additions & 0 deletions benchmarks/arrayperf.jl
@@ -0,0 +1,57 @@
using PyCall, BenchmarkTools, DataStructures
using PyCall: PyArray_Info

results = OrderedDict{String,Any}()

let
np = pyimport("numpy")
nprand = np["random"]["rand"]
# nparray_pyo(x) = pycall(np["array"], PyObject, x)
# pytestarray(sz::Int...) = pycall(np["reshape"], PyObject, nparray_pyo(1:prod(sz)), sz)

# no convert baseline
nprand_pyo(sz...) = pycall(nprand, PyObject, sz...)

for arr_size in [(2,2), (100,100)]
pyo_arr = nprand_pyo(arr_size...)
results["nprand_pyo$arr_size"] = @benchmark $nprand_pyo($arr_size...)
println("nprand_pyo $arr_size:\n"); display(results["nprand_pyo$arr_size"])
println("--------------------------------------------------")

results["convert_pyarr$arr_size"] = @benchmark $convert(PyArray, $pyo_arr)
println("convert_pyarr $arr_size:\n"); display(results["convert_pyarr$arr_size"])
println("--------------------------------------------------")

results["PyArray_Info$arr_size"] = @benchmark $PyArray_Info($pyo_arr)
println("PyArray_Info $arr_size:\n"); display(results["PyArray_Info$arr_size"])
println("--------------------------------------------------")

results["convert_pyarrbuf$arr_size"] = @benchmark $PyArray($pyo_arr)
println("convert_pyarrbuf $arr_size:\n"); display(results["convert_pyarrbuf$arr_size"])
println("--------------------------------------------------")

results["convert_arr$arr_size"] = @benchmark convert(Array, $pyo_arr)
println("convert_arr $arr_size:\n"); display(results["convert_arr$arr_size"])
println("--------------------------------------------------")

results["convert_arrbuf$arr_size"] = @benchmark $NoCopyArray($pyo_arr)
println("convert_arrbuf $arr_size:\n"); display(results["convert_arrbuf$arr_size"])
println("--------------------------------------------------")

pyarr = convert(PyArray, pyo_arr)
results["setdata!$arr_size"] = @benchmark $setdata!($pyarr, $pyo_arr)
println("setdata!:\n"); display(results["setdata!$arr_size"])
println("--------------------------------------------------")

pyarr = convert(PyArray, pyo_arr)
pybuf=PyBuffer()
results["setdata! bufprealloc$arr_size"] =
@benchmark $setdata!($pyarr, $pyo_arr, $pybuf)
println("setdata! bufprealloc:\n"); display(results["setdata! bufprealloc$arr_size"])
println("--------------------------------------------------")
end
end
println()
println("Mean times")
println("----------")
foreach((r)->println(rpad(r[1],27), ": ", mean(r[2])), results)
4 changes: 3 additions & 1 deletion src/PyCall.jl
Expand Up @@ -5,7 +5,8 @@ module PyCall
using Compat, VersionParsing

export pycall, pycall!, pyimport, pyimport_e, pybuiltin, PyObject, PyReverseDims,
PyPtr, pyincref, pydecref, pyversion, PyArray, PyArray_Info,
PyPtr, pyincref, pydecref, pyversion,
PyArray, PyArray_Info, PyBuffer,
pyerr_check, pyerr_clear, pytype_query, PyAny, @pyimport, PyDict,
pyisinstance, pywrap, pytypeof, pyeval, PyVector, pystring, pystr, pyrepr,
pyraise, pytype_mapping, pygui, pygui_start, pygui_stop,
Expand Down Expand Up @@ -177,6 +178,7 @@ pytypeof(o::PyObject) = ispynull(o) ? throw(ArgumentError("NULL PyObjects have n

const TypeTuple = Union{Type,NTuple{N, Type}} where {N}
include("pybuffer.jl")
include("pyarray.jl")
include("conversions.jl")
include("pytype.jl")
include("pyiterator.jl")
Expand Down
6 changes: 2 additions & 4 deletions src/conversions.jl
Expand Up @@ -768,13 +768,11 @@ function pysequence_query(o::PyObject)
return AbstractRange
elseif ispybytearray(o)
return Vector{UInt8}
elseif !haskey(o, "__array_interface__")
elseif !isbuftype(o)
# only handle PyList for now
return pyisinstance(o, @pyglobalobj :PyList_Type) ? Array : Union{}
else
otypestr = get(o["__array_interface__"], PyObject, "typestr")
typestr = convert(AbstractString, otypestr) # Could this just be String now?
T = npy_typestrs[typestr[2:end]]
T, native_byteorder = array_format(o)
if T == PyPtr
T = PyObject
end
Expand Down