diff --git a/clients/src/main/java/org/apache/kafka/common/compress/Lz4BlockInputStream.java b/clients/src/main/java/org/apache/kafka/common/compress/Lz4BlockInputStream.java index 27ca3c589b1e5..f8cdbccf72287 100644 --- a/clients/src/main/java/org/apache/kafka/common/compress/Lz4BlockInputStream.java +++ b/clients/src/main/java/org/apache/kafka/common/compress/Lz4BlockInputStream.java @@ -19,11 +19,13 @@ import org.apache.kafka.common.compress.Lz4BlockOutputStream.BD; import org.apache.kafka.common.compress.Lz4BlockOutputStream.FLG; import org.apache.kafka.common.utils.internals.BufferSupplier; +import org.apache.kafka.common.utils.internals.Checksums; import net.jpountz.lz4.LZ4Compressor; import net.jpountz.lz4.LZ4Exception; import net.jpountz.lz4.LZ4Factory; import net.jpountz.lz4.LZ4SafeDecompressor; +import net.jpountz.xxhash.StreamingXXHash32; import net.jpountz.xxhash.XXHash32; import net.jpountz.xxhash.XXHashFactory; @@ -31,6 +33,7 @@ import java.io.InputStream; import java.nio.ByteBuffer; import java.nio.ByteOrder; +import java.util.zip.Checksum; import static org.apache.kafka.common.compress.Lz4BlockOutputStream.LZ4_FRAME_INCOMPRESSIBLE_MASK; import static org.apache.kafka.common.compress.Lz4BlockOutputStream.MAGIC; @@ -48,9 +51,11 @@ public final class Lz4BlockInputStream extends InputStream { public static final String NOT_SUPPORTED = "Stream unsupported (invalid magic bytes)"; public static final String BLOCK_HASH_MISMATCH = "Block checksum mismatch"; public static final String DESCRIPTOR_HASH_MISMATCH = "Stream frame descriptor corrupted"; + public static final String CONTENT_HASH_MISMATCH = "Content checksum mismatch"; private static final LZ4SafeDecompressor DECOMPRESSOR = LZ4Factory.fastestInstance().safeDecompressor(); - private static final XXHash32 CHECKSUM = XXHashFactory.fastestInstance().hash32(); + private static final XXHashFactory HASH_FACTORY = XXHashFactory.fastestInstance(); + private static final XXHash32 CHECKSUM = HASH_FACTORY.hash32(); private static final RuntimeException BROKEN_LZ4_EXCEPTION; // https://issues.apache.org/jira/browse/KAFKA-9203 @@ -77,6 +82,10 @@ public final class Lz4BlockInputStream extends InputStream { // If a block is compressed, this is the same as `decompressionBuffer`. If a block is not compressed, this is // a slice of `in` to avoid unnecessary copies. private ByteBuffer decompressedBuffer; + // Running XXHash32 over the decompressed content; both fields are non-null iff the frame's FLG.contentChecksum + // bit is set. `contentHashAsChecksum` is the j.u.z.Checksum view of `contentHash` used by `Checksums.update`. + private StreamingXXHash32 contentHash; + private Checksum contentHashAsChecksum; private boolean finished; /** @@ -95,6 +104,10 @@ public Lz4BlockInputStream(ByteBuffer in, BufferSupplier bufferSupplier, boolean this.bufferSupplier = bufferSupplier; readHeader(); decompressionBuffer = bufferSupplier.get(maxBlockSize); + if (flg.isContentChecksumSet()) { + contentHash = HASH_FACTORY.newStreamingHash32(0); + contentHashAsChecksum = contentHash.asChecksum(); + } finished = false; } @@ -169,8 +182,14 @@ private void readBlock() throws IOException { // Check for EndMark if (blockSize == 0) { finished = true; - if (flg.isContentChecksumSet()) - in.getInt(); // TODO: verify this content checksum + if (flg.isContentChecksumSet()) { + int expected = in.getInt(); + // Read directly from StreamingXXHash32: the lz4-java 1.10.2 asChecksum() adapter masks the + // returned value with 0xFFFFFFFL (28 bits) instead of 0xFFFFFFFFL, dropping the top 4 bits. + if (contentHash.getValue() != expected) { + throw new IOException(CONTENT_HASH_MISMATCH); + } + } return; } else if (blockSize > maxBlockSize) { throw new IOException(String.format("Block size %d exceeded max: %d", blockSize, maxBlockSize)); @@ -195,6 +214,11 @@ private void readBlock() throws IOException { decompressedBuffer.limit(blockSize); } + // Update running content hash before the consumer can advance decompressedBuffer's position. + if (contentHashAsChecksum != null) { + Checksums.update(contentHashAsChecksum, decompressedBuffer, 0, decompressedBuffer.remaining()); + } + // verify checksum if (flg.isBlockChecksumSet()) { int hash = CHECKSUM.hash(in, in.position(), blockSize, 0); @@ -264,6 +288,11 @@ public int available() { @Override public void close() { bufferSupplier.release(decompressionBuffer); + if (contentHash != null) { + contentHash.close(); + contentHash = null; + contentHashAsChecksum = null; + } } @Override diff --git a/clients/src/test/java/org/apache/kafka/common/compress/Lz4CompressionTest.java b/clients/src/test/java/org/apache/kafka/common/compress/Lz4CompressionTest.java index b621e23b7ee30..5ee7ded20b745 100644 --- a/clients/src/test/java/org/apache/kafka/common/compress/Lz4CompressionTest.java +++ b/clients/src/test/java/org/apache/kafka/common/compress/Lz4CompressionTest.java @@ -21,6 +21,7 @@ import org.apache.kafka.common.utils.internals.BufferSupplier; import org.apache.kafka.common.utils.internals.ChunkedBytesStream; +import net.jpountz.lz4.LZ4FrameOutputStream; import net.jpountz.xxhash.XXHashFactory; import org.junit.jupiter.api.Test; @@ -433,6 +434,65 @@ private void testDecompression(ByteBuffer buffer, Args args) throws IOException if (!args.close) assertNotNull(error); } + @Test + public void testContentChecksumVerificationSuccess() throws IOException { + byte[] payload = String.join("", Collections.nCopies(64, "content-checksum-verify")) + .getBytes(StandardCharsets.UTF_8); + byte[] framed = withContentChecksum(payload, false); + + try (Lz4BlockInputStream in = new Lz4BlockInputStream( + ByteBuffer.wrap(framed), BufferSupplier.create(), false)) { + assertArrayEquals(payload, in.readAllBytes()); + } + } + + @Test + public void testContentChecksumVerificationFailure() throws IOException { + byte[] payload = String.join("", Collections.nCopies(64, "content-checksum-verify")) + .getBytes(StandardCharsets.UTF_8); + byte[] framed = withContentChecksum(payload, true); + + try (Lz4BlockInputStream in = new Lz4BlockInputStream( + ByteBuffer.wrap(framed), BufferSupplier.create(), false)) { + IOException e = assertThrows(IOException.class, in::readAllBytes); + assertEquals(Lz4BlockInputStream.CONTENT_HASH_MISMATCH, e.getMessage()); + } + } + + @Test + public void testContentChecksumVerificationSuccessDirectBuffer() throws IOException { + byte[] payload = new byte[8 * 1024]; + RANDOM.nextBytes(payload); + byte[] framed = withContentChecksum(payload, false); + + ByteBuffer direct = ByteBuffer.allocateDirect(framed.length); + direct.put(framed).flip(); + + try (Lz4BlockInputStream in = new Lz4BlockInputStream(direct, BufferSupplier.create(), false)) { + assertArrayEquals(payload, in.readAllBytes()); + } + } + + /** + * Build a spec-compliant LZ4 frame with the contentChecksum FLG bit set, using lz4-java's + * own framed writer. BLOCK_INDEPENDENCE is required by Kafka's reader and is not a default. + */ + private byte[] withContentChecksum(byte[] payload, boolean corruptTrailer) throws IOException { + ByteArrayOutputStream out = new ByteArrayOutputStream(); + try (LZ4FrameOutputStream lz4 = new LZ4FrameOutputStream( + out, + LZ4FrameOutputStream.BLOCKSIZE.SIZE_64KB, + LZ4FrameOutputStream.FLG.Bits.BLOCK_INDEPENDENCE, + LZ4FrameOutputStream.FLG.Bits.CONTENT_CHECKSUM)) { + lz4.write(payload); + } + byte[] framed = out.toByteArray(); + if (corruptTrailer) { + framed[framed.length - 1] ^= 0x01; + } + return framed; + } + private byte[] compressedBytes(Args args) throws IOException { ByteArrayOutputStream output = new ByteArrayOutputStream(); Lz4BlockOutputStream lz4 = new Lz4BlockOutputStream(