Skip to content
Permalink
Browse files
Allow initial support for custom extension types (#11)
* Allow initial support for custom extension types

Fixes #10. This PR is large mainly because it adds fairly generic
support for custom field metadata. We then use that support to add
custom type extension for Symbol/Char vectors. We could perhaps factor
this a bit more to make it even more generic for custom types, but I
think this gets us pretty darn far.

* Fix dict encoded
  • Loading branch information
quinnj committed Oct 3, 2020
1 parent 24be05c commit 7f8c4e0385954a84adbf8450a95f40c03e2f858f
Show file tree
Hide file tree
Showing 5 changed files with 135 additions and 42 deletions.
@@ -26,7 +26,12 @@ finaljuliatype(::Type{Union{T, Missing}}) where {T} = Union{Missing, finaljuliat

function juliaeltype(f::Meta.Field)
T = juliaeltype(f, f.type)
return f.nullable ? Union{T, Missing} : T
if f.custom_metadata !== nothing
fm = Dict(kv.key => kv.value for kv in f.custom_metadata)
else
fm = nothing
end
return (f.nullable ? Union{T, Missing} : T), fm
end

juliaeltype(f::Meta.Field, ::Meta.Null) = Missing
@@ -179,11 +184,9 @@ function arrowtype(b, ::Type{Date{U, T}}) where {U, T}
return Meta.Date, Meta.dateEnd(b), nothing
end

arrowtype(b, ::Type{Dates.Date}) = arrowtype(b, Date{Meta.DateUnit.DAY, Int32})
const UNIX_EPOCH_DATE = Dates.value(Dates.Date(1970))
Base.convert(::Type{Date{Meta.DateUnit.DAY, Int32}}, x::Dates.Date) = Date{Meta.DateUnit.DAY, Int32}(Int32(Dates.value(x) - UNIX_EPOCH_DATE))

arrowtype(b, ::Type{Dates.DateTime}) = arrowtype(b, Date{Meta.DateUnit.MILLISECOND, Int64})
const UNIX_EPOCH_DATETIME = Dates.value(Dates.DateTime(1970))
Base.convert(::Type{Date{Meta.DateUnit.MILLISECOND, Int64}}, x::Dates.DateTime) = Date{Meta.DateUnit.MILLISECOND, Int64}(Int64(Dates.value(x) - UNIX_EPOCH_DATETIME))

@@ -209,7 +212,6 @@ function arrowtype(b, ::Type{Time{U, T}}) where {U, T}
return Meta.Time, Meta.timeEnd(b), nothing
end

arrowtype(b, ::Type{Dates.Time}) = arrowtype(b, Time{Meta.TimeUnit.NANOSECOND, Int64})
Base.convert(::Type{Time{Meta.TimeUnit.NANOSECOND, Int64}}, x::Dates.Time) = Time{Meta.TimeUnit.NANOSECOND, Int64}(Dates.value(x))

struct Timestamp{U, TZ} <: ArrowTimeType
@@ -278,41 +280,41 @@ arrowperiodtype(::Type{Dates.Millisecond}) = Meta.TimeUnit.MILLISECOND
arrowperiodtype(::Type{Dates.Microsecond}) = Meta.TimeUnit.MICROSECOND
arrowperiodtype(::Type{Dates.Nanosecond}) = Meta.TimeUnit.NANOSECOND

arrowtype(b, ::Type{P}) where {P <: Dates.Period} = arrowtype(b, Duration{arrowperiodtype(P)})
Base.convert(::Type{Duration{U}}, x::Dates.Period) where {U} = Duration{U}(Dates.value(periodtype(U)(x)))

# nested types; call juliaeltype recursively on nested children
function juliaeltype(f::Meta.Field, list::Union{Meta.List, Meta.LargeList})
return Vector{juliaeltype(f.children[1])}
T, _ = juliaeltype(f.children[1])
return Vector{T}
end

# arrowtype will call fieldoffset recursively for children
function arrowtype(b, ::Type{Vector{T}}) where {T}
children = [fieldoffset(b, -1, "", T, nothing)]
children = [fieldoffset(b, -1, "", T, nothing, nothing)]
Meta.listStart(b)
return Meta.List, Meta.listEnd(b), children
end

function juliaeltype(f::Meta.Field, list::Meta.FixedSizeList)
type = juliaeltype(f.children[1])
type, _ = juliaeltype(f.children[1])
return NTuple{Int(list.listSize), type}
end

function arrowtype(b, ::Type{NTuple{N, T}}) where {N, T}
children = [fieldoffset(b, -1, "", T, nothing)]
children = [fieldoffset(b, -1, "", T, nothing, nothing)]
Meta.fixedSizeListStart(b)
Meta.fixedSizeListAddListSize(b, Int32(N))
return Meta.FixedSizeList, Meta.fixedSizeListEnd(b), children
end

function juliaeltype(f::Meta.Field, map::Meta.Map)
K = juliaeltype(f.children[1].children[1])
V = juliaeltype(f.children[1].children[2])
K, _ = juliaeltype(f.children[1].children[1])
V, _ = juliaeltype(f.children[1].children[2])
return Pair{K, V}
end

function arrowtype(b, ::Type{Pair{K, V}}) where {K, V}
children = [fieldoffset(b, -1, "entries", KeyValue{K, V}, nothing)]
children = [fieldoffset(b, -1, "entries", KeyValue{K, V}, nothing, nothing)]
Meta.mapStart(b)
return Meta.Map, Meta.mapEnd(b), children
end
@@ -328,19 +330,19 @@ Base.iterate(kv::KeyValue, st=1) = st === nothing ? nothing : (kv, nothing)
default(::Type{KeyValue{K, V}}) where {K, V} = KeyValue(default(K), default(V))

function arrowtype(b, ::Type{KeyValue{K, V}}) where {K, V}
children = [fieldoffset(b, -1, "key", K, nothing), fieldoffset(b, -1, "value", V, nothing)]
children = [fieldoffset(b, -1, "key", K, nothing, nothing), fieldoffset(b, -1, "value", V, nothing, nothing)]
Meta.structStart(b)
return Meta.Struct, Meta.structEnd(b), children
end

function juliaeltype(f::Meta.Field, list::Meta.Struct)
names = Tuple(Symbol(x.name) for x in f.children)
types = Tuple(juliaeltype(x) for x in f.children)
types = Tuple(juliaeltype(x)[1] for x in f.children)
return NamedTuple{names, Tuple{types...}}
end

function arrowtype(b, ::Type{NamedTuple{names, types}}) where {names, types}
children = [fieldoffset(b, -1, names[i], fieldtype(types, i), nothing) for i = 1:length(names)]
children = [fieldoffset(b, -1, names[i], fieldtype(types, i), nothing, nothing) for i = 1:length(names)]
Meta.structStart(b)
return Meta.Struct, Meta.structEnd(b), children
end
@@ -349,7 +351,7 @@ default(::Type{NamedTuple{names, types}}) where {names, types} = NamedTuple{name

# Unions
function juliaeltype(f::Meta.Field, u::Meta.Union)
return UnionT{u.mode, u.typeIds !== nothing ? Tuple(u.typeIds) : u.typeIds, Tuple{(juliaeltype(x) for x in f.children)...}}
return UnionT{u.mode, u.typeIds !== nothing ? Tuple(u.typeIds) : u.typeIds, Tuple{(juliaeltype(x)[1] for x in f.children)...}}
end

# Note: nested Union types can't be represented using julia's builtin Union{...}
@@ -363,7 +365,7 @@ function arrowtype(b, ::Type{UnionT{T, typeIds, U}}) where {T, typeIds, U}
end
TI = FlatBuffers.endvector!(b, length(typeIds))
end
children = [fieldoffset(b, -1, "", fieldtype(U, i), nothing) for i = 1:fieldcount(U)]
children = [fieldoffset(b, -1, "", fieldtype(U, i), nothing, nothing) for i = 1:fieldcount(U)]
Meta.unionStart(b)
Meta.unionAddMode(b, T)
if typeIds !== nothing
@@ -53,16 +53,21 @@ function Table(bytes::Vector{UInt8}, off::Integer=1, tlen::Union{Integer, Nothin
sch = nothing
dictencodings = Dict{Int64, DictEncoding}()
dictencoded = Dict{Int64, Tuple{Bool, Type, Meta.Field}}()
fieldmetadata = Dict{Int, Dict{String, String}}()
for batch in BatchIterator{debug}(bytes, off)
# store custom_metadata of batch.msg?
header = batch.msg.header
if header isa Meta.Schema
debug && println("parsing schema message")
# assert endianness?
# store custom_metadata?
for field in header.fields
for (i, field) in enumerate(header.fields)
push!(names(t), Symbol(field.name))
push!(types(t), juliaeltype(field))
T, metadata = juliaeltype(field)
if metadata !== nothing
fieldmetadata[i] = metadata
end
push!(types(t), T)
d = field.dictionary
isencoded = false
if d !== nothing
@@ -117,6 +122,19 @@ function Table(bytes::Vector{UInt8}, off::Integer=1, tlen::Union{Integer, Nothin
end
lu = lookup(t)
for (i, (k, T, col)) in enumerate(zip(names(t), types(t), columns(t)))
if haskey(fieldmetadata, i) && haskey(fieldmetadata[i], "ARROW:extension:name")
if fieldmetadata[i]["ARROW:extension:name"] == "JuliaLang.Symbol"
TT = finaljuliatype(Symbol)
types(t)[i] = TT
col = converter(TT, col)
columns(t)[i] = col
elseif fieldmetadata[i]["ARROW:extension:name"] == "JuliaLang.Char"
TT = finaljuliatype(Char)
types(t)[i] = TT
col = converter(TT, col)
columns(t)[i] = col
end
end
if convert
TT = finaljuliatype(T)
if TT !== T
@@ -236,11 +236,14 @@ struct Converter{T, A} <: AbstractVector{T}
end

converter(::Type{T}, x::A) where {T, A} = Converter{eltype(A) >: Missing ? Union{T, Missing} : T, A}(x)
converter(::Type{T}, x::ChainedVector{A}) where {T, A} = ChainedVector(Vector{A}[converter(T, x) for x in x.arrays])

Base.IndexStyle(::Type{<:Converter}) = Base.IndexLinear()
Base.size(x::Converter) = (length(x.data),)
Base.eltype(x::Converter{T, A}) where {T, A} = T
Base.getindex(x::Converter{T}, i::Int) where {T} = convert(T, getindex(x.data, i))
Base.getindex(x::Converter{Symbol, A}, i::Int) where {T, A <: AbstractVector{String}} = Symbol(getindex(x.data, i))
Base.getindex(x::Converter{Char, A}, i::Int) where {T, A <: AbstractVector{String}} = getindex(x.data, i)[1]

maybemissing(::Type{T}) where {T} = T === Missing ? Missing : Base.nonmissingtype(T)

@@ -48,16 +48,16 @@ function write(io, source, writetofile, debug)
# start message writing from channel
@static if VERSION >= v"1.3-DEV"
tsk = Threads.@spawn for msg in msgs
Base.write(io, msg, blocks)
Base.write(io, msg, blocks, sch)
end
else
tsk = @async for msg in msgs
Base.write(io, msg, blocks)
Base.write(io, msg, blocks, sch)
end
end
@sync for (i, tbl) in enumerate(parts(source))
if i == 1
cols = Tables.columns(tbl)
cols = Tables.columns(toarrowtable(tbl))
sch[] = Tables.schema(cols)
firstcols[] = cols
for (i, col) in enumerate(Tables.Columns(cols))
@@ -68,6 +68,7 @@ end
dictencodings[i] = (id, encodingtype(length(values)), values)
end
end
@show sch[]
put!(msgs, makeschemamsg(sch[], cols, dictencodings))
if !isempty(dictencodings)
for (colidx, (id, T, values)) in dictencodings
@@ -126,7 +127,7 @@ end
wait(tsk)
# write empty message
if !writetofile
Base.write(io, Message(UInt8[], nothing, nothing, 0, true, false), blocks)
Base.write(io, Message(UInt8[], nothing, nothing, 0, true, false), blocks, sch)
end
if writetofile
b = FlatBuffers.Builder(1024)
@@ -166,6 +167,54 @@ end
return io
end

struct ToArrowTable
sch::Tables.Schema
cols::Vector{Any}
fieldmetadata::Dict{Int, Dict{String, String}}
end

function toarrowtable(x)
cols = Tables.columns(x)
sch = Tables.schema(cols)
types = sch.types
N = length(types)
newcols = Vector{Any}(undef, N)
newtypes = Vector{Type}(undef, N)
fieldmetadata = Dict{Int, Dict{String, String}}()
Tables.eachcolumn(sch, cols) do col, i, nm
T, newcol = toarrow(types[i], i, col, fieldmetadata)
@inbounds newtypes[i] = T
@inbounds newcols[i] = newcol
end
return ToArrowTable(Tables.Schema(sch.names, newtypes), newcols, fieldmetadata)
end

toarrow(::Type{T}, i, col, fm) where {T} = T, col
toarrow(::Type{Dates.Date}, i, col, fm) = Date{Meta.DateUnit.DAY, Int32}, converter(Date{Meta.DateUnit.DAY, Int32}, col)
toarrow(::Type{Dates.Time}, i, col, fm) = Time{Meta.TimeUnit.NANOSECOND, Int64}, converter(Time{Meta.TimeUnit.NANOSECOND, Int64}, col)
toarrow(::Type{Dates.DateTime}, i, col, fm) = Date{Meta.DateUnit.MILLISECOND, Int64}, converter(Date{Meta.DateUnit.MILLISECOND, Int64}, col)
toarrow(::Type{P}, i, col, fm) where {P <: Dates.Period} = Duration{arrowperiodtype(P)}, converter(Duration{arrowperiodtype(P)}, col)

function toarrow(::Type{Symbol}, i, col, fm)
meta = get!(() -> Dict{String, String}(), fm, i)
meta["ARROW:extension:name"] = "JuliaLang.Symbol"
meta["ARROW:extension:metadata"] = ""
return String, (String(x) for x in col)
end

function toarrow(::Type{Char}, i, col, fm)
meta = get!(() -> Dict{String, String}(), fm, i)
meta["ARROW:extension:name"] = "JuliaLang.Char"
meta["ARROW:extension:metadata"] = ""
return String, (string(x) for x in col)
end

Tables.columns(x::ToArrowTable) = x
Tables.rowcount(x::ToArrowTable) = length(x.cols) == 0 ? 0 : length(x.cols[1])
Tables.schema(x::ToArrowTable) = x.sch
Tables.columnnames(x::ToArrowTable) = x.sch.names
Tables.getcolumn(x::ToArrowTable, i::Int) = x.cols[i]

struct Message
msgflatbuf
columns
@@ -194,7 +243,7 @@ Base.size(x::DictEncoder) = (length(x.values),)
Base.eltype(x::DictEncoder{T, A}) where {T, A} = T
Base.getindex(x::DictEncoder, i::Int) = x.pool[x.values[i]]

function Base.write(io::IO, msg::Message, blocks)
function Base.write(io::IO, msg::Message, blocks, sch)
metalen = padding(length(msg.msgflatbuf))
if msg.blockmsg
push!(blocks[msg.isrecordbatch ? 1 : 2], Block(position(io), metalen + 8, msg.bodylen))
@@ -209,14 +258,16 @@ function Base.write(io::IO, msg::Message, blocks)
n += writezeros(io, paddinglength(n))
# message body
if msg.columns !== nothing
types = sch[].types
# write out buffers
for i = 1:length(Tables.columnnames(msg.columns))
col = Tables.getcolumn(msg.columns, i)
T = types[i]
if msg.dictencodings !== nothing && haskey(msg.dictencodings, i)
_, T, vals = msg.dictencodings[i]
col = DictEncoder(col, vals, T)
end
writebuffer(io, eltype(col) === Missing ? Missing : Base.nonmissingtype(eltype(col)), col)
writebuffer(io, T === Missing ? Missing : Base.nonmissingtype(T), col)
end
end
return n
@@ -239,7 +290,7 @@ end
function makeschema(b, sch::Tables.Schema{names, types}, columns, dictencodings) where {names, types}
# build Field objects
N = length(names)
fieldoffsets = [fieldoffset(b, i, names[i], fieldtype(types, i), dictencodings) for i = 1:N]
fieldoffsets = [fieldoffset(b, i, names[i], fieldtype(types, i), dictencodings, columns.fieldmetadata) for i = 1:N]
Meta.schemaStartFieldsVector(b, N)
for off in Iterators.reverse(fieldoffsets)
FlatBuffers.prependoffset!(b, off)
@@ -259,9 +310,29 @@ function makeschemamsg(sch::Tables.Schema{names, types}, columns, dictencodings)
return makemessage(b, Meta.Schema, schema)
end

function fieldoffset(b, colidx, name, T, dictencodings)
function fieldoffset(b, colidx, name, T, dictencodings, metadata)
nameoff = FlatBuffers.createstring!(b, String(name))
nullable = T >: Missing
# check for custom metadata
if metadata !== nothing && haskey(metadata, colidx)
kvs = metadata[colidx]
kvoffs = Vector{FlatBuffers.UOffsetT}(undef, length(kvs))
for (i, (k, v)) in enumerate(kvs)
koff = FlatBuffers.createstring!(b, String(k))
voff = FlatBuffers.createstring!(b, String(v))
Meta.keyValueStart(b)
Meta.keyValueAddKey(b, koff)
Meta.keyValueAddValue(b, voff)
kvoffs[i] = Meta.keyValueEnd(b)
end
Meta.fieldStartCustomMetadataVector(b, length(kvs))
for off in Iterators.reverse(kvoffs)
FlatBuffers.prependoffset!(b, off)
end
meta = FlatBuffers.endvector!(b, length(kvs))
else
meta = FlatBuffers.UOffsetT(0)
end
# build dictionary
if dictencodings !== nothing && haskey(dictencodings, colidx)
id, encodingtype, _ = dictencodings[colidx]
@@ -294,7 +365,7 @@ function fieldoffset(b, colidx, name, T, dictencodings)
Meta.fieldAddType(b, typeoff)
Meta.fieldAddDictionary(b, dict)
Meta.fieldAddChildren(b, children)
# Meta.fieldAddCustomMetadata(b, meta)
Meta.fieldAddCustomMetadata(b, meta)
return Meta.fieldEnd(b)
end

@@ -410,15 +481,6 @@ function makenodesbuffers!(::Type{T}, col, fieldnodes, fieldbuffers, bufferoffse
return bufferoffset + padding(blen)
end

makenodesbuffers!(::Type{Dates.Date}, col, fieldnodes, fieldbuffers, bufferoffset) =
makenodesbuffers!(Date{Meta.DateUnit.DAY, Int32}, converter(Date{Meta.DateUnit.DAY, Int32}, col), fieldnodes, fieldbuffers, bufferoffset)
makenodesbuffers!(::Type{Dates.Time}, col, fieldnodes, fieldbuffers, bufferoffset) =
makenodesbuffers!(Time{Meta.TimeUnit.NANOSECOND, Int64}, converter(Time{Meta.TimeUnit.NANOSECOND, Int64}, col), fieldnodes, fieldbuffers, bufferoffset)
makenodesbuffers!(::Type{Dates.DateTime}, col, fieldnodes, fieldbuffers, bufferoffset) =
makenodesbuffers!(Date{Meta.DateUnit.MILLISECOND, Int64}, converter(Date{Meta.DateUnit.MILLISECOND, Int64}, col), fieldnodes, fieldbuffers, bufferoffset)
makenodesbuffers!(::Type{P}, col, fieldnodes, fieldbuffers, bufferoffset) where {P <: Dates.Period} =
makenodesbuffers!(Duration{arrowperiodtype(P)}, converter(Duration{arrowperiodtype(P)}, col), fieldnodes, fieldbuffers, bufferoffset)

function writebitmap(io, col)
nullcount(col) == 0 && return 0
len = _length(col)
@@ -447,11 +509,6 @@ function writebuffer(io, ::Type{T}, col) where {T}
return
end

writebuffer(io, ::Type{Dates.Date}, col) = writebuffer(io, Date{Meta.DateUnit.DAY, Int32}, converter(Date{Meta.DateUnit.DAY, Int32}, col))
writebuffer(io, ::Type{Dates.Time}, col) = writebuffer(io, Time{Meta.TimeUnit.NANOSECOND, Int64}, converter(Time{Meta.TimeUnit.NANOSECOND, Int64}, col))
writebuffer(io, ::Type{Dates.DateTime}, col) = writebuffer(io, Date{Meta.DateUnit.MILLISECOND, Int64}, converter(Date{Meta.DateUnit.MILLISECOND, Int64}, col))
writebuffer(io, ::Type{P}, col) where {P <: Dates.Period} = writebuffer(io, Duration{arrowperiodtype(P)}, converter(Duration{arrowperiodtype(P)}, col))

function makenodesbuffers!(::Type{T}, col, fieldnodes, fieldbuffers, bufferoffset) where {T <: Union{AbstractString, AbstractVector}}
len = _length(col)
nc = nullcount(col)
@@ -171,4 +171,17 @@ tt = Arrow.Table(io)
@test length(tt) == length(t)
@test all(isequal.(values(t), values(tt)))

# non-standard types
t = (
col1=[:hey, :there, :sailor],
col2=['a', 'b', 'c'],
)
io = IOBuffer()
Arrow.write(io, t)
seekstart(io)
tt = Arrow.Table(io)
@test length(tt) == length(t)
@test all(isequal.(values(t), values(tt)))


end

0 comments on commit 7f8c4e0

Please sign in to comment.