Skip to content
Permalink
Browse files
Various fixes and mechanics to test round-tripping w/ pyarrow (#30)
This doesn't actually hook up pyarrow roundtrip testing, but you can run
the pyarrowrountrip.jl test file if you have python3 and pyarrow
installed locally (along with PyCall.jl on the julia side). It then
tests most of our testtables.jl testing tables by writing them in julia,
passing written bytes to pyarrow, reading them via pyarrow, writing them
back out, then reading in on julia side. The fixes were pretty minor,
but feels much better knowing all these exmaples work well (and will be
easy to test in the future).
  • Loading branch information
quinnj committed Oct 3, 2020
1 parent 096d754 commit 3e1c3c9cf064daa137994a0d66f44cdcdfa55bf2
Show file tree
Hide file tree
Showing 8 changed files with 55 additions and 11 deletions.
@@ -111,7 +111,7 @@ macro scopedenum(T, syms...)
# enum definition
primitive type $(esc(typename)) <: ScopedEnum{$(basetype)} $(sizeof(basetype) * 8) end
function $(esc(typename))(x::Integer)
$(Base.Enums.membershiptest(:x, values)) || enum_argument_error($(Expr(:quote, typename)), x)
$(Base.Enums.membershiptest(:x, values)) || Base.Enums.enum_argument_error($(Expr(:quote, typename)), x)
return Core.bitcast($(esc(typename)), convert($(basetype), x))
end
if isdefined(Base.Enums, :namemap)
@@ -474,7 +474,7 @@ struct DenseUnion{T, S} <: ArrowVector{T}
end

Base.size(s::DenseUnion) = size(s.typeIds)
nullcount(x::DenseUnion) = nullcount(x.data[1])
nullcount(x::DenseUnion) = 0

arrowvector(U::Union, ::Type{S}, x, de, meta; denseunions::Bool=true, kw...) where {S} =
arrowvector(denseunions ? DenseUnionVector(x) : SparseUnionVector(x), de, meta; denseunions=denseunions, kw...)
@@ -536,7 +536,7 @@ struct SparseUnion{T, S} <: ArrowVector{T}
end

Base.size(s::SparseUnion) = size(s.typeIds)
nullcount(x::SparseUnion) = nullcount(x.data[1])
nullcount(x::SparseUnion) = 0

@propagate_inbounds function Base.getindex(s::SparseUnion{T}, i::Integer) where {T}
@boundscheck checkbounds(s, i)
@@ -195,6 +195,7 @@ Base.convert(::Type{Dates.Time}, x::Time{U, T}) where {U, T} = Dates.Time(Dates.
function arrowtype(b, ::Type{Time{U, T}}) where {U, T}
Meta.timeStart(b)
Meta.timeAddUnit(b, U)
Meta.timeAddBitWidth(b, Int32(8 * sizeof(T)))
return Meta.Time, Meta.timeEnd(b), nothing
end

@@ -1,4 +1,4 @@
FlatBuffers.@scopedenum MetadataVersion::Int16 V1 V2 V3 V4
FlatBuffers.@scopedenum MetadataVersion::Int16 V1 V2 V3 V4 V5

struct Null <: FlatBuffers.Table
bytes::Vector{UInt8}
@@ -257,7 +257,7 @@ function Base.iterate(x::BatchIterator, (pos, id)=(x.startpos, 0))
msg = FlatBuffers.getrootas(Meta.Message, x.bytes, pos-1)
pos += msglen
# pos now points to message body
@debug 1 "parsing message: msglen = $msglen, version = $(msg.version), bodyLength = $(msg.bodyLength)"
@debug 1 "parsing message: msglen = $msglen, bodyLength = $(msg.bodyLength)"
return Batch(msg, x.bytes, pos, id), (pos + msg.bodyLength, id + 1)
end

@@ -292,7 +292,7 @@ end
function makemessage(b, headerType, header, columns=nothing, bodylen=0)
# write the message flatbuffer object
Meta.messageStart(b)
Meta.messageAddVersion(b, Meta.MetadataVersion.V4)
Meta.messageAddVersion(b, Meta.MetadataVersion.V5)
Meta.messageAddHeaderType(b, headerType)
Meta.messageAddHeader(b, header)
Meta.messageAddBodyLength(b, Int64(bodylen))
@@ -0,0 +1,43 @@
ENV["PYTHON"] = "python3"
import PyCall
pa = PyCall.pyimport("pyarrow")
include(joinpath(dirname(pathof(Arrow)), "../test/testtables.jl"))

for (nm, t, writekw, readkw, extratests) in testtables
nm == "unions" && continue
println("pyarrow roundtrip: $nm")
io = IOBuffer()
Arrow.write(io, t; writekw...)
seekstart(io)
buf = PyCall.pybytes(take!(io))
reader = pa.ipc.open_stream(buf)
sink = pa.BufferOutputStream()
writer = pa.ipc.new_stream(sink, reader.schema)
for batch in reader
writer.write_batch(batch)
end
writer.close()
buf = sink.getvalue()
jbytes = copy(reinterpret(UInt8, buf))
tt = Arrow.Table(jbytes)
end

f1 = pa.field("f1", pa.float64(), true)
f2 = pa.field("f2", pa.int64(), false)
fu = pa.field("col1", pa.union([f1, f2], "dense"))
sch = pa.schema([fu])

xs = pa.array([2.0, 4.0, PyCall.pynothing[]], type=pa.float64())
ys = pa.array([1, 3], type=pa.int64())
types = pa.array([0, 1, 0, 1, 1], type=pa.int8())
offsets = pa.array([0, 0, 1, 1, 2], type=pa.int32())
union_arr = pa.UnionArray.from_dense(types, offsets, [xs, ys])
data = [union_arr]
batch = pa.record_batch(data, names=["col1"])
sink = pa.BufferOutputStream()
writer = pa.ipc.new_stream(sink, batch.schema)
writer.write_batch(batch)
writer.close()
buf = sink.getvalue()
jbytes = copy(reinterpret(UInt8, buf))
tt = Arrow.Table(jbytes)
@@ -65,8 +65,8 @@ testtables = [
(
"unions",
(
col1=Arrow.DenseUnionVector([1, 2.0, 3, 4.0, missing]),
col2=Arrow.SparseUnionVector([1, 2.0, 3, 4.0, missing]),
col1=Arrow.DenseUnionVector( Union{Int64, Float64, Missing}[1, 2.0, 3, 4.0, missing]),
col2=Arrow.SparseUnionVector(Union{Int64, Float64, Missing}[1, 2.0, 3, 4.0, missing]),
),
NamedTuple(),
NamedTuple(),
@@ -150,11 +150,11 @@ testtables = [
(
"dictencode keyword",
(
col1=Int64[1,2,3,4,5,6,7,8,9,10],
col1=Int64[1,2,3,4],
col2=Union{String, Missing}["hey", "there", "sailor", missing],
col3=Arrow.DictEncode(NamedTuple{(:a, :b), Tuple{Int64, Union{String, Missing}}}[(a=Int64(1), b=missing), (a=Int64(1), b=missing), (a=Int64(3), b="sailor"), (a=Int64(4), b="jo-bob")]),
col4=[:a, :b, :c, :d, :a, :b, :c, :d, :e, missing],
col5=[Date(2020, 1, 1) for x = 1:10]
col4=[:a, :b, :c, missing],
col5=[Date(2020, 1, 1) for x = 1:4]
),
(dictencode=true,),
NamedTuple(),

0 comments on commit 3e1c3c9

Please sign in to comment.