Skip to content
Permalink
Browse files
Refactor nested dict encoding and isdelta dictionary batch support (#43)
Fixes #32 among other issues. Turns out the probably-rare-in-practice
data race mentioned in #32 was the least of the worries. While digging
into things, I realized we weren't doing isDelta dictionary batches
right at all. In particular, we were basically writing each record
batch/dictionary batch independent of each other, but using the same
dictionary batch ids. We didn't have tests failures because we weren't
testing isDelta batches anyway :P.

In this PR, everything is cleaned up quite a bit. We now generate a
dictionary encoding id based on the column index, nesting level, and
field index (in the case of structs and unions). This allows us to
re-use the same constructed dict encodings from batch to batch, and also
allows us to support the use-case of re-using a single dict encoding
across multiple columns if so desired (user would just pass in their own
`DictEncode` column pointing to the same id). We also avoid race
conditions by putting a lock around dict encodings so different threads
writing will have to take turns.
  • Loading branch information
quinnj committed Oct 22, 2020
1 parent 5fd86f6 commit 799ef104a078aed924e3f8b09bc9d4c28c7a9e98
Show file tree
Hide file tree
Showing 16 changed files with 193 additions and 174 deletions.
@@ -1,6 +1,6 @@
name = "Arrow"
uuid = "69666777-d1a9-59fb-9406-91d4454c9d45"
authors = ["quinnj <quinn.jacobd@gmail.com>", "ExpandingMan <expandingman@protonmail.com"]
authors = ["quinnj <quinn.jacobd@gmail.com>"]
version = "0.3.0"

[deps]
@@ -22,10 +22,10 @@ validitybitmap(x::ArrowVector) = x.validity
nullcount(x::ArrowVector) = validitybitmap(x).nc
getmetadata(x::ArrowVector) = x.metadata

function toarrowvector(x, de=DictEncoding[], meta=getmetadata(x); compression::Union{Nothing, LZ4FrameCompressor, ZstdCompressor}=nothing, kw...)
function toarrowvector(x, i=1, de=Dict{Int64, Any}(), ded=DictEncoding[], meta=getmetadata(x); compression::Union{Nothing, LZ4FrameCompressor, ZstdCompressor}=nothing, kw...)
@debug 2 "converting top-level column to arrow format: col = $(typeof(x)), compression = $compression, kw = $(kw.data)"
@debug 3 x
A = arrowvector(x, de, meta; compression=compression, kw...)
A = arrowvector(x, i, 0, 0, de, ded, meta; compression=compression, kw...)
if compression isa LZ4FrameCompressor
A = compress(Meta.CompressionType.LZ4_FRAME, compression, A)
elseif compression isa ZstdCompressor
@@ -36,36 +36,35 @@ function toarrowvector(x, de=DictEncoding[], meta=getmetadata(x); compression::U
return A
end

function arrowvector(x, de, meta; dictencoding::Bool=false, dictencode::Bool=false, kw...)
function arrowvector(x, i, nl, fi, de, ded, meta; dictencoding::Bool=false, dictencode::Bool=false, kw...)
if !(x isa DictEncode) && !dictencoding && (dictencode || (x isa AbstractArray && DataAPI.refarray(x) !== x))
x = DictEncode(x)
x = DictEncode(x, dictencodeid(i, nl, fi))
end
T = eltype(x)
S = maybemissing(T)
return arrowvector(S, T, x, de, meta; kw...)
S = maybemissing(eltype(x))
return arrowvector(S, x, i, nl, fi, de, ded, meta; dictencode=dictencode, kw...)
end

# conversions to arrow types
arrowvector(::Type{Dates.Date}, ::Type{S}, x, de, meta; kw...) where {S} =
arrowvector(converter(DATE, x), de, meta; kw...)
arrowvector(::Type{Dates.Time}, ::Type{S}, x, de, meta; kw...) where {S} =
arrowvector(converter(TIME, x), de, meta; kw...)
arrowvector(::Type{Dates.DateTime}, ::Type{S}, x, de, meta; kw...) where {S} =
arrowvector(converter(DATETIME, x), de, meta; kw...)
arrowvector(::Type{P}, ::Type{S}, x, de, meta; kw...) where {P <: Dates.Period, S} =
arrowvector(converter(Duration{arrowperiodtype(P)}, x), de, meta; kw...)
arrowvector(::Type{Dates.Date}, x, i, nl, fi, de, ded, meta; kw...) =
arrowvector(converter(DATE, x), i, nl, fi, de, ded, meta; kw...)
arrowvector(::Type{Dates.Time}, x, i, nl, fi, de, ded, meta; kw...) =
arrowvector(converter(TIME, x), i, nl, fi, de, ded, meta; kw...)
arrowvector(::Type{Dates.DateTime}, x, i, nl, fi, de, ded, meta; kw...) =
arrowvector(converter(DATETIME, x), i, nl, fi, de, ded, meta; kw...)
arrowvector(::Type{P}, x, i, nl, fi, de, ded, meta; kw...) where {P <: Dates.Period} =
arrowvector(converter(Duration{arrowperiodtype(P)}, x), i, nl, fi, de, ded, meta; kw...)

# fallback that calls ArrowType
function arrowvector(::Type{S}, ::Type{T}, x, de, meta; kw...) where {S, T}
function arrowvector(::Type{S}, x, i, nl, fi, de, ded, meta; kw...) where {S}
if ArrowTypes.istyperegistered(S)
meta = meta === nothing ? Dict{String, String}() : meta
arrowtype = ArrowTypes.getarrowtype!(meta, S)
return arrowvector(converter(arrowtype, x), de, meta; kw...)
return arrowvector(converter(arrowtype, x), i, nl, fi, de, ded, meta; kw...)
end
return arrowvector(ArrowType(S), x, de, meta; kw...)
return arrowvector(ArrowType(S), x, i, nl, fi, de, ded, meta; kw...)
end

arrowvector(::NullType, x, de, meta; kw...) = MissingVector(length(x))
arrowvector(::NullType, x, i, nl, fi, de, ded, meta; kw...) = MissingVector(length(x))
compress(Z::Meta.CompressionType, comp, v::MissingVector) =
Compressed{Z, MissingVector}(v, CompressedBuffer[], length(v), length(v), Compressed[])

@@ -44,7 +44,7 @@ end
return v
end

function arrowvector(::BoolType, x, de, meta; kw...)
function arrowvector(::BoolType, x, i, nl, fi, de, ded, meta; kw...)
validity = ValidityBitmap(x)
len = length(x)
blen = cld(len, 8)
@@ -33,10 +33,11 @@ struct DictEncodeType{T} end
getT(::Type{DictEncodeType{T}}) where {T} = T

struct DictEncode{T, A} <: AbstractVector{DictEncodeType{T}}
id::Int64
data::A
end

DictEncode(x::A) where {A} = DictEncode{eltype(A), A}(x)
DictEncode(x::A, id=-1) where {A} = DictEncode{eltype(A), A}(id, x)
Base.IndexStyle(::Type{<:DictEncode}) = Base.IndexLinear()
Base.size(x::DictEncode) = (length(x.data),)
Base.iterate(x::DictEncode, st...) = iterate(x.data, st...)
@@ -69,36 +70,70 @@ indtype(d::D) where {D <: DictEncoded} = indtype(D)
indtype(::Type{DictEncoded{T, S, A}}) where {T, S, A} = signedtype(S)
indtype(c::Compressed{Z, A}) where {Z, A <: DictEncoded} = indtype(A)

dictencodeid(colidx, nestedlevel, fieldid) = (Int64(nestedlevel) << 48) | (Int64(fieldid) << 32) | Int64(colidx)

getid(d::DictEncoded) = d.encoding.id
getid(c::Compressed{Z, A}) where {Z, A <: DictEncoded} = c.data.encoding.id

function arrowvector(::DictEncodedType, x, de, meta; dictencode::Bool=false, dictencodenested::Bool=false, kw...)
function arrowvector(::DictEncodedType, x, i, nl, fi, de, ded, meta; dictencode::Bool=false, dictencodenested::Bool=false, kw...)
@assert x isa DictEncode
id = x.id == -1 ? dictencodeid(i, nl, fi) : x.id
x = x.data
len = length(x)
validity = ValidityBitmap(x)
if x isa AbstractArray && DataAPI.refarray(x) !== x
inds = copy(DataAPI.refarray(x))
if !haskey(de, id)
# dict encoding doesn't exist yet, so create for 1st time
if DataAPI.refarray(x) === x
# need to encode ourselves
x = PooledArray(x)
inds = DataAPI.refarray(x)
else
inds = copy(DataAPI.refarray(x))
end
# adjust to "offset" instead of index
for i = 1:length(inds)
@inbounds inds[i] -= 1
end
pool = DataAPI.refpool(x)
data = arrowvector(pool, de, nothing; dictencode=dictencodenested, dictencodenested=dictencodenested, dictencoding=true, kw...)
encoding = DictEncoding{eltype(data), typeof(data)}(0, data, false)
data = arrowvector(DataAPI.refpool(x), i, nl, fi, de, ded, nothing; dictencode=dictencodenested, dictencodenested=dictencodenested, dictencoding=true, kw...)
encoding = DictEncoding{eltype(data), typeof(data)}(id, data, false)
de[id] = Lockable(encoding)
else
# need to encode ourselves
y = PooledArray(x)
inds = DataAPI.refarray(y)
# adjust to "offset" instead of index
for i = 1:length(inds)
@inbounds inds[i] = inds[i] - 1
# encoding already exists
# compute inds based on it
# if value doesn't exist in encoding, push! it
# also add to deltas updates
encodinglockable = de[id]
@lock encodinglockable begin
encoding = encodinglockable.x
pool = Dict(a => (b - 1) for (b, a) in enumerate(encoding))
deltas = eltype(x)[]
len = length(x)
inds = Vector{encodingtype(len)}(undef, len)
for (j, val) in enumerate(x)
@inbounds inds[j] = get!(pool, val) do
push!(deltas, val)
length(pool)
end
end
if !isempty(deltas)
data = arrowvector(deltas, i, nl, fi, de, ded, nothing; dictencode=dictencodenested, dictencodenested=dictencodenested, dictencoding=true, kw...)
push!(ded, DictEncoding{eltype(data), typeof(data)}(id, data, false))
if typeof(encoding.data) <: ChainedVector
append!(encoding.data, data)
else
data2 = ChainedVector([encoding.data, data])
encoding = DictEncoding{eltype(data2), typeof(data2)}(id, data2, false)
de[id] = Lockable(encoding)
end
end
end
data = arrowvector(DataAPI.refpool(y), de, nothing; dictencode=dictencodenested, dictencodenested=dictencodenested, dictencoding=true, kw...)
encoding = DictEncoding{eltype(data), typeof(data)}(0, data, false)
end
push!(de, encoding)
return DictEncoded(UInt8[], validity, inds, encoding, data.metadata)
if meta !== nothing && data.metadata !== nothing
merge!(meta, data.metadata)
elseif data.metadata !== nothing
meta = data.metadata
end
return DictEncoded(UInt8[], validity, inds, encoding, meta)
end

@propagate_inbounds function Base.getindex(d::DictEncoded, i::Integer)
@@ -83,14 +83,14 @@ end
return x, (i + 1, chunk, chunk_i, len)
end

function arrowvector(::FixedSizeListType, x, de, meta; kw...)
function arrowvector(::FixedSizeListType, x, i, nl, fi, de, ded, meta; kw...)
len = length(x)
validity = ValidityBitmap(x)
flat = ToFixedSizeList(x)
if eltype(flat) == UInt8
data = flat
else
data = arrowvector(flat, de, nothing; kw...)
data = arrowvector(flat, i, nl + 1, fi, de, ded, nothing; kw...)
end
return FixedSizeList{eltype(x), typeof(data)}(UInt8[], validity, data, len, meta)
end
@@ -173,15 +173,15 @@ end
return x, (i, chunk, chunk_i, chunk_len, len)
end

function arrowvector(::ListType, x, de, meta; largelists::Bool=false, kw...)
function arrowvector(::ListType, x, i, nl, fi, de, ded, meta; largelists::Bool=false, kw...)
len = length(x)
validity = ValidityBitmap(x)
flat = ToList(x; largelists=largelists)
offsets = Offsets(UInt8[], flat.inds)
if eltype(flat) == UInt8 # binary or utf8string
data = flat
else
data = arrowvector(flat, de, nothing; lareglists=largelists, kw...)
data = arrowvector(flat, i, nl + 1, fi, de, ded, nothing; lareglists=largelists, kw...)
end
return List{eltype(x), eltype(flat.inds), typeof(data)}(UInt8[], validity, offsets, data, len, meta)
end
@@ -37,7 +37,7 @@ end
keyvalues(KT, ::Missing) = missing
keyvalues(KT, x::AbstractDict) = [KT(k, v) for (k, v) in pairs(x)]

function arrowvector(::MapType, x, de, meta; largelists::Bool=false, kw...)
function arrowvector(::MapType, x, i, nl, fi, de, ded, meta; largelists::Bool=false, kw...)
len = length(x)
validity = ValidityBitmap(x)
ET = eltype(x)
@@ -47,7 +47,7 @@ function arrowvector(::MapType, x, de, meta; largelists::Bool=false, kw...)
T = DT !== ET ? Union{Missing, VT} : VT
flat = ToList(T[keyvalues(KT, y) for y in x]; largelists=largelists)
offsets = Offsets(UInt8[], flat.inds)
data = arrowvector(flat, de, nothing; lareglists=largelists, kw...)
data = arrowvector(flat, i, nl + 1, fi, de, ded, nothing; lareglists=largelists, kw...)
return Map{ET, eltype(flat.inds), typeof(data)}(validity, offsets, data, len, meta)
end

@@ -58,7 +58,7 @@ end
return v
end

function arrowvector(::PrimitiveType, x, de, meta; kw...)
function arrowvector(::PrimitiveType, x, i, nl, fi, de, ded, meta; kw...)
validity = ValidityBitmap(x)
return Primitive(eltype(x), UInt8[], validity, x, length(x), meta)
end
@@ -67,14 +67,14 @@ Base.size(x::ToStruct) = (length(x.data),)
Base.@propagate_inbounds function Base.getindex(A::ToStruct{T, j}, i::Integer) where {T, j}
@boundscheck checkbounds(A, i)
@inbounds x = A.data[i]
return @miss_or(x, @inbounds getfield(x, j))
return x === missing ? ArrowTypes.default(T) : getfield(x, j)
end

function arrowvector(::StructType, x, de, meta; kw...)
function arrowvector(::StructType, x, i, nl, fi, de, ded, meta; kw...)
len = length(x)
validity = ValidityBitmap(x)
T = Base.nonmissingtype(eltype(x))
data = Tuple(arrowvector(ToStruct(x, i), de, nothing; kw...) for i = 1:fieldcount(T))
data = Tuple(arrowvector(ToStruct(x, j), i, nl + 1, j, de, ded, nothing; kw...) for j = 1:fieldcount(T))
return Struct{eltype(x), typeof(data)}(validity, data, len, meta)
end

@@ -151,20 +151,20 @@ Base.@propagate_inbounds function Base.getindex(A::ToSparseUnion{T}, i::Integer)
return @inbounds x isa T ? x : ArrowTypes.default(T)
end

arrowvector(U::Union, ::Type{S}, x, de, meta; denseunions::Bool=true, kw...) where {S} =
arrowvector(denseunions ? DenseUnionVector(x) : SparseUnionVector(x), de, meta; denseunions=denseunions, kw...)
arrowvector(U::Union, x, i, nl, fi, de, ded, meta; denseunions::Bool=true, kw...) =
arrowvector(denseunions ? DenseUnionVector(x) : SparseUnionVector(x), i, nl, fi, de, ded, meta; denseunions=denseunions, kw...)

function arrowvector(::UnionType, x, de, meta; kw...)
function arrowvector(::UnionType, x, i, nl, fi, de, ded, meta; kw...)
UT = eltype(x)
if unionmode(UT) == Meta.UnionMode.Dense
x = x isa DenseUnionVector ? x.itr : x
typeids, offsets, data = todense(UT, x)
data2 = map(y -> arrowvector(y, de, nothing; kw...), data)
data2 = map(y -> arrowvector(y[2], i, nl + 1, y[1], de, ded, nothing; kw...), enumerate(data))
return DenseUnion{UT, typeof(data2)}(UInt8[], UInt8[], typeids, offsets, data2, meta)
else
x = x isa SparseUnionVector ? x.itr : x
typeids = sparsetypeids(UT, x)
data3 = Tuple(arrowvector(ToSparseUnion(fieldtype(eltype(UT), i), x), de, nothing; kw...) for i = 1:fieldcount(eltype(UT)))
data3 = Tuple(arrowvector(ToSparseUnion(fieldtype(eltype(UT), j), x), i, nl + 1, j, de, ded, nothing; kw...) for j = 1:fieldcount(eltype(UT)))
return SparseUnion{UT, typeof(data3)}(UInt8[], typeids, data3, meta)
end
end
@@ -95,6 +95,7 @@ default(T) = zero(T)
default(::Type{Symbol}) = Symbol()
default(::Type{Char}) = '\0'
default(::Type{String}) = ""
default(::Type{Union{T, Missing}}) where {T} = default(T)

function default(::Type{A}) where {A <: AbstractVector{T}} where {T}
a = similar(A, 1)
@@ -127,8 +127,8 @@ function Base.getproperty(x::DictionaryBatch, field::Symbol)
return FlatBuffers.init(RecordBatch, FlatBuffers.bytes(x), y)
end
elseif field === :isDelta
o = FlatBuffers.offset(x, 4)
o != 0 && return FlatBuffers.get(x, o + FlatBuffers.pos(x), Bool)
o = FlatBuffers.offset(x, 8)
o != 0 && return FlatBuffers.get(x, o + FlatBuffers.pos(x), Base.Bool)
return false
end
return nothing
@@ -156,3 +156,64 @@ function readmessage(filebytes, off=9)

FlatBuffers.getrootas(Meta.Message, filebytes, off + 8)
end

# a custom Channel type that only allows put!-ing objects in a specific, monotonically increasing order
struct OrderedChannel{T}
chan::Channel{T}
cond::Threads.Condition
i::Ref{Int}
end

OrderedChannel{T}(sz) where {T} = OrderedChannel{T}(Channel{T}(sz), Threads.Condition(), Ref(1))
Base.iterate(ch::OrderedChannel, st...) = iterate(ch.chan, st...)

macro lock(obj, expr)
esc(quote
lock($obj)
try
$expr
finally
unlock($obj)
end
end)
end

# when put!-ing an object, operation may have to wait until other tasks have put their
# objects to ensure the channel is ordered correctly
function Base.put!(ch::OrderedChannel{T}, x::T, i::Integer, incr::Bool=false) where {T}
@lock ch.cond begin
while ch.i[] < i
# channel index too early, need to wait for other tasks to put their objects first
wait(ch.cond)
end
# now it's our turn
put!(ch.chan, x)
if incr
ch.i[] += 1
end
# wake up tasks that may be waiting to put their objects
notify(ch.cond)
end
return
end

function Base.close(ch::OrderedChannel)
@lock ch.cond begin
# just need to ensure any tasks waiting to put their tasks have had a chance to put
while Base.n_waiters(ch.cond) > 0
wait(ch.cond)
end
close(ch.chan)
end
return
end

struct Lockable{T}
x::T
lock::ReentrantLock
end

Lockable(x::T) where {T} = Lockable{T}(x, ReentrantLock())

Base.lock(x::Lockable) = lock(x.lock)
Base.unlock(x::Lockable) = unlock(x.lock)

0 comments on commit 799ef10

Please sign in to comment.