Skip to content
Permalink
Browse files
Introduce new maxdepth keyword argument for setting a limit on nest…
…ing (#147)

level limit

Alternative fix for #143. This is a more general fix than just
specializing CategoricalArrays. This should prevent more general cases
of the same issue: i.e. someone accidently passes a recursive data
structure and `Arrow.write` gets stuck trying to recursively serialize.
  • Loading branch information
quinnj committed Mar 10, 2021
1 parent 8e7869d commit 0f1b3500b78edc48861ce4782aaf994d355cbf44
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 10 deletions.
@@ -48,7 +48,10 @@ function toarrowvector(x, i=1, de=Dict{Int64, Any}(), ded=DictEncoding[], meta=g
return A
end

function arrowvector(x, i, nl, fi, de, ded, meta; dictencoding::Bool=false, dictencode::Bool=false, kw...)
function arrowvector(x, i, nl, fi, de, ded, meta; dictencoding::Bool=false, dictencode::Bool=false, maxdepth::Int=DEFAULT_MAX_DEPTH, kw...)
if nl > maxdepth
error("reached nested serialization level ($nl) deeper than provided max depth argument ($(maxdepth)); to increase allowed nesting level, pass `maxdepth=X`")
end
if !(x isa DictEncode) && !dictencoding && (dictencode || (x isa AbstractArray && DataAPI.refarray(x) !== x))
x = DictEncode(x, dictencodeid(i, nl, fi))
elseif x isa DictEncoded
@@ -39,6 +39,8 @@ to table, column, and other objects).
"""
getmetadata(x, default=nothing) = get(OBJ_METADATA, x, default)

const DEFAULT_MAX_DEPTH = 6

"""
Arrow.write(io::IO, tbl)
Arrow.write(file::String, tbl)
@@ -66,24 +68,25 @@ Supported keyword arguments to `Arrow.write` include:
* `dictencodenested::Bool=false`: whether nested data type columns should also dict encode nested arrays/buffers; other language implementations [may not support this](https://arrow.apache.org/docs/status.html)
* `denseunions::Bool=true`: whether Julia `Vector{<:Union}` arrays should be written using the dense union layout; passing `false` will result in the sparse union layout
* `largelists::Bool=false`: causes list column types to be written with Int64 offset arrays; mainly for testing purposes; by default, Int64 offsets will be used only if needed
* `maxdepth::Int=$DEFAULT_MAX_DEPTH`: deepest allowed nested serialization level; this is provided by default to prevent accidental infinite recursion with mutually recursive data structures
* `file::Bool=false`: if a an `io` argument is being written to, passing `file=true` will cause the arrow file format to be written instead of just IPC streaming
"""
function write end

write(io_or_file; kw...) = x -> write(io_or_file, x; kw...)

function write(file::String, tbl; largelists::Bool=false, compress::Union{Nothing, Symbol, LZ4FrameCompressor, ZstdCompressor}=nothing, denseunions::Bool=true, dictencode::Bool=false, dictencodenested::Bool=false, alignment::Int=8)
function write(file::String, tbl; largelists::Bool=false, compress::Union{Nothing, Symbol, LZ4FrameCompressor, ZstdCompressor}=nothing, denseunions::Bool=true, dictencode::Bool=false, dictencodenested::Bool=false, alignment::Int=8, maxdepth::Int=DEFAULT_MAX_DEPTH)
open(file, "w") do io
write(io, tbl, true, largelists, compress, denseunions, dictencode, dictencodenested, alignment)
write(io, tbl, true, largelists, compress, denseunions, dictencode, dictencodenested, alignment, maxdepth)
end
return file
end

function write(io::IO, tbl; largelists::Bool=false, compress::Union{Nothing, Symbol, LZ4FrameCompressor, ZstdCompressor}=nothing, denseunions::Bool=true, dictencode::Bool=false, dictencodenested::Bool=false, alignment::Int=8, file::Bool=false)
return write(io, tbl, file, largelists, compress, denseunions, dictencode, dictencodenested, alignment)
function write(io::IO, tbl; largelists::Bool=false, compress::Union{Nothing, Symbol, LZ4FrameCompressor, ZstdCompressor}=nothing, denseunions::Bool=true, dictencode::Bool=false, dictencodenested::Bool=false, alignment::Int=8, maxdepth::Int=DEFAULT_MAX_DEPTH, file::Bool=false)
return write(io, tbl, file, largelists, compress, denseunions, dictencode, dictencodenested, alignment, maxdepth)
end

function write(io, source, writetofile, largelists, compress, denseunions, dictencode, dictencodenested, alignment)
function write(io, source, writetofile, largelists, compress, denseunions, dictencode, dictencodenested, alignment, maxdepth)
if compress === :lz4
compress = LZ4_FRAME_COMPRESSOR
elseif compress === :zstd
@@ -108,7 +111,7 @@ function write(io, source, writetofile, largelists, compress, denseunions, dicte
@sync for (i, tbl) in enumerate(Tables.partitions(source))
@debug 1 "processing table partition i = $i"
if i == 1
cols = toarrowtable(tbl, dictencodings, largelists, compress, denseunions, dictencode, dictencodenested)
cols = toarrowtable(tbl, dictencodings, largelists, compress, denseunions, dictencode, dictencodenested, maxdepth)
sch[] = Tables.schema(cols)
firstcols[] = cols
put!(msgs, makeschemamsg(sch[], cols), i)
@@ -124,7 +127,7 @@ function write(io, source, writetofile, largelists, compress, denseunions, dicte
put!(msgs, makerecordbatchmsg(sch[], cols, alignment), i, true)
else
Threads.@spawn begin
cols = toarrowtable(tbl, dictencodings, largelists, compress, denseunions, dictencode, dictencodenested)
cols = toarrowtable(tbl, dictencodings, largelists, compress, denseunions, dictencode, dictencodenested, maxdepth)
if !isempty(cols.dictencodingdeltas)
for de in cols.dictencodingdeltas
dictsch = Tables.Schema((:col,), (eltype(de.data),))
@@ -188,7 +191,7 @@ struct ToArrowTable
dictencodingdeltas::Vector{DictEncoding}
end

function toarrowtable(x, dictencodings, largelists, compress, denseunions, dictencode, dictencodenested)
function toarrowtable(x, dictencodings, largelists, compress, denseunions, dictencode, dictencodenested, maxdepth)
@debug 1 "converting input table to arrow formatted columns"
cols = Tables.columns(x)
meta = getmetadata(cols)
@@ -199,7 +202,7 @@ function toarrowtable(x, dictencodings, largelists, compress, denseunions, dicte
newtypes = Vector{Type}(undef, N)
dictencodingdeltas = DictEncoding[]
Tables.eachcolumn(sch, cols) do col, i, nm
newcol = toarrowvector(col, i, dictencodings, dictencodingdeltas; compression=compress, largelists=largelists, denseunions=denseunions, dictencode=dictencode, dictencodenested=dictencodenested)
newcol = toarrowvector(col, i, dictencodings, dictencodingdeltas; compression=compress, largelists=largelists, denseunions=denseunions, dictencode=dictencode, dictencodenested=dictencodenested, maxdepth=maxdepth)
newtypes[i] = eltype(newcol)
newcols[i] = newcol
end

0 comments on commit 0f1b350

Please sign in to comment.