Skip to content

Commit

Permalink
partitioned reader with Tables.partitions
Browse files Browse the repository at this point in the history
Make `Tables.partitions(Parquet.Table)` implemenation to do lazy loading of partitions.
  • Loading branch information
tanmaykm committed Mar 2, 2021
1 parent 3573f98 commit 4ba8213
Show file tree
Hide file tree
Showing 3 changed files with 115 additions and 42 deletions.
10 changes: 9 additions & 1 deletion src/schema.jl
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,15 @@ end
function max_definition_level(sch::Schema, schname::T) where {T <: AbstractVector{String}}
lev = isrequired(sch, schname) ? 0 : 1
istoplevel(schname) ? lev : (lev + max_definition_level(sch, parentname(schname)))
end
end

tables_schema(parfile) = tables_schema(schema(parfile))
function tables_schema(sch::Schema)
cols = Parquet.ntcolstype(sch, sch.schema[1])
colnames = fieldnames(cols)
coltypes = eltype.(fieldtypes(cols))
Tables.Schema(colnames, coltypes)
end

logical_decimal_unscaled_type(precision::Int32) = (precision < 5) ? UInt16 :
(precision < 10) ? UInt32 :
Expand Down
130 changes: 96 additions & 34 deletions src/simple_reader.jl
Original file line number Diff line number Diff line change
@@ -1,22 +1,6 @@

struct Table <: Tables.AbstractColumns
schema::Tables.Schema
chunks::Vector
lookup::Dict{Symbol, Int} # map column name => index
columns::Vector{AbstractVector}
end

Tables.istable(::Table) = true
Tables.columnaccess(::Table) = true
Tables.columns(t::Table) = Tables.CopiedColumns(t)
Tables.schema(t::Table) = getfield(t, :schema)
Tables.columnnames(t::Table) = getfield(t, :schema).names
Tables.getcolumn(t::Table, nm::Symbol) = Tables.getcolumn(t, getfield(t, :lookup)[nm])
Tables.getcolumn(t::Table, i::Int) = getfield(t, :columns)[i]
Tables.partitions(t::Table) = getfield(t, :chunks)

"""
read_parquet(path)
read_parquet(path; kwargs...)
Parquet.Table(path; kwargs...)
Returns the table contained in the parquet file in an Tables.jl compatible format.
Expand All @@ -33,20 +17,98 @@ using DataFrames
df = DataFrame(read_parquet(path; copycols=false))
```
"""
function read_parquet(path;
rows::Union{Nothing,UnitRange}=nothing,
batchsize::Union{Nothing,Signed}=nothing,
use_threads::Bool=(nthreads() > 1))

parquetfile = Parquet.File(path);
kwargs = Dict{Symbol,Any}(:use_threads => use_threads, :reusebuffer => false)
(rows === nothing) || (kwargs[:rows] = rows)
(batchsize === nothing) || (kwargs[:batchsize] = batchsize)

# read all the chunks
chunks = [chunk for chunk in BatchedColumnsCursor(parquetfile; kwargs...)]
sch = Tables.schema(chunks[1])
N = length(sch.names)
columns = length(chunks) == 1 ? AbstractVector[chunks[1][i] for i = 1:N] : AbstractVector[ChainedVector([chunks[j][i] for j = 1:length(chunks)]) for i = 1:N]
return Table(sch, chunks, Dict{Symbol, Int}(nm => i for (i, nm) in enumerate(sch.names)), columns)
struct Table <: Tables.AbstractColumns
path::String
rows::Union{Nothing,UnitRange}
batchsize::Union{Nothing,Signed}
use_threads::Bool
parfile::Parquet.File
schema::Tables.Schema
lookup::Dict{Symbol, Int} # map column name => index
columns::Vector{AbstractVector}

function Table(path;
rows::Union{Nothing,UnitRange}=nothing,
batchsize::Union{Nothing,Signed}=nothing,
use_threads::Bool=(nthreads() > 1))
parfile = Parquet.File(path)
sch = tables_schema(parfile)
lookup = Dict{Symbol, Int}(nm => i for (i, nm) in enumerate(sch.names))
new(path, rows, batchsize, use_threads, parfile, sch, lookup, AbstractVector[])
end
end

const read_parquet = Table

struct TablePartition <: Tables.AbstractColumns
table::Table
columns::Vector{AbstractVector}
end

struct TablePartitions
table::Table
cursor::BatchedColumnsCursor

function TablePartitions(table::Table)
new(table, cursor(table))
end
end
length(tp::TablePartitions) = length(tp.cursor)
function iterated_partition(partitions::TablePartitions, iterresult)
(iterresult === nothing) && (return nothing)
chunk, batchid = iterresult
sch = Tables.schema(getfield(partitions, :table))
ncols = length(sch.names)
TablePartition(partitions.table, AbstractVector[chunk[colidx] for colidx in 1:ncols]), batchid
end
Base.iterate(partitions::TablePartitions, batchid) = iterated_partition(partitions, iterate(partitions.cursor, batchid))
Base.iterate(partitions::TablePartitions) = iterated_partition(partitions, iterate(partitions.cursor))

function cursor(table::Table)
kwargs = Dict{Symbol,Any}(:use_threads => getfield(table, :use_threads), :reusebuffer => false)
(getfield(table, :rows) === nothing) || (kwargs[:rows] = getfield(table, :rows))
(getfield(table, :batchsize) === nothing) || (kwargs[:batchsize] = getfield(table, :batchsize))
BatchedColumnsCursor(getfield(table, :parfile); kwargs...)
end

loaded(table::Table) = !isempty(getfield(table, :columns))
load(table::Table) = load(table, cursor(table))
function load(table::Table, chunks::BatchedColumnsCursor)
chunks = [chunk for chunk in chunks]
sch = Tables.schema(table)
ncols = length(sch.names)
columns = getfield(table, :columns)

empty!(columns)
nchunks = length(chunks)
if nchunks == 1
for colidx in 1:ncols
push!(columns, chunks[1][colidx])
end
else
for colidx in 1:ncols
push!(columns, ChainedVector([chunks[chunkidx][colidx] for chunkidx = 1:nchunks]))
end
end
nothing
end

Tables.istable(::Table) = true
Tables.columnaccess(::Table) = true
Tables.schema(t::Table) = getfield(t, :schema)
Tables.columnnames(t::Table) = getfield(t, :schema).names
Tables.columns(t::Table) = Tables.CopiedColumns(t)
Tables.getcolumn(t::Table, nm::Symbol) = Tables.getcolumn(t, getfield(t, :lookup)[nm])
function Tables.getcolumn(t::Table, i::Int)
loaded(t) || load(t)
getfield(t, :columns)[i]
end
Tables.partitions(t::Table) = TablePartitions(t)

Tables.istable(::TablePartition) = true
Tables.columnaccess(::TablePartition) = true
Tables.schema(tp::TablePartition) = Tables.schema(getfield(tp, :table))
Tables.columnnames(tp::TablePartition) = Tables.columnnames(getfield(tp, :table))
Tables.columns(tp::TablePartition) = Tables.CopiedColumns(tp)
Tables.getcolumn(tp::TablePartition, nm::Symbol) = Tables.getcolumn(tp, getfield(getfield(tp, :table), :lookup)[nm])
Tables.getcolumn(tp::TablePartition, i::Int) = getfield(tp, :columns)[i]
17 changes: 10 additions & 7 deletions test/test_load.jl
Original file line number Diff line number Diff line change
Expand Up @@ -270,23 +270,26 @@ function test_load_file()
table = read_parquet(filename)
cols = Tables.columns(table)
@test all([length(col)==100 for col in cols]) # all columns must be 100 rows long
@test length(getfield(table, :chunks)) == 2
@test 50 == length(getfield(table, :chunks)[1][1])
@test 50 == length(getfield(table, :chunks)[2][1])
@test length(cols) == 12 # 12 columns
partitions = Tables.partitions(table)
@test length(partitions) == 2
@test length(collect(partitions)) == 2

table = read_parquet(filename; rows=1:10)
cols = Tables.columns(table)
@test all([length(col)==10 for col in cols]) # all columns must be 100 rows long
@test length(getfield(table, :chunks)) == 1
@test length(cols) == 12 # 12 columns
partitions = Tables.partitions(table)
@test length(partitions) == 1
@test length(collect(partitions)) == 1

table = read_parquet(filename; rows=1:100, batchsize=10)
cols = Tables.columns(table)
@test all([length(col)==100 for col in cols]) # all columns must be 100 rows long
@test 10 == length(getfield(table, :chunks)[1][1])
@test length(getfield(table, :chunks)) == 10
@test all([length(col)==100 for col in cols]) # all columns must be 100 rows long
@test length(cols) == 12 # 12 columns
partitions = Tables.partitions(table)
@test length(partitions) == 10
@test length(collect(partitions)) == 10
end
end

Expand Down

0 comments on commit 4ba8213

Please sign in to comment.