diff --git a/benchmarks/seqperf.jl b/benchmarks/seqperf.jl new file mode 100644 index 00000000..7eb8c7fd --- /dev/null +++ b/benchmarks/seqperf.jl @@ -0,0 +1,41 @@ +using PyCall, BenchmarkTools, DataStructures + +results = OrderedDict{String,Any}() +let + np = pyimport("numpy") + nprandint = np["random"]["randint"] + nprand = np["random"]["rand"] + res = PyNULL() + + tuplen = 16 + tpl = convert(PyObject, (1:tuplen...)) + lst = convert(PyObject, Any[1:tuplen...]) + for (name, pycoll) in zip(("tpl", "lst"), (tpl, lst)) + idx = rand(0:(tuplen-1)) + results["standard get $name"] = @benchmark get($pycoll, PyObject, PyObject($idx)) + println("standard get:\n"); display(results["standard get $name"]) + println("--------------------------------------------------") + + idx = rand(0:(tuplen-1)) + results["faster get $name"] = @benchmark get($pycoll, PyObject, $idx) + println("faster get:\n"); display(results["faster get $name"]) + println("--------------------------------------------------") + + idx = rand(0:(tuplen-1)) + results["get! $name"] = @benchmark get!($res, $pycoll, PyObject, $idx) + println("get!:\n"); display(results["get! $name"]) + println("--------------------------------------------------") + + if pycoll == tpl + idx = rand(0:(tuplen-1)) + results["unsafe_gettpl!"] = @benchmark unsafe_gettpl!($res, $tpl, PyObject, $idx) + println("unsafe_gettpl!:\n"); display(results["unsafe_gettpl!"]) + println("--------------------------------------------------") + end + end +end + +println("") +println("Mean times") +println("----------") +foreach((r)->println(rpad(r[1],20), ": ", mean(r[2])), results) diff --git a/src/PyCall.jl b/src/PyCall.jl index 37cd10c2..5225b345 100644 --- a/src/PyCall.jl +++ b/src/PyCall.jl @@ -10,11 +10,12 @@ export pycall, pyimport, pyimport_e, pybuiltin, PyObject, PyReverseDims, pyisinstance, pywrap, pytypeof, pyeval, PyVector, pystring, pystr, pyrepr, pyraise, pytype_mapping, pygui, pygui_start, pygui_stop, pygui_stop_all, @pylab, set!, PyTextIO, @pysym, PyNULL, ispynull, @pydef, - pyimport_conda, @py_str, @pywith, @pycall, pybytes, pyfunction, pyfunctionret + pyimport_conda, @py_str, @pywith, @pycall, pybytes, pyfunction, pyfunctionret, + unsafe_gettpl! import Base: size, ndims, similar, copy, getindex, setindex!, stride, convert, pointer, summary, convert, show, haskey, keys, values, - eltype, get, delete!, empty!, length, isempty, start, done, + eltype, get, get!, delete!, empty!, length, isempty, start, done, next, filter!, hash, splice!, pop!, ==, isequal, push!, append!, insert!, prepend!, unsafe_convert import Compat: pushfirst!, popfirst!, firstindex, lastindex @@ -97,8 +98,12 @@ it is equivalent to a `PyNULL()` object. """ ispynull(o::PyObject) = o.o == PyPtr_NULL +function pydecref_(o::PyPtr) + ccall(@pysym(:Py_DecRef), Cvoid, (PyPtr,), o) +end + function pydecref(o::PyObject) - ccall(@pysym(:Py_DecRef), Cvoid, (PyPtr,), o.o) + pydecref_(o.o) o.o = PyPtr_NULL o end @@ -130,10 +135,15 @@ function pystealref!(o::PyObject) return optr end -function Base.copy!(dest::PyObject, src::PyObject) - pydecref(dest) - dest.o = src.o - return pyincref(dest) +Base.copy!(dest::PyObject, src::PyObject) = Base.copy!(dest, src.o) + +function Base.copy!(dest::PyObject, src::PyPtr) + if dest.o != src + pyincref_(src) + pydecref_(dest.o) + dest.o = src + end + return dest end pyisinstance(o::PyObject, t::PyObject) = @@ -170,6 +180,7 @@ include("pyiterator.jl") include("pyclass.jl") include("callback.jl") include("io.jl") +include("get.jl") ######################################################################### @@ -753,27 +764,6 @@ macro pycall(ex) :(pycall($(map(esc,[kwargs; func; T; args])...))) end -######################################################################### -# Once Julia lets us overload ".", we will use [] to access items, but -# for now we can define "get". - -function get(o::PyObject, returntype::TypeTuple, k, default) - r = ccall((@pysym :PyObject_GetItem), PyPtr, (PyPtr,PyPtr), o,PyObject(k)) - if r == C_NULL - pyerr_clear() - default - else - convert(returntype, PyObject(r)) - end -end - -get(o::PyObject, returntype::TypeTuple, k) = - convert(returntype, PyObject(@pycheckn ccall((@pysym :PyObject_GetItem), - PyPtr, (PyPtr,PyPtr), o, PyObject(k)))) - -get(o::PyObject, k, default) = get(o, PyAny, k, default) -get(o::PyObject, k) = get(o, PyAny, k) - function delete!(o::PyObject, k) e = ccall((@pysym :PyObject_DelItem), Cint, (PyPtr, PyPtr), o, PyObject(k)) if e == -1 diff --git a/src/get.jl b/src/get.jl new file mode 100644 index 00000000..0370faa0 --- /dev/null +++ b/src/get.jl @@ -0,0 +1,95 @@ +######################################################################### +# Once Julia lets us overload ".", we will use [] to access items, but +# for now we can define "get". + +############################### +# get with k<:Any and a default +############################### +function get!(ret::PyObject, o::PyObject, returntype::TypeTuple, k, default) + r = ccall((@pysym :PyObject_GetItem), PyPtr, (PyPtr,PyPtr), o, PyObject(k)) + if r == C_NULL + pyerr_clear() + default + else + if r != ret.o + pydecref_(ret.o) + ret.o = r + end + convert(returntype, ret) + end +end + +get(o::PyObject, returntype::TypeTuple, k, default) = + get!(PyNULL(), o, returntype, k, default) + +# returntype defaults to PyAny +get!(ret::PyObject, o::PyObject, k, default) = get!(ret, o, PyAny, k, default) +get(o::PyObject, k, default) = get(o, PyAny, k, default) + +############################### +# get with k<:Any +############################### +function get!(ret::PyObject, o::PyObject, returntype::TypeTuple, k) + r = @pycheckn ccall((@pysym :PyObject_GetItem), + PyPtr, (PyPtr,PyPtr), o, PyObject(k)) + if r != ret.o + pydecref_(ret.o) + ret.o = r + end + return convert(returntype, ret) +end + +get(o::PyObject, returntype::TypeTuple, k) = get!(PyNULL(), o, returntype, k) + +# returntype defaults to PyAny +get!(ret::PyObject, o::PyObject, k) = get!(ret, o, PyAny, k) +get(o::PyObject, k) = get(o, PyAny, k) + +############################### +# get with k<:Integer +############################### +function get!(ret::PyObject, o::PyObject, returntype::TypeTuple, k::Integer) + if pyisinstance(o, @pyglobalobj :PyTuple_Type) + copy!(ret, @pycheckn ccall(@pysym(:PyTuple_GetItem), PyPtr, (PyPtr, Cint), o, k)) + elseif pyisinstance(o, @pyglobalobj :PyList_Type) + copy!(ret, @pycheckn ccall(@pysym( :PyList_GetItem), PyPtr, (PyPtr, Cint), o, k)) + else + return get!(ret, o, returntype, PyObject(k)) + end + return convert(returntype, ret) +end + +get(o::PyObject, returntype::TypeTuple, k::Integer) = + get!(PyNULL(), o, returntype, k) + +# default to PyObject(k) methods for no returntype, and default variants +get!(ret::PyObject, o::PyObject, returntype::TypeTuple, k::Integer, default) = + get!(ret, o, returntype, PyObject(k), default) + +get!(ret::PyObject, o::PyObject, k::Integer) = get!(ret, o, PyObject(k)) + +get!(ret::PyObject, o::PyObject, k::Integer, default) = + get!(ret, o, PyObject(k), default) + +############################### +# unsafe_gettpl! +############################### + +# struct PyTuple_struct +# refs: +# https://github.com/python/cpython/blob/da1734c58d2f97387ccc9676074717d38b044128/Include/object.h#L106-L115 +# https://github.com/python/cpython/blob/da1734c58d2f97387ccc9676074717d38b044128/Include/tupleobject.h#L25-L33 +struct PyVar_struct + ob_refcnt::Int + ob_type::Ptr{Cvoid} + ob_size::Int + # ob_item::Ptr{PyPtr} +end + +function unsafe_gettpl!(ret::PyObject, o::PyObject, returntype::TypeTuple, k::Int) + pytype_ptr = unsafe_load(o.o).ob_type + # get address of ob_item (just after the end of the struct) + itemsptr = Base.reinterpret(Ptr{PyPtr}, o.o + sizeof(PyVar_struct)) + copy!(ret, unsafe_load(itemsptr, k+1)) # unsafe_load is 1-based + return convert(returntype, ret) +end diff --git a/test/runtests.jl b/test/runtests.jl index 3db84b83..cc56311e 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -284,9 +284,8 @@ pymodule_exists(s::AbstractString) = !ispynull(pyimport_e(s)) @test o3[2,3,4] == 9 end end - - # list operations: - let o = PyObject(Any[8,3]) + @testset "list operations" begin + lst = Any[8,3]; o = PyObject(lst) @test collect(push!(o, 5)) == [8,3,5] @test pop!(o) == 5 && collect(o) == [8,3] @test popfirst!(o) == 8 && collect(o) == [3] @@ -294,6 +293,19 @@ pymodule_exists(s::AbstractString) = !ispynull(pyimport_e(s)) @test collect(prepend!(o, [5,4,2])) == [5,4,2,9,3] @test collect(append!(o, [1,6,8])) == [5,4,2,9,3,1,6,8] @test isempty(empty!(o)) + + # get! with preallocated return object + res = PyNULL() + lst = Any[8,3,5]; o = PyObject(lst) # re-init since emptied + lpyo = map(PyObject, lst) + for i in eachindex(lst) + @test get!(res, o, PyObject, i-1) == lpyo[i] + @test get!(res, o, i-1) == lst[i] # PyAny default + end + + # get! with default if key not found + @test get!(res, o, PyObject, 3, 12) == 12 + @test get!(res, o, 3, 12) == 12 end let o = PyObject(Any[8,3]) @test collect(append!(o, o)) == [8,3,8,3] @@ -301,6 +313,40 @@ pymodule_exists(s::AbstractString) = !ispynull(pyimport_e(s)) @test collect(prepend!(o, o)) == [8,3,8,3,1,8,3,8,3,1] end + @testset "tuple get" begin + tpl = ("a","b","c"); o = PyObject(tpl); tpl_os = map(PyObject, tpl) + + # get with PyObject key + @test ((get(o, PyObject, PyObject(i-1)) for i in eachindex(tpl))...,) == tpl_os + # get with PyObject key and default + @test get(o, PyObject, PyObject(3), "z") == PyObject("z") + + # get with Integer key + @test ((get(o, PyObject, PyObject(i-1)) for i in eachindex(tpl))...,) == tpl_os + # get with Integer key and default + @test get(o, PyObject, 3, "z") == PyObject("z") + + # get! + res = PyNULL() + for i in eachindex(tpl) + @test get!(res, o, PyObject, i-1) == tpl_os[i] + @test get!(res, o, PyObject, PyObject(i-1)) == tpl_os[i] + @test get!(res, o, i-1) == tpl[i] + @test get!(res, o, PyObject(i-1)) == tpl[i] + end + + # get! with default if key not found + @test get!(res, o, PyObject, 3, "z") == PyObject("z") + @test get!(res, o, PyObject, PyObject(3), "z") == PyObject("z") + @test get!(res, o, 3, "z") == PyObject("z") + @test get!(res, o, PyObject(3), "z") == PyObject("z") + + # fast unsafe_gettpl! + for i in eachindex(tpl) + @test unsafe_gettpl!(res, o, PyObject, i-1) == tpl_os[i] + end + end + # issue #216: @test length(collect(pyimport("itertools")[:combinations]([1,2,3],2))) == 3