From 9d232bec7e1369d2cd4ab8a4c7221fd094982ba1 Mon Sep 17 00:00:00 2001 From: Dan Klishch Date: Fri, 17 May 2024 14:20:27 -0400 Subject: [PATCH] AK+LibCompress: Implement streamable asynchronous deflate decompression --- AK/AsyncBitStream.h | 325 ++++++++++++++++++ AK/AsyncStreamBuffer.h | 109 ++++++ .../Libraries/LibCompress/AsyncDeflate.cpp | 324 +++++++++++++++++ Userland/Libraries/LibCompress/AsyncDeflate.h | 28 ++ Userland/Libraries/LibCompress/CMakeLists.txt | 1 + Userland/Libraries/LibCompress/Deflate.cpp | 24 ++ Userland/Libraries/LibCompress/Deflate.h | 3 + 7 files changed, 814 insertions(+) create mode 100644 AK/AsyncBitStream.h create mode 100644 Userland/Libraries/LibCompress/AsyncDeflate.cpp create mode 100644 Userland/Libraries/LibCompress/AsyncDeflate.h diff --git a/AK/AsyncBitStream.h b/AK/AsyncBitStream.h new file mode 100644 index 00000000000000..59599aa48eec69 --- /dev/null +++ b/AK/AsyncBitStream.h @@ -0,0 +1,325 @@ +/* + * Copyright (c) 2024, Dan Klishch + * + * SPDX-License-Identifier: BSD-2-Clause + */ + +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace AK { + +class AsyncInputLittleEndianBitStream; + +class BufferBitView { + // These are defined just to replace some 4s and 8s with meaningful expressions. + using WordType = u32; + using DoubleWordType = u64; + static constexpr size_t bits_in_word = sizeof(WordType) * 8; + +public: + BufferBitView(ReadonlyBytes bytes, u8 bit_position) + { + auto ptr = reinterpret_cast(bytes.data()); + auto buffer_offset_in_bytes = ptr % alignof(WordType); + auto bytes_in_current_word_to_fill = sizeof(WordType) - buffer_offset_in_bytes; + + m_bit_position = buffer_offset_in_bytes * 8 + bit_position; + m_bits_left = bytes.size() * 8 - bit_position; + memcpy( + reinterpret_cast(&m_current_and_next_word) + buffer_offset_in_bytes, + bytes.data(), + min(bytes_in_current_word_to_fill, bytes.size())); + + if (bytes.size() > bytes_in_current_word_to_fill) { + m_aligned_words = ReadonlySpan { + reinterpret_cast(ptr + bytes_in_current_word_to_fill), + (bytes.size() - bytes_in_current_word_to_fill) / sizeof(WordType), + }; + auto unaligned_end = bytes.slice(bytes_in_current_word_to_fill + m_aligned_words.size() * sizeof(WordType)); + memcpy(&m_unaligned_end, unaligned_end.data(), unaligned_end.size()); + refill_next_word(); + } + } + + size_t bits_left() const { return m_bits_left; } + size_t bits_consumed(Badge) const { return m_bits_consumed; } + + WordType peek_bits_possibly_past_end() const + { + return m_current_and_next_word >> m_bit_position; + } + + template + ErrorOr read_bits(u8 count) + { + static_assert(sizeof(T) <= sizeof(WordType)); + VERIFY(count <= sizeof(T) * 8); // FIXME: Teach read_bits to read more than 32 bits. + + if (bits_left() < count) + return Error::from_errno(EAGAIN); + + T result = peek_bits_possibly_past_end() & ((1ULL << count) - 1); + advance_read_head(count); + return result; + } + + ErrorOr read_bit() + { + if (!bits_left()) + return Error::from_errno(EAGAIN); + bool result = m_current_and_next_word >> m_bit_position & 1; + advance_read_head(1); + return result; + } + + void consume_bits(size_t count) + { + m_bits_consumed += count; + } + + template + auto rollback_group(Func&& func) + { + auto bits_left_originally = m_bits_left; + auto result = func(); + if (!result.is_error()) + consume_bits(bits_left_originally - m_bits_left); + return result; + } + +private: + void refill_next_word() + { + if (!m_aligned_words.is_empty()) { + m_current_and_next_word |= static_cast(m_aligned_words[0]) << bits_in_word; + m_aligned_words = m_aligned_words.slice(1); + } else { + m_current_and_next_word |= static_cast(m_unaligned_end) << bits_in_word; + m_unaligned_end = 0; + } + } + + void advance_read_head(u8 bits) + { + m_bit_position += bits; + m_bits_left -= bits; + if (m_bit_position >= bits_in_word) { + m_bit_position -= bits_in_word; + m_current_and_next_word >>= bits_in_word; + refill_next_word(); + } + } + + u8 m_bit_position { 0 }; // bit offset inside current word + DoubleWordType m_current_and_next_word { 0 }; + size_t m_bits_left { 0 }; + size_t m_bits_consumed { 0 }; + + ReadonlySpan m_aligned_words; + WordType m_unaligned_end { 0 }; +}; + +class AsyncInputLittleEndianBitStream final : public AsyncInputStream { + AK_MAKE_NONCOPYABLE(AsyncInputLittleEndianBitStream); + AK_MAKE_NONMOVABLE(AsyncInputLittleEndianBitStream); + +public: + AsyncInputLittleEndianBitStream(MaybeOwned&& stream) + : m_stream(move(stream)) + { + } + + ~AsyncInputLittleEndianBitStream() + { + if (is_open()) + reset(); + } + + void reset() override + { + VERIFY(is_open()); + m_is_open = false; + m_stream->reset(); + } + + Coroutine> close() override + { + VERIFY(is_open()); + if (m_bit_position != 0) { + reset(); + co_return Error::from_errno(EBUSY); + } + m_is_open = false; + if (m_stream.is_owned()) + co_return co_await m_stream->close(); + co_return {}; + } + + bool is_open() const override { return m_is_open; } + + Coroutine> enqueue_some(Badge) override + { + auto result = co_await m_stream->enqueue_some(badge()); + if (result.is_error()) + m_is_open = false; + + if (buffered_data_unchecked(badge()).size() >= NumericLimits::max() / 8) [[unlikely]] { + // Can realistically only trigger on 32-bit. + m_stream->reset(); + co_return Error::from_string_literal("Too much data buffered"); + } + + co_return result; + } + + ReadonlyBytes buffered_data_unchecked(Badge) const override + { + VERIFY(m_bit_position == 0); + return m_stream->buffered_data_unchecked(badge()); + } + + void dequeue(Badge, size_t bytes) override + { + VERIFY(m_bit_position == 0); + m_stream->dequeue(badge(), bytes); + } + + size_t buffered_bits_count() const + { + return m_stream->buffered_data().size() * 8 - m_bit_position; + } + + void align_to_byte_boundary() + { + if (m_bit_position != 0) { + m_bit_position = 0; + m_stream->dequeue(badge(), 1); + } + } + + template + ErrorOr with_bit_view_of_buffer(Func&& func) + { + BufferBitView bit_view { m_stream->buffered_data(), m_bit_position }; + ErrorOr result = func(bit_view); + + VERIFY(m_is_open); + + if (result.is_error()) { + if (result.error().code() == EAGAIN) { + m_is_reading_peek = true; + } else { + reset(); + return result.release_error(); + } + } else { + m_is_reading_peek = false; + } + + size_t offset = m_bit_position + bit_view.bits_consumed({}); + m_bit_position = offset % 8; + if (offset >= 8) + m_stream->dequeue(badge(), offset / 8); + + return {}; + } + + struct PeekBitsSyncResult { + u64 value; + size_t valid_bits; + }; + + // In AsyncInputStream terms, this always does a no-op peek of data. The precondition is that + // the current peek is non-reading, so this function can return 0 valid bits. For the sake of + // simplicity and performance, this function isn't guaranteed to return more than 57 bits (even + // if more data is available). + PeekBitsSyncResult peek_bits_sync() + { + VERIFY(!m_is_reading_peek); // Reading peek cannot ever be synchronous. + m_is_reading_peek = true; + + auto data = m_stream->buffered_data(); + + u64 value = 0; + static_assert(HostIsLittleEndian); + if (data.size() > sizeof(value)) [[likely]] { + memcpy(&value, data.data(), sizeof(value)); + value >>= m_bit_position; + } else { + memcpy(&value, data.data(), min(sizeof(value), data.size())); + value >>= m_bit_position; + } + + return { .value = value, .valid_bits = min(64U, data.size() * 8) - m_bit_position }; + } + + Coroutine> peek_bits() + { + TemporaryChange bit_position_change { m_bit_position, static_cast(0) }; + auto data = co_await peek(); + if (data.is_error()) { + m_is_open = false; + co_return data.release_error(); + } + co_return {}; + } + + void discard_bits(size_t count) + { + VERIFY(buffered_bits_count() >= count); + + m_is_reading_peek = false; + + size_t bytes_to_read = (m_bit_position + count) / 8; + if (bytes_to_read) + m_stream->dequeue(badge(), bytes_to_read); + m_bit_position = (m_bit_position + count) % 8; + } + + template + Coroutine> read_bits(size_t count) + { + VERIFY(!m_is_reading_peek); + VERIFY(count <= 57); // FIXME: Teach peek_bits_sync to peek more than 57 bits. + + while (buffered_bits_count() < count) { + m_is_reading_peek = true; + auto result = co_await m_stream->peek(); + if (result.is_error()) { + m_is_open = false; + co_return result.release_error(); + } + } + m_is_reading_peek = false; + + auto [value, valid_bits] = peek_bits_sync(); + VERIFY(valid_bits >= count); + discard_bits(count); + co_return value & ((1ULL << count) - 1); + } + + Coroutine> read_bit() + { + return read_bits(1); + } + +private: + MaybeOwned m_stream; + bool m_is_open { true }; + + u8 m_bit_position { 0 }; +}; + +} + +#ifdef USING_AK_GLOBALLY +using AK::AsyncInputLittleEndianBitStream; +using AK::BufferBitView; +#endif diff --git a/AK/AsyncStreamBuffer.h b/AK/AsyncStreamBuffer.h index 6355dc8f2f2363..799fbb03664d53 100644 --- a/AK/AsyncStreamBuffer.h +++ b/AK/AsyncStreamBuffer.h @@ -8,6 +8,7 @@ #include #include +#include namespace AK { @@ -124,8 +125,116 @@ class AsyncStreamBuffer { u8* m_data { nullptr }; }; +class AsyncStreamSeekbackBuffer { +public: + // The absolute minimum for `seekback_buffer_size` is + // `max_seekback_distance + max_back_reference_length`, but providing a bigger value usually + // slightly improves performance. + AsyncStreamSeekbackBuffer(size_t max_seekback_distance, size_t seekback_buffer_size) + : m_seekback(MUST(FixedArray::create(seekback_buffer_size))) + , m_max_seekback_distance(max_seekback_distance) + { + } + + ReadonlyBytes data() const + { + return m_buffer.data(); + } + + void dequeue(size_t bytes) + { + m_buffer.dequeue(bytes); + } + + void write(ReadonlyBytes bytes) + { + m_buffer.append(bytes); + write_to_seekback(bytes); + } + + void write(u8 byte) + { + m_buffer.append(byte); + write_to_seekback(byte); + } + + void copy_from_seekback(size_t distance, size_t length) + { + VERIFY(distance <= max_seekback_distance()); + + auto buffer_bytes = m_buffer.get_bytes_for_writing(length); + + while (length > 0) { + auto write = [&](ReadonlyBytes bytes) { + write_to_seekback(bytes, false); + buffer_bytes = buffer_bytes.slice(bytes.copy_to(buffer_bytes)); + }; + + size_t to_copy = min(distance, length); + if (distance <= m_head) { + write(m_seekback.span().slice(m_head - distance, to_copy)); + } else if (distance - to_copy > m_head) { + write(m_seekback.span().slice(m_seekback.size() - (distance - m_head), to_copy)); + } else { + auto first_part = m_seekback.span().slice_from_end(distance - m_head); + auto second_part = m_seekback.span().slice(0, to_copy - distance + m_head); + write(first_part); + write(second_part); + } + + distance += to_copy; + length -= to_copy; + } + + VERIFY(buffer_bytes.is_empty()); + } + + size_t max_seekback_distance() const + { + return min(m_max_seekback_distance, m_seekback_length); + } + +private: + void write_to_seekback(ReadonlyBytes bytes, bool may_discard_prefix = true) + { + if (may_discard_prefix && bytes.size() > m_max_seekback_distance) { + bytes = bytes.slice_from_end(m_max_seekback_distance); + m_head = 0; + } + + if (m_head + bytes.size() > m_seekback.size()) { + size_t first_part_size = m_seekback.size() - m_head; + size_t new_head = bytes.size() - first_part_size; + + memcpy(m_seekback.data() + m_head, bytes.data(), first_part_size); + memcpy(m_seekback.data(), bytes.slice(first_part_size).data(), new_head); + m_head = new_head; + } else { + memcpy(m_seekback.data() + m_head, bytes.data(), bytes.size()); + m_head += bytes.size(); + if (m_head == m_seekback.size()) + m_head = 0; + } + m_seekback_length += bytes.size(); + } + + void write_to_seekback(u8 byte) + { + m_seekback[m_head] = byte; + m_head = (m_head + 1 == m_seekback.size() ? 0 : m_head + 1); + ++m_seekback_length; + } + + AsyncStreamBuffer m_buffer; + FixedArray m_seekback; + size_t m_head { 0 }; + u64 m_seekback_length { 0 }; + size_t m_max_seekback_distance { 0 }; +}; + } #ifdef USING_AK_GLOBALLY using AK::AsyncStreamBuffer; +using AK::AsyncStreamSeekbackBuffer; #endif diff --git a/Userland/Libraries/LibCompress/AsyncDeflate.cpp b/Userland/Libraries/LibCompress/AsyncDeflate.cpp new file mode 100644 index 00000000000000..dec2af0586ff36 --- /dev/null +++ b/Userland/Libraries/LibCompress/AsyncDeflate.cpp @@ -0,0 +1,324 @@ +/* + * Copyright (c) 2024, Dan Klishch + * + * SPDX-License-Identifier: BSD-2-Clause + */ + +#include +#include + +namespace Compress::Async { + +namespace { +constexpr size_t max_seekback_distance = 32 * KiB; +constexpr size_t max_back_reference_length = 258; + +struct CompressedBlock { +public: + CompressedBlock(AsyncStreamSeekbackBuffer& write_buffer, CanonicalCode const& literal_codes, Optional distance_codes) + : m_write_buffer(write_buffer) + , m_literal_codes(literal_codes) + , m_distance_codes(distance_codes) + { + } + + struct ReadResult { + bool read_something { false }; + bool is_eof { false }; + }; + + ErrorOr read_current_chunk(AsyncInputLittleEndianBitStream* stream) + { + bool read_at_least_one_symbol = false; + bool is_eof = false; + TRY(stream->with_bit_view_of_buffer([&](BufferBitView& bit_view) -> ErrorOr { + while (true) { + auto result = TRY(read_symbol(bit_view)); + if (result) { + read_at_least_one_symbol = true; + } else { + is_eof = true; + return {}; + } + } + })); + return ReadResult { read_at_least_one_symbol, is_eof }; + } + +private: + ErrorOr decode_length(BufferBitView& bit_view, u32 symbol) + { + if (symbol <= 264) + return symbol - 254; + + if (symbol <= 284) { + auto extra_bits_count = (symbol - 261) / 4; + return (((symbol - 265) % 4 + 4) << extra_bits_count) + 3 + TRY(bit_view.read_bits(extra_bits_count)); + } + + if (symbol == 285) + return max_back_reference_length; + + VERIFY_NOT_REACHED(); + } + + ErrorOr decode_distance(BufferBitView& bit_view, u32 symbol) + { + if (symbol <= 3) + return symbol + 1; + + if (symbol <= 29) { + auto extra_bits_count = (symbol / 2) - 1; + return ((symbol % 2 + 2) << extra_bits_count) + 1 + TRY(bit_view.read_bits(extra_bits_count)); + } + + VERIFY_NOT_REACHED(); + } + + ErrorOr read_symbol(BufferBitView& bit_view) // True if read bytes, false if read EOF + { + return bit_view.rollback_group([&] -> ErrorOr { + u32 symbol = TRY(m_literal_codes.read_symbol(bit_view)); + + if (symbol >= 286) + return Error::from_string_literal("Invalid deflate literal/length symbol"); + + if (symbol < 256) { + m_write_buffer.write(static_cast(symbol)); + return true; + } + + if (symbol == 256) + return false; + + if (!m_distance_codes.has_value()) + return Error::from_string_literal("Distance codes have not been initialized in this block"); + + auto length = TRY(decode_length(bit_view, symbol)); + + u32 distance_symbol = TRY(m_distance_codes->read_symbol(bit_view)); + if (distance_symbol >= 30) + return Error::from_string_literal("Invalid deflate distance symbol"); + + auto distance = TRY(decode_distance(bit_view, distance_symbol)); + + if (distance > m_write_buffer.max_seekback_distance()) + return Error::from_string_literal("Provided seekback distance is larger than the amount of data available in seekback buffer"); + + m_write_buffer.copy_from_seekback(distance, length); + return true; + }); + } + + AsyncStreamSeekbackBuffer& m_write_buffer; + CanonicalCode const& m_literal_codes; + Optional m_distance_codes; +}; + +struct CodeLengthsDecompressor { +public: + CodeLengthsDecompressor(size_t length, CanonicalCode const& code_length_code) + : m_required_length(length) + , m_code_length_code(code_length_code) + { + } + + bool is_done() const + { + return m_code_lengths.size() >= m_required_length; + } + + ErrorOr read_current_chunk(AsyncInputLittleEndianBitStream* stream) + { + return stream->with_bit_view_of_buffer([&](BufferBitView& bit_view) -> ErrorOr { + while (!is_done()) + TRY(read_symbol(bit_view)); + return {}; + }); + } + + Vector&& take_code_lengths() { return move(m_code_lengths); } + +private: + static constexpr u8 deflate_special_code_length_copy = 16; + static constexpr u8 deflate_special_code_length_zeros = 17; + static constexpr u8 deflate_special_code_length_long_zeros = 18; + + ErrorOr read_symbol(BufferBitView& bit_view) + { + return bit_view.rollback_group([&] -> ErrorOr { + auto symbol = TRY(m_code_length_code.read_symbol(bit_view)); + + if (symbol < deflate_special_code_length_copy) { + m_code_lengths.append(static_cast(symbol)); + } else if (symbol == deflate_special_code_length_copy) { + if (m_code_lengths.is_empty()) + return Error::from_string_literal("Found no codes to copy before a copy block"); + auto nrepeat = 3 + TRY(bit_view.read_bits(2)); + for (size_t j = 0; j < nrepeat; ++j) + m_code_lengths.append(m_code_lengths.last()); + } else if (symbol == deflate_special_code_length_zeros) { + auto nrepeat = 3 + TRY(bit_view.read_bits(3)); + for (size_t j = 0; j < nrepeat; ++j) + m_code_lengths.append(0); + } else { + VERIFY(symbol == deflate_special_code_length_long_zeros); + auto nrepeat = 11 + TRY(bit_view.read_bits(7)); + for (size_t j = 0; j < nrepeat; ++j) + m_code_lengths.append(0); + } + return {}; + }); + } + + size_t m_required_length { 0 }; + Vector m_code_lengths; + CanonicalCode const& m_code_length_code; +}; + +Coroutine> decode_codes(AsyncInputLittleEndianBitStream* stream, CanonicalCode& literal_code, Optional& distance_code) +{ + auto literal_code_count = CO_TRY(co_await stream->read_bits(5)) + 257; + auto distance_code_count = CO_TRY(co_await stream->read_bits(5)) + 1; + auto code_length_count = CO_TRY(co_await stream->read_bits(4)) + 4; + + // First we have to extract the code lengths of the code that was used to encode the code lengths of + // the code that was used to encode the block. + + auto packed_code_lengths_code_lengths = CO_TRY(co_await stream->read_bits(code_length_count * 3)); + u8 code_lengths_code_lengths[19] = { 0 }; + + for (size_t i = 0; i < code_length_count; ++i) { + code_lengths_code_lengths[code_lengths_code_lengths_order[i]] = packed_code_lengths_code_lengths & 7; + packed_code_lengths_code_lengths >>= 3; + } + + // Now we can extract the code that was used to encode the code lengths of the code that was used to + // encode the block. + auto code_length_code_or_error = CanonicalCode::from_bytes({ code_lengths_code_lengths, sizeof(code_lengths_code_lengths) }); + if (code_length_code_or_error.is_error()) { + stream->reset(); + co_return code_length_code_or_error.release_error(); + } + auto const& code_length_code = code_length_code_or_error.value(); + + // Next we extract the code lengths of the code that was used to encode the block. + CodeLengthsDecompressor code_lengths_decompressor { literal_code_count + distance_code_count, code_length_code }; + while (!code_lengths_decompressor.is_done()) { + CO_TRY(co_await stream->peek_bits()); + CO_TRY(code_lengths_decompressor.read_current_chunk(stream)); + } + Vector code_lengths = code_lengths_decompressor.take_code_lengths(); + + if (code_lengths.size() != literal_code_count + distance_code_count) { + stream->reset(); + co_return Error::from_string_literal("Number of code lengths does not match the sum of codes"); + } + + // Now we extract the code that was used to encode literals and lengths in the block. + auto literal_code_or_error = CanonicalCode::from_bytes(code_lengths.span().trim(literal_code_count)); + if (literal_code_or_error.is_error()) { + stream->reset(); + co_return literal_code_or_error.release_error(); + } + literal_code = literal_code_or_error.release_value(); + + // Now we extract the code that was used to encode distances in the block. + if (distance_code_count == 1) { + auto length = code_lengths[literal_code_count]; + + if (length == 0) { + co_return {}; + } else if (length != 1) { + stream->reset(); + co_return Error::from_string_literal("Length for a single distance code is longer than 1"); + } + } + + auto distance_code_or_error = CanonicalCode::from_bytes(code_lengths.span().slice(literal_code_count)); + if (distance_code_or_error.is_error()) { + stream->reset(); + co_return distance_code_or_error.release_error(); + } + distance_code = distance_code_or_error.release_value(); + + co_return {}; +} +} + +DeflateDecompressor::DeflateDecompressor(NonnullOwnPtr&& input) + : AsyncStreamTransform(make(move(input)), decompress()) + , m_buffer(max_seekback_distance, 2 * (max_seekback_distance + max_back_reference_length)) +{ +} + +ReadonlyBytes DeflateDecompressor::buffered_data_unchecked(Badge) const +{ + return m_buffer.data(); +} + +void DeflateDecompressor::dequeue(Badge, size_t bytes) +{ + return m_buffer.dequeue(bytes); +} + +auto DeflateDecompressor::decompress() -> Generator +{ + while (true) { + bool is_final_block = CO_TRY(co_await m_stream->read_bit()); + auto block_type = CO_TRY(co_await m_stream->read_bits(2)); + + if (block_type == 0b00) { + m_stream->align_to_byte_boundary(); + + size_t length = CO_TRY(co_await m_stream->read_object>()); + size_t negated_length = CO_TRY(co_await m_stream->read_object>()); + + if ((length ^ 0xffff) != negated_length) { + m_stream->reset(); + co_return Error::from_string_literal("Calculated negated length does not equal stored negated length"); + } + + while (length > 0) { + auto data = CO_TRY(co_await m_stream->peek()); + size_t to_copy = min(data.size(), length); + m_buffer.write(must_sync(m_stream->read(to_copy))); + length -= to_copy; + co_yield {}; + } + } else if (block_type == 0b01 || block_type == 0b10) { + CanonicalCode literal_codes; + Optional distance_codes; + Optional block; + + if (block_type == 0b01) { + block = CompressedBlock { m_buffer, CanonicalCode::fixed_literal_codes(), CanonicalCode::fixed_distance_codes() }; + } else { + CO_TRY(co_await decode_codes(m_stream.ptr(), literal_codes, distance_codes)); + if (distance_codes.has_value()) + block = CompressedBlock { m_buffer, literal_codes, distance_codes.value() }; + else + block = CompressedBlock { m_buffer, literal_codes, {} }; + } + + while (true) { + CO_TRY(co_await m_stream->peek_bits()); + auto [read_something, is_eof] = CO_TRY(block->read_current_chunk(m_stream.ptr())); + if (read_something) + co_yield {}; + if (is_eof) + break; + } + } else { + m_stream->reset(); + co_return Error::from_string_literal("Invalid block type"); + } + + if (is_final_block) { + m_stream->align_to_byte_boundary(); + co_return {}; + } + } +} + +} diff --git a/Userland/Libraries/LibCompress/AsyncDeflate.h b/Userland/Libraries/LibCompress/AsyncDeflate.h new file mode 100644 index 00000000000000..6401d406ed229f --- /dev/null +++ b/Userland/Libraries/LibCompress/AsyncDeflate.h @@ -0,0 +1,28 @@ +/* + * Copyright (c) 2024, Dan Klishch + * + * SPDX-License-Identifier: BSD-2-Clause + */ + +#pragma once + +#include +#include +#include + +namespace Compress::Async { + +class DeflateDecompressor final : public AsyncStreamTransform { +public: + DeflateDecompressor(NonnullOwnPtr&& input); + + ReadonlyBytes buffered_data_unchecked(Badge) const override; + void dequeue(Badge, size_t bytes) override; + +private: + Generator decompress(); + + AsyncStreamSeekbackBuffer m_buffer; +}; + +} diff --git a/Userland/Libraries/LibCompress/CMakeLists.txt b/Userland/Libraries/LibCompress/CMakeLists.txt index e903a1ba5d8e61..1f281a2c20d447 100644 --- a/Userland/Libraries/LibCompress/CMakeLists.txt +++ b/Userland/Libraries/LibCompress/CMakeLists.txt @@ -1,4 +1,5 @@ set(SOURCES + AsyncDeflate.cpp Brotli.cpp BrotliDictionary.cpp Deflate.cpp diff --git a/Userland/Libraries/LibCompress/Deflate.cpp b/Userland/Libraries/LibCompress/Deflate.cpp index d35201af431607..40fc6fcbacbaef 100644 --- a/Userland/Libraries/LibCompress/Deflate.cpp +++ b/Userland/Libraries/LibCompress/Deflate.cpp @@ -166,6 +166,30 @@ ErrorOr CanonicalCode::read_symbol(LittleEndianInputBitStream& stream) cons return Error::from_string_literal("Symbol exceeds maximum symbol number"); } +ErrorOr CanonicalCode::read_symbol(BufferBitView& bit_view) const +{ + auto prefix = bit_view.peek_bits_possibly_past_end() & ((1 << m_max_prefixed_code_length) - 1); + + if (auto [symbol_value, code_length] = m_prefix_table[prefix]; code_length != 0) { + TRY(bit_view.read_bits(code_length)); + return symbol_value; + } + + auto code_bits = TRY(bit_view.read_bits(m_max_prefixed_code_length)); + code_bits = fast_reverse16(code_bits, m_max_prefixed_code_length); + code_bits |= 1 << m_max_prefixed_code_length; + + for (size_t i = m_max_prefixed_code_length; i < 16; ++i) { + size_t index; + if (binary_search(m_symbol_codes.span(), code_bits, &index)) + return m_symbol_values[index]; + + code_bits = code_bits << 1 | TRY(bit_view.read_bit()); + } + + return Error::from_string_literal("Symbol exceeds maximum symbol number"); +} + ErrorOr CanonicalCode::write_symbol(LittleEndianOutputBitStream& stream, u32 symbol) const { auto code = symbol < m_bit_codes.size() ? m_bit_codes[symbol] : 0u; diff --git a/Userland/Libraries/LibCompress/Deflate.h b/Userland/Libraries/LibCompress/Deflate.h index 00f237a0b98bce..acc87af2799c93 100644 --- a/Userland/Libraries/LibCompress/Deflate.h +++ b/Userland/Libraries/LibCompress/Deflate.h @@ -7,6 +7,7 @@ #pragma once +#include #include #include #include @@ -21,7 +22,9 @@ namespace Compress { class CanonicalCode { public: CanonicalCode() = default; + ErrorOr read_symbol(LittleEndianInputBitStream&) const; + ErrorOr read_symbol(BufferBitView& bit_view) const; ErrorOr write_symbol(LittleEndianOutputBitStream&, u32) const; static CanonicalCode const& fixed_literal_codes();