diff --git a/src/common.jl b/src/common.jl index 3ff3707..1cf8df1 100644 --- a/src/common.jl +++ b/src/common.jl @@ -17,6 +17,7 @@ julia> update!(ctx, b"data to to be hashed") ``` """ function update!(context::T, data::U, datalen=length(data)) where {T<:SHA_CTX, U<:AbstractBytes} + context.used && error("Cannot update CTX after `digest!` has been called on it") # We need to do all our arithmetic in the proper bitwidth UIntXXX = typeof(context.bytecount) @@ -82,6 +83,7 @@ end digest!(context) Finalize the SHA context and return the hash as array of bytes (Array{Uint8, 1}). +Updating the context after calling `digest!` on it will error. # Examples ```julia-repl @@ -97,18 +99,26 @@ julia> digest!(ctx) ⋮ 0x89 0xf5 + +julia> update!(ctx, b"more data") +ERROR: Cannot update CTX after `digest!` has been called on it +[...] ``` """ function digest!(context::T) where T<:SHA_CTX - pad_remainder!(context) - # Store the length of the input data (in bits) at the end of the padding - bitcount_idx = div(short_blocklen(T), sizeof(context.bytecount)) + 1 - pbuf = Ptr{typeof(context.bytecount)}(pointer(context.buffer)) - unsafe_store!(pbuf, bswap(context.bytecount * 8), bitcount_idx) - - # Final transform: - transform!(context) + if !context.used + pad_remainder!(context) + # Store the length of the input data (in bits) at the end of the padding + bitcount_idx = div(short_blocklen(T), sizeof(context.bytecount)) + 1 + pbuf = Ptr{typeof(context.bytecount)}(pointer(context.buffer)) + unsafe_store!(pbuf, bswap(context.bytecount * 8), bitcount_idx) + + # Final transform: + transform!(context) + bswap!(context.state) + context.used = true + end # Return the digest - return reinterpret(UInt8, bswap!(context.state))[1:digestlen(T)] + return reinterpret(UInt8, context.state)[1:digestlen(T)] end diff --git a/src/sha3.jl b/src/sha3.jl index 894cd3d..58393cc 100644 --- a/src/sha3.jl +++ b/src/sha3.jl @@ -55,26 +55,29 @@ end # Finalize data in the buffer, append total bitlength, and return our precious hash! function digest!(context::T) where {T<:SHA3_CTX} - usedspace = context.bytecount % blocklen(T) - # If we have anything in the buffer still, pad and transform that data - if usedspace < blocklen(T) - 1 - # Begin padding with a 0x06 - context.buffer[usedspace+1] = 0x06 - # Fill with zeros up until the last byte - context.buffer[usedspace+2:end-1] .= 0x00 - # Finish it off with a 0x80 - context.buffer[end] = 0x80 - else - # Otherwise, we have to add on a whole new buffer just for the zeros and 0x80 - context.buffer[end] = 0x06 - transform!(context) + if !context.used + usedspace = context.bytecount % blocklen(T) + # If we have anything in the buffer still, pad and transform that data + if usedspace < blocklen(T) - 1 + # Begin padding with a 0x06 + context.buffer[usedspace+1] = 0x06 + # Fill with zeros up until the last byte + context.buffer[usedspace+2:end-1] .= 0x00 + # Finish it off with a 0x80 + context.buffer[end] = 0x80 + else + # Otherwise, we have to add on a whole new buffer just for the zeros and 0x80 + context.buffer[end] = 0x06 + transform!(context) - context.buffer[1:end-1] .= 0x0 - context.buffer[end] = 0x80 - end + context.buffer[1:end-1] .= 0x0 + context.buffer[end] = 0x80 + end - # Final transform: - transform!(context) + # Final transform: + transform!(context) + context.used = true + end # Return the digest return reinterpret(UInt8, context.state)[1:digestlen(T)] diff --git a/src/types.jl b/src/types.jl index bfe2a43..7d25691 100644 --- a/src/types.jl +++ b/src/types.jl @@ -12,6 +12,7 @@ mutable struct SHA1_CTX <: SHA_CTX bytecount::UInt64 buffer::Array{UInt8,1} W::Array{UInt32,1} + used::Bool end # SHA2 224/256/384/512-bit Context Structures @@ -19,24 +20,28 @@ mutable struct SHA2_224_CTX <: SHA2_CTX state::Array{UInt32,1} bytecount::UInt64 buffer::Array{UInt8,1} + used::Bool end mutable struct SHA2_256_CTX <: SHA2_CTX state::Array{UInt32,1} bytecount::UInt64 buffer::Array{UInt8,1} + used::Bool end mutable struct SHA2_384_CTX <: SHA2_CTX state::Array{UInt64,1} bytecount::UInt128 buffer::Array{UInt8,1} + used::Bool end mutable struct SHA2_512_CTX <: SHA2_CTX state::Array{UInt64,1} bytecount::UInt128 buffer::Array{UInt8,1} + used::Bool end function Base.getproperty(ctx::SHA2_CTX, fieldname::Symbol) @@ -48,6 +53,8 @@ function Base.getproperty(ctx::SHA2_CTX, fieldname::Symbol) return getfield(ctx, :buffer)::Vector{UInt8} elseif fieldname === :W return getfield(ctx, :W)::Vector{UInt32} + elseif fieldname === :used + return getfield(ctx, :used)::Bool else error("SHA2_CTX has no field ", fieldname) end @@ -67,24 +74,28 @@ mutable struct SHA3_224_CTX <: SHA3_CTX bytecount::UInt128 buffer::Array{UInt8,1} bc::Array{UInt64,1} + used::Bool end mutable struct SHA3_256_CTX <: SHA3_CTX state::Array{UInt64,1} bytecount::UInt128 buffer::Array{UInt8,1} bc::Array{UInt64,1} + used::Bool end mutable struct SHA3_384_CTX <: SHA3_CTX state::Array{UInt64,1} bytecount::UInt128 buffer::Array{UInt8,1} bc::Array{UInt64,1} + used::Bool end mutable struct SHA3_512_CTX <: SHA3_CTX state::Array{UInt64,1} bytecount::UInt128 buffer::Array{UInt8,1} bc::Array{UInt64,1} + used::Bool end function Base.getproperty(ctx::SHA3_CTX, fieldname::Symbol) @@ -96,6 +107,8 @@ function Base.getproperty(ctx::SHA3_CTX, fieldname::Symbol) return getfield(ctx, :buffer)::Vector{UInt8} elseif fieldname === :bc return getfield(ctx, :bc)::Vector{UInt64} + elseif fieldname === :used + return getfield(ctx, :used)::Bool else error("type ", typeof(ctx), " has no field ", fieldname) end @@ -145,50 +158,50 @@ short_blocklen(::Type{T}) where {T<:SHA_CTX} = blocklen(T) - 2*sizeof(state_type Construct an empty SHA2_224 context. """ -SHA2_224_CTX() = SHA2_224_CTX(copy(SHA2_224_initial_hash_value), 0, zeros(UInt8, blocklen(SHA2_224_CTX))) +SHA2_224_CTX() = SHA2_224_CTX(copy(SHA2_224_initial_hash_value), 0, zeros(UInt8, blocklen(SHA2_224_CTX)), false) """ SHA2_256_CTX() Construct an empty SHA2_256 context. """ -SHA2_256_CTX() = SHA2_256_CTX(copy(SHA2_256_initial_hash_value), 0, zeros(UInt8, blocklen(SHA2_256_CTX))) +SHA2_256_CTX() = SHA2_256_CTX(copy(SHA2_256_initial_hash_value), 0, zeros(UInt8, blocklen(SHA2_256_CTX)), false) """ SHA2_384() Construct an empty SHA2_384 context. """ -SHA2_384_CTX() = SHA2_384_CTX(copy(SHA2_384_initial_hash_value), 0, zeros(UInt8, blocklen(SHA2_384_CTX))) +SHA2_384_CTX() = SHA2_384_CTX(copy(SHA2_384_initial_hash_value), 0, zeros(UInt8, blocklen(SHA2_384_CTX)), false) """ SHA2_512_CTX() Construct an empty SHA2_512 context. """ -SHA2_512_CTX() = SHA2_512_CTX(copy(SHA2_512_initial_hash_value), 0, zeros(UInt8, blocklen(SHA2_512_CTX))) +SHA2_512_CTX() = SHA2_512_CTX(copy(SHA2_512_initial_hash_value), 0, zeros(UInt8, blocklen(SHA2_512_CTX)), false) """ SHA3_224_CTX() Construct an empty SHA3_224 context. """ -SHA3_224_CTX() = SHA3_224_CTX(zeros(UInt64, 25), 0, zeros(UInt8, blocklen(SHA3_224_CTX)), Vector{UInt64}(undef, 5)) +SHA3_224_CTX() = SHA3_224_CTX(zeros(UInt64, 25), 0, zeros(UInt8, blocklen(SHA3_224_CTX)), Vector{UInt64}(undef, 5), false) """ SHA3_256_CTX() Construct an empty SHA3_256 context. """ -SHA3_256_CTX() = SHA3_256_CTX(zeros(UInt64, 25), 0, zeros(UInt8, blocklen(SHA3_256_CTX)), Vector{UInt64}(undef, 5)) +SHA3_256_CTX() = SHA3_256_CTX(zeros(UInt64, 25), 0, zeros(UInt8, blocklen(SHA3_256_CTX)), Vector{UInt64}(undef, 5), false) """ SHA3_384_CTX() Construct an empty SHA3_384 context. """ -SHA3_384_CTX() = SHA3_384_CTX(zeros(UInt64, 25), 0, zeros(UInt8, blocklen(SHA3_384_CTX)), Vector{UInt64}(undef, 5)) +SHA3_384_CTX() = SHA3_384_CTX(zeros(UInt64, 25), 0, zeros(UInt8, blocklen(SHA3_384_CTX)), Vector{UInt64}(undef, 5), false) """ SHA3_512_CTX() Construct an empty SHA3_512 context. """ -SHA3_512_CTX() = SHA3_512_CTX(zeros(UInt64, 25), 0, zeros(UInt8, blocklen(SHA3_512_CTX)), Vector{UInt64}(undef, 5)) +SHA3_512_CTX() = SHA3_512_CTX(zeros(UInt64, 25), 0, zeros(UInt8, blocklen(SHA3_512_CTX)), Vector{UInt64}(undef, 5), false) # Nickname'd outer constructor methods for SHA2 const SHA224_CTX = SHA2_224_CTX @@ -202,13 +215,13 @@ const SHA512_CTX = SHA2_512_CTX Construct an empty SHA1 context. """ -SHA1_CTX() = SHA1_CTX(copy(SHA1_initial_hash_value), 0, zeros(UInt8, blocklen(SHA1_CTX)), Vector{UInt32}(undef, 80)) +SHA1_CTX() = SHA1_CTX(copy(SHA1_initial_hash_value), 0, zeros(UInt8, blocklen(SHA1_CTX)), Vector{UInt32}(undef, 80), false) # Copy functions -copy(ctx::T) where {T<:SHA1_CTX} = T(copy(ctx.state), ctx.bytecount, copy(ctx.buffer), copy(ctx.W)) -copy(ctx::T) where {T<:SHA2_CTX} = T(copy(ctx.state), ctx.bytecount, copy(ctx.buffer)) -copy(ctx::T) where {T<:SHA3_CTX} = T(copy(ctx.state), ctx.bytecount, copy(ctx.buffer), Vector{UInt64}(undef, 5)) +copy(ctx::T) where {T<:SHA1_CTX} = T(copy(ctx.state), ctx.bytecount, copy(ctx.buffer), copy(ctx.W), ctx.used) +copy(ctx::T) where {T<:SHA2_CTX} = T(copy(ctx.state), ctx.bytecount, copy(ctx.buffer), ctx.used) +copy(ctx::T) where {T<:SHA3_CTX} = T(copy(ctx.state), ctx.bytecount, copy(ctx.buffer), Vector{UInt64}(undef, 5), ctx.used) # Make printing these types a little friendlier diff --git a/test/runtests.jl b/test/runtests.jl index 20e661f..796d985 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -63,6 +63,23 @@ end @test hash == answers[sha_funcs[sha_idx]][end] end end + + # Test that the hash states cannot be updated after having been finalized, + # but can still return the same digest + @testset "Reuse" begin + for sha_idx in 1:length(sha_funcs) + ctx = sha_types[sha_funcs[sha_idx]]() + update!(ctx, codeunits("abracadabra")) + hash1 = digest!(ctx) + + # Cannot update after having been digested + @test_throws Exception update!(ctx, codeunits("abc")) + + # But will still return the same digest twice + hash2 = digest!(ctx) + @test hash1 == hash2 + end + end end @testset "HMAC" begin