Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix dictionary encoding and decoding #234

Merged
merged 8 commits into from
Aug 14, 2023
90 changes: 40 additions & 50 deletions src/codec/decode.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,75 +30,65 @@ function decode(d::AbstractProtoDecoder, ::Type{T}) where {T <: Union{Enum{Int32
end
decode(d::AbstractProtoDecoder, ::Type{T}) where {T <: Union{Float64,Float32}} = read(d.io, T)
function decode!(d::AbstractProtoDecoder, buffer::Dict{K,V}) where {K,V<:_ScalarTypesEnum}
len = vbyte_decode(d.io, UInt32)
endpos = position(d.io) + len
while position(d.io) < endpos
field_number, wire_type = decode_tag(d)
key = decode(d, K)
field_number, wire_type = decode_tag(d)
val = decode(d, V)
buffer[key] = val
end
@assert position(d.io) == endpos
pair_len = vbyte_decode(d.io, UInt32)
pair_end_pos = position(d.io) + pair_len
field_number, wire_type = decode_tag(d)
key = decode(d, K)
field_number, wire_type = decode_tag(d)
val = decode(d, V)
@assert position(d.io) == pair_end_pos
buffer[key] = val
nothing
end

function decode!(d::AbstractProtoDecoder, buffer::Dict{K,V}) where {K,V}
len = vbyte_decode(d.io, UInt32)
endpos = position(d.io) + len
while position(d.io) < endpos
field_number, wire_type = decode_tag(d)
key = decode(d, K)
field_number, wire_type = decode_tag(d)
val = decode(d, Ref{V})
buffer[key] = val
end
@assert position(d.io) == endpos
pair_len = vbyte_decode(d.io, UInt32)
pair_end_pos = position(d.io) + pair_len
field_number, wire_type = decode_tag(d)
key = decode(d, K)
field_number, wire_type = decode_tag(d)
val = decode(d, Ref{V})
@assert position(d.io) == pair_end_pos
buffer[key] = val
nothing
end

for T in (:(:fixed), :(:zigzag))
@eval function decode!(d::AbstractProtoDecoder, buffer::Dict{K,V}, ::Type{Val{Tuple{Nothing,$(T)}}}) where {K,V}
len = vbyte_decode(d.io, UInt32)
endpos = position(d.io) + len
while position(d.io) < endpos
field_number, wire_type = decode_tag(d)
key = decode(d, K)
field_number, wire_type = decode_tag(d)
val = decode(d, V, Val{$(T)})
buffer[key] = val
end
@assert position(d.io) == endpos
pair_len = vbyte_decode(d.io, UInt32)
pair_end_pos = position(d.io) + pair_len
field_number, wire_type = decode_tag(d)
key = decode(d, K)
field_number, wire_type = decode_tag(d)
val = decode(d, V, Val{$(T)})
@assert position(d.io) == pair_end_pos
buffer[key] = val
nothing
end

@eval function decode!(d::AbstractProtoDecoder, buffer::Dict{K,V}, ::Type{Val{Tuple{$(T),Nothing}}}) where {K,V}
len = vbyte_decode(d.io, UInt32)
endpos = position(d.io) + len
while position(d.io) < endpos
field_number, wire_type = decode_tag(d)
key = decode(d, K, Val{$(T)})
field_number, wire_type = decode_tag(d)
val = decode(d, V)
buffer[key] = val
end
@assert position(d.io) == endpos
pair_len = vbyte_decode(d.io, UInt32)
pair_end_pos = position(d.io) + pair_len
field_number, wire_type = decode_tag(d)
key = decode(d, K, Val{$(T)})
field_number, wire_type = decode_tag(d)
val = decode(d, V)
@assert position(d.io) == pair_end_pos
buffer[key] = val
nothing
end
end

for T in (:(:fixed), :(:zigzag)), S in (:(:fixed), :(:zigzag))
@eval function decode!(d::AbstractProtoDecoder, buffer::Dict{K,V}, ::Type{Val{Tuple{$(T),$(S)}}}) where {K,V}
len = vbyte_decode(d.io, UInt32)
endpos = position(d.io) + len
while position(d.io) < endpos
field_number, wire_type = decode_tag(d)
key = decode(d, K, Val{$(T)})
field_number, wire_type = decode_tag(d)
val = decode(d, V, Val{$(S)})
buffer[key] = val
end
@assert position(d.io) == endpos
pair_len = vbyte_decode(d.io, UInt32)
pair_end_pos = position(d.io) + pair_len
field_number, wire_type = decode_tag(d)
key = decode(d, K, Val{$(T)})
field_number, wire_type = decode_tag(d)
val = decode(d, V, Val{$(S)})
@assert position(d.io) == pair_end_pos
buffer[key] = val
nothing
end
end
Expand Down
56 changes: 28 additions & 28 deletions src/codec/encode.jl
Original file line number Diff line number Diff line change
Expand Up @@ -139,40 +139,52 @@ function _encode(io::IO, x::Vector{T}) where {T<:Union{UInt32,UInt64,Int32,Int64
return nothing
end

function _encode(_e::ProtoEncoder, x::Dict{K,V}) where {K,V}
maybe_ensure_room(_e.io, 2length(x))
function encode(e::ProtoEncoder, i::Int, x::Dict{K,V}) where {K,V}
maybe_ensure_room(e.io, 2*(length(x)+1))
for (k, v) in x
encode(_e, 1, k)
encode(_e, 2, v)
# encode header for key-value pair message
encode_tag(e, i, LENGTH_DELIMITED)
vbyte_encode(e.io, UInt32(_encoded_size(k, 1) + _encoded_size(v, 2)))
encode(e, 1, k)
encode(e, 2, v)
end
nothing
end

for T in (:(:fixed), :(:zigzag))
@eval function _encode(_e::ProtoEncoder, x::Dict{K,V}, ::Type{Val{Tuple{$(T),Nothing}}}) where {K,V}
maybe_ensure_room(_e.io, 2length(x))
@eval function encode(e::ProtoEncoder, i::Int, x::Dict{K,V}, ::Type{Val{Tuple{$(T),Nothing}}}) where {K,V}
maybe_ensure_room(e.io, 2*(length(x)+1))
for (k, v) in x
encode(_e, 1, k, Val{$(T)})
encode(_e, 2, v)
# encode header for key-value pair message
encode_tag(e, i, LENGTH_DELIMITED)
vbyte_encode(e.io, UInt32(_encoded_size(k, 1, Val{$(T)}) + _encoded_size(v, 2)))
encode(e, 1, k, Val{$(T)})
encode(e, 2, v)
end
nothing
end
@eval function _encode(_e::ProtoEncoder, x::Dict{K,V}, ::Type{Val{Tuple{Nothing,$(T)}}}) where {K,V}
maybe_ensure_room(_e.io, 2length(x))
@eval function encode(e::ProtoEncoder, i::Int, x::Dict{K,V}, ::Type{Val{Tuple{Nothing,$(T)}}}) where {K,V}
maybe_ensure_room(e.io, 2*(length(x)+1))
for (k, v) in x
encode(_e, 1, k)
encode(_e, 2, v, Val{$(T)})
# encode header for key-value pair message
encode_tag(e, i, LENGTH_DELIMITED)
vbyte_encode(e.io, UInt32(_encoded_size(k, 1) + _encoded_size(v, 2, Val{$(T)})))
encode(e, 1, k)
encode(e, 2, v, Val{$(T)})
end
nothing
end
end

for T in (:(:fixed), :(:zigzag)), S in (:(:fixed), :(:zigzag))
@eval function _encode(_e::AbstractProtoEncoder, x::Dict{K,V}, ::Type{Val{Tuple{$(T),$(S)}}}) where {K,V}
maybe_ensure_room(_e.io, 2length(x))
@eval function encode(e::AbstractProtoEncoder, i::Int, x::Dict{K,V}, ::Type{Val{Tuple{$(T),$(S)}}}) where {K,V}
maybe_ensure_room(e.io, 2*(length(x)+1))
for (k, v) in x
encode(_e, 1, k, Val{$(T)})
encode(_e, 2, v, Val{$(S)})
# encode header for key-value pair message
encode_tag(e, i, LENGTH_DELIMITED)
vbyte_encode(e.io, UInt32(_encoded_size(k, 1, Val{$(T)}) + _encoded_size(v, 2, Val{$(S)})))
encode(e, 1, k, Val{$(T)})
encode(e, 2, v, Val{$(S)})
end
nothing
end
Expand Down Expand Up @@ -269,18 +281,6 @@ function encode(e::AbstractProtoEncoder, i::Int, x::Vector{T}) where {T<:Union{U
return nothing
end

function encode(e::AbstractProtoEncoder, i::Int, x::Dict{K,V}) where {K,V}
encode_tag(e, i, LENGTH_DELIMITED)
_with_size(_encode, e.io, e, x)
return nothing
end

function encode(e::AbstractProtoEncoder, i::Int, x::Dict{K,V}, ::Type{W}) where {K,V,W}
encode_tag(e, i, LENGTH_DELIMITED)
_with_size(_encode, e.io, e, x, W)
return nothing
end

function encode(e::AbstractProtoEncoder, i::Int, x::Vector{T}, ::Type{Val{:zigzag}}) where {T<:Union{Int32,Int64}}
encode_tag(e, i, LENGTH_DELIMITED)
_with_size(_encode, e.io, e.io, x, Val{:zigzag})
Expand Down
38 changes: 30 additions & 8 deletions src/codec/encoded_size.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

# For Length-Delimited fields we don't include the encoded number of bytes
# unless we also provide the field number in which case we encode both the
# tag and lenght
# tag and length
_with_size(n::Int) = (n + _encoded_size(n))
_encoded_size(x::String) = sizeof(x)

Expand All @@ -29,19 +29,41 @@
_encoded_size(xs::AbstractVector{T}) where {T<:Union{String,AbstractVector{UInt8}}} = sum(x->_with_size(_encoded_size(x)), xs, init=0)
_encoded_size(xs::AbstractVector{T}, ::Type{Val{:fixed}}) where {T<:Union{Int32,UInt32,Int64,UInt64}} = sizeof(xs)

# Dicts add dummy tags to both keys and values
_encoded_size(d::AbstractDict) = mapreduce(x->_encoded_size(x.first, 1) + _encoded_size(x.second, 2), +, d, init=0)
# Dicts add dummy tags to both keys and values and to each pair
# _encoded_size(::AbstractDict) does not include the "pair" tag and field number
# those are added in the _encoded_size(::AbstractDict, ::Int) methods below because the field number
# is not known at this point
function _encoded_size(d::AbstractDict)
mapreduce(x->begin
total_size = _encoded_size(x.first, 1) + _encoded_size(x.second, 2)
return _varint_size(total_size) + total_size
end, +, d, init=0)
end
_encoded_size(xs::AbstractDict, i::Int) = _encoded_size(i << 3) * length(xs) + _encoded_size(xs)

for T in (:(:fixed), :(:zigzag))
@eval _encoded_size(d::AbstractDict, ::Type{Val{Tuple{$(T),Nothing}}}) = mapreduce(x->_encoded_size(x.first, 1, Val{$(T)}) + _encoded_size(x.second, 2), +, d, init=0)
@eval _encoded_size(d::AbstractDict, ::Type{Val{Tuple{Nothing,$(T)}}}) = mapreduce(x->_encoded_size(x.first, 1) + _encoded_size(x.second, 2, Val{$(T)}), +, d, init=0)
@eval _encoded_size(d::AbstractDict, ::Type{Val{Tuple{$(T),Nothing}}}) = mapreduce(x->begin
total_size = _encoded_size(x.first, 1, Val{$(T)}) + _encoded_size(x.second, 2)
return _varint_size(total_size) + total_size
end, +, d, init=0)
@eval _encoded_size(xs::AbstractDict, i::Int, ::Type{Val{Tuple{$(T),Nothing}}}) = _encoded_size(i << 3) * length(xs) + _encoded_size(xs, Val{Tuple{$(T),Nothing}})

@eval _encoded_size(d::AbstractDict, ::Type{Val{Tuple{Nothing,$(T)}}}) = mapreduce(x->begin
total_size = _encoded_size(x.first, 1) + _encoded_size(x.second, 2, Val{$(T)})
return _varint_size(total_size) + total_size
end, +, d, init=0)
@eval _encoded_size(xs::AbstractDict, i::Int, ::Type{Val{Tuple{Nothing,$(T)}}}) = _encoded_size(i << 3) * length(xs) + _encoded_size(xs, Val{Tuple{Nothing,$(T)}})

@eval _encoded_size(xs::Union{AbstractDict,AbstractVector}, i::Int, ::Type{Val{$(T)}}) = _encoded_size(i << 3) + _with_size(_encoded_size(xs, Val{$(T)}))
@eval _encoded_size(xs::Union{Int32,Int64,UInt64,UInt32}, i::Int, ::Type{Val{$(T)}}) = _encoded_size(i << 3) + _encoded_size(xs, Val{$(T)})
@eval _encoded_size(xs::AbstractVector, i::Int, ::Type{Val{$(T)}}) = _encoded_size(i << 3) + _with_size(_encoded_size(xs, Val{$(T)}))

Check warning on line 57 in src/codec/encoded_size.jl

View check run for this annotation

Codecov / codecov/patch

src/codec/encoded_size.jl#L57

Added line #L57 was not covered by tests
@eval _encoded_size(xs::Union{Int32,Int64,UInt64,UInt32}, i::Int, ::Type{Val{$(T)}}) = _encoded_size(i << 3) + _encoded_size(xs, Val{$(T)})
end

for T in (:(:fixed), :(:zigzag)), S in (:(:fixed), :(:zigzag))
@eval _encoded_size(d::AbstractDict, ::Type{Val{Tuple{$(T),$(S)}}}) = mapreduce(x->_encoded_size(x.first, 1, Val{$(S)}) + _encoded_size(x.second, 2, Val{$(S)}), +, d, init=0)
@eval _encoded_size(d::AbstractDict, ::Type{Val{Tuple{$(T),$(S)}}}) = mapreduce(x->begin
total_size = _encoded_size(x.first, 1, Val{$(T)}) + _encoded_size(x.second, 2, Val{$(S)})
return _varint_size(total_size) + total_size
end, +, d, init=0)
@eval _encoded_size(xs::AbstractDict, i::Int, ::Type{Val{Tuple{$(T),$(S)}}}) = _encoded_size(i << 3) * length(xs) + _encoded_size(xs, Val{Tuple{$(T),$(S)}})
end

# These methods handle fields that refer to messages/groups
Expand Down
Loading
Loading