From 699575c688591011bdc61c9598f45734bcf8b6e6 Mon Sep 17 00:00:00 2001 From: Akira Ajisaka Date: Fri, 27 Mar 2026 18:12:56 +0900 Subject: [PATCH 1/3] [SPARK-56227][CORE] Fix GcmTransportCipher to correctly handle multiple messages per channel MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Three bugs in `GcmTransportCipher` cause failures in production YARN clusters when AES-GCM RPC encryption is enabled (`spark.network.crypto.cipher=AES/GCM/NoPadding`). **Bug 1 — DecryptionHandler is single-use per channel (YARN container launch failure)** After decoding the first post-auth message, `completed = true` was never reset. `AesGcmHkdfStreaming` is a one-shot streaming primitive: each GCM message carries its own random IV and requires a fresh `StreamSegmentDecrypter`. With `decrypter` declared `final` and all guard flags stuck at their terminal values, every subsequent message on the channel was silently discarded. Fix: make `decrypter` non-final, add `resetForNextMessage()` that reinstates all per-message state (including a fresh `StreamSegmentDecrypter`), and call it after each successfully decoded message. **Bug 2 — TCP-coalesced messages lost (SparkSQL IllegalStateException)** When TCP delivers multiple back-to-back GCM messages in a single `channelRead()` call (common under shuffle load), the old code released the `ByteBuf` after decoding the first message, discarding any remaining bytes. The next `channelRead()` then read bytes from the middle of the second message as its length header, producing a negative `long` and throwing `IllegalStateException("Invalid expected ciphertext length")`. Fix: wrap the decode logic in an outer `while(true)` loop that exhausts all complete messages from the buffer before releasing it; call `resetForNextMessage()` inside the loop between messages. **Bug 3 — TCP-fragmented frame header causes IndexOutOfBoundsException (benchmark)** `ByteBuf.readBytes(ByteBuffer dst)` requires exactly `dst.remaining()` bytes to be present and throws `IndexOutOfBoundsException` if the source is shorter. Under high load, TCP can fragment a GCM message's 24-byte internal header (or 8-byte length prefix) across multiple `channelRead()` calls. The code incorrectly assumed `readBytes` would stop early and leave `hasRemaining() == true`. Fix: compute `toRead = Math.min(readable, dst.remaining())`, temporarily narrow `dst.limit` to `position + toRead`, call `readBytes(dst)`, then restore `limit`. **Bug 4 — EncryptionHandler shares mutable buffers across GcmEncryptedMessage instances** `plaintextBuffer` and `ciphertextBuffer` were `EncryptionHandler` fields reused across all `GcmEncryptedMessage` instances. Under Netty's write pipeline a new message can be constructed (via `write()`) before a prior one's `transferTo()` completes; the new constructor's `ciphertextBuffer.limit(0)` would corrupt the in-flight message's buffer. Fix: allocate `plaintextBuffer` and `ciphertextBuffer` inside the `GcmEncryptedMessage` constructor so each message owns its own buffers. - Cache `headerLength` in `DecryptionHandler` to avoid repeated `getHeaderLength()` calls - Replace `Integer.min()` with `Math.min()` for style consistency - `testMultipleMessages`: regression for Bug 1 — same `DecryptionHandler` decodes two independent messages delivered via separate `channelRead()` calls - `testBatchedMessages`: regression for Bug 2 — two ciphertexts concatenated into one `ByteBuf` and delivered in a single `channelRead()` call - `testSplitHeader`: regression for Bug 3 — ciphertext split at byte 12 (8-byte length field + 4 bytes into the 24-byte GCM header) across two `channelRead()` calls Co-Authored-By: Claude Opus 4.6 --- .../network/crypto/GcmTransportCipher.java | 187 +++++++++------ .../network/crypto/GcmAuthEngineSuite.java | 216 ++++++++++++++++++ 2 files changed, 334 insertions(+), 69 deletions(-) diff --git a/common/network-common/src/main/java/org/apache/spark/network/crypto/GcmTransportCipher.java b/common/network-common/src/main/java/org/apache/spark/network/crypto/GcmTransportCipher.java index d3f1bf490d3a3..4ecbd647151a6 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/crypto/GcmTransportCipher.java +++ b/common/network-common/src/main/java/org/apache/spark/network/crypto/GcmTransportCipher.java @@ -82,25 +82,16 @@ public void addToChannel(Channel ch) throws GeneralSecurityException { @VisibleForTesting class EncryptionHandler extends ChannelOutboundHandlerAdapter { - private final ByteBuffer plaintextBuffer; - private final ByteBuffer ciphertextBuffer; private final AesGcmHkdfStreaming aesGcmHkdfStreaming; EncryptionHandler() throws InvalidAlgorithmParameterException { aesGcmHkdfStreaming = getAesGcmHkdfStreaming(); - plaintextBuffer = ByteBuffer.allocate(aesGcmHkdfStreaming.getPlaintextSegmentSize()); - ciphertextBuffer = ByteBuffer.allocate(aesGcmHkdfStreaming.getCiphertextSegmentSize()); } @Override public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception { - GcmEncryptedMessage encryptedMessage = new GcmEncryptedMessage( - aesGcmHkdfStreaming, - msg, - plaintextBuffer, - ciphertextBuffer); - ctx.write(encryptedMessage, promise); + ctx.write(new GcmEncryptedMessage(aesGcmHkdfStreaming, msg), promise); } } @@ -116,15 +107,15 @@ static class GcmEncryptedMessage extends AbstractFileRegion { private final long encryptedCount; GcmEncryptedMessage(AesGcmHkdfStreaming aesGcmHkdfStreaming, - Object plaintextMessage, - ByteBuffer plaintextBuffer, - ByteBuffer ciphertextBuffer) throws GeneralSecurityException { + Object plaintextMessage) throws GeneralSecurityException { Preconditions.checkArgument( plaintextMessage instanceof ByteBuf || plaintextMessage instanceof FileRegion, "Unrecognized message type: %s", plaintextMessage.getClass().getName()); this.plaintextMessage = plaintextMessage; - this.plaintextBuffer = plaintextBuffer; - this.ciphertextBuffer = ciphertextBuffer; + this.plaintextBuffer = + ByteBuffer.allocate(aesGcmHkdfStreaming.getPlaintextSegmentSize()); + this.ciphertextBuffer = + ByteBuffer.allocate(aesGcmHkdfStreaming.getCiphertextSegmentSize()); // If the ciphertext buffer cannot be fully written the target, transferTo may // return with it containing some unwritten data. The initial call we'll explicitly // set its limit to 0 to indicate the first call to transferTo. @@ -297,7 +288,8 @@ class DecryptionHandler extends ChannelInboundHandlerAdapter { private final ByteBuffer headerBuffer; private final ByteBuffer ciphertextBuffer; private final AesGcmHkdfStreaming aesGcmHkdfStreaming; - private final StreamSegmentDecrypter decrypter; + private StreamSegmentDecrypter decrypter; + private final int headerLength; private final int plaintextSegmentSize; private boolean decrypterInit = false; private boolean completed = false; @@ -307,17 +299,46 @@ class DecryptionHandler extends ChannelInboundHandlerAdapter { DecryptionHandler() throws GeneralSecurityException { aesGcmHkdfStreaming = getAesGcmHkdfStreaming(); + headerLength = aesGcmHkdfStreaming.getHeaderLength(); expectedLengthBuffer = ByteBuffer.allocate(LENGTH_HEADER_BYTES); - headerBuffer = ByteBuffer.allocate(aesGcmHkdfStreaming.getHeaderLength()); + headerBuffer = ByteBuffer.allocate(headerLength); ciphertextBuffer = ByteBuffer.allocate(aesGcmHkdfStreaming.getCiphertextSegmentSize()); decrypter = aesGcmHkdfStreaming.newStreamSegmentDecrypter(); plaintextSegmentSize = aesGcmHkdfStreaming.getPlaintextSegmentSize(); } - private boolean initalizeExpectedLength(ByteBuf ciphertextNettyBuf) { + /** + * Resets all per-message state so that the next incoming GCM message can be + * decoded through the same channel handler instance. This must be called after + * every successfully completed message because AesGcmHkdfStreaming is a one-shot + * streaming primitive: each encrypted message carries its own random IV and must + * be decrypted with a fresh StreamSegmentDecrypter. + */ + private void resetForNextMessage() throws GeneralSecurityException { + expectedLength = -1; + expectedLengthBuffer.clear(); + headerBuffer.clear(); + ciphertextBuffer.clear(); + decrypterInit = false; + completed = false; + segmentNumber = 0; + ciphertextRead = 0; + decrypter = aesGcmHkdfStreaming.newStreamSegmentDecrypter(); + } + + private boolean initializeExpectedLength(ByteBuf ciphertextNettyBuf) { if (expectedLength < 0) { - ciphertextNettyBuf.readBytes(expectedLengthBuffer); + // ByteBuf.readBytes(ByteBuffer) throws if fewer than dst.remaining() bytes + // are available, so temporarily narrow the limit to what is actually present. + int toRead = Math.min(ciphertextNettyBuf.readableBytes(), + expectedLengthBuffer.remaining()); + if (toRead > 0) { + int savedLimit = expectedLengthBuffer.limit(); + expectedLengthBuffer.limit(expectedLengthBuffer.position() + toRead); + ciphertextNettyBuf.readBytes(expectedLengthBuffer); + expectedLengthBuffer.limit(savedLimit); + } if (expectedLengthBuffer.hasRemaining()) { // We did not read enough bytes to initialize the expected length. return false; @@ -332,12 +353,22 @@ private boolean initalizeExpectedLength(ByteBuf ciphertextNettyBuf) { return true; } - private boolean initalizeDecrypter(ByteBuf ciphertextNettyBuf) + private boolean initializeDecrypter(ByteBuf ciphertextNettyBuf) throws GeneralSecurityException { // Check if the ciphertext header has been read. This contains // the IV and other internal metadata. if (!decrypterInit) { - ciphertextNettyBuf.readBytes(headerBuffer); + // ByteBuf.readBytes(ByteBuffer) throws if fewer than dst.remaining() bytes + // are available. Under TCP fragmentation the header can arrive in multiple + // chunks, so temporarily narrow the limit to what is actually present. + int toRead = Math.min(ciphertextNettyBuf.readableBytes(), + headerBuffer.remaining()); + if (toRead > 0) { + int savedLimit = headerBuffer.limit(); + headerBuffer.limit(headerBuffer.position() + toRead); + ciphertextNettyBuf.readBytes(headerBuffer); + headerBuffer.limit(savedLimit); + } if (headerBuffer.hasRemaining()) { // We did not read enough bytes to initialize the header. return false; @@ -346,7 +377,7 @@ private boolean initalizeDecrypter(ByteBuf ciphertextNettyBuf) byte[] lengthAad = Longs.toByteArray(expectedLength); decrypter.init(headerBuffer, lengthAad); decrypterInit = true; - ciphertextRead += aesGcmHkdfStreaming.getHeaderLength(); + ciphertextRead += headerLength; if (expectedLength == ciphertextRead) { // If the expected length is just the header, the ciphertext is 0 length. completed = true; @@ -362,57 +393,75 @@ public void channelRead(ChannelHandlerContext ctx, Object ciphertextMessage) "Unrecognized message type: %s", ciphertextMessage.getClass().getName()); ByteBuf ciphertextNettyBuf = (ByteBuf) ciphertextMessage; - // The format of the output is: + // The format of each message is: // [8 byte length][Internal IV and header][Ciphertext][Auth Tag] + // + // A single channelRead() call may deliver bytes from multiple back-to-back + // GCM messages (common under shuffle load when TCP coalesces writes). The + // outer loop processes as many complete messages as possible from the buffer + // before releasing it, so that bytes belonging to the next message are never + // discarded mid-stream. try { - if (!initalizeExpectedLength(ciphertextNettyBuf)) { - // We have not read enough bytes to initialize the expected length. - return; - } - if (!initalizeDecrypter(ciphertextNettyBuf)) { - // We have not read enough bytes to initialize a header, needed to - // initialize a decrypter. - return; - } - int nettyBufReadableBytes = ciphertextNettyBuf.readableBytes(); - while (nettyBufReadableBytes > 0 && !completed) { - // Read the ciphertext into the local buffer - int readableBytes = Integer.min( - nettyBufReadableBytes, - ciphertextBuffer.remaining()); - int expectedRemaining = (int) (expectedLength - ciphertextRead); - int bytesToRead = Integer.min(readableBytes, expectedRemaining); - // The smallest ciphertext size is 16 bytes for the auth tag - ((Buffer) ciphertextBuffer).limit( - ((Buffer) ciphertextBuffer).position() + bytesToRead); - ciphertextNettyBuf.readBytes(ciphertextBuffer); - ciphertextRead += bytesToRead; - // Check if this is the last segment - if (ciphertextRead == expectedLength) { - completed = true; - } else if (ciphertextRead > expectedLength) { - throw new IllegalStateException("Read more ciphertext than expected."); + while (true) { + if (!initializeExpectedLength(ciphertextNettyBuf)) { + // We have not read enough bytes to initialize the expected length. + break; + } + if (!initializeDecrypter(ciphertextNettyBuf)) { + // We have not read enough bytes to initialize a header, needed to + // initialize a decrypter. + break; + } + int nettyBufReadableBytes = ciphertextNettyBuf.readableBytes(); + while (nettyBufReadableBytes > 0 && !completed) { + // Read the ciphertext into the local buffer + int readableBytes = Math.min( + nettyBufReadableBytes, + ciphertextBuffer.remaining()); + int expectedRemaining = (int) (expectedLength - ciphertextRead); + int bytesToRead = Math.min(readableBytes, expectedRemaining); + // The smallest ciphertext size is 16 bytes for the auth tag + ciphertextBuffer.limit(ciphertextBuffer.position() + bytesToRead); + ciphertextNettyBuf.readBytes(ciphertextBuffer); + ciphertextRead += bytesToRead; + // Check if this is the last segment + if (ciphertextRead == expectedLength) { + completed = true; + } else if (ciphertextRead > expectedLength) { + throw new IllegalStateException("Read more ciphertext than expected."); + } + // If the ciphertext buffer is full, or this is the last segment, + // then decrypt it and fire a read. + if (ciphertextBuffer.limit() == ciphertextBuffer.capacity() || completed) { + ByteBuffer plaintextBuffer = ByteBuffer.allocate(plaintextSegmentSize); + ciphertextBuffer.flip(); + decrypter.decryptSegment( + ciphertextBuffer, + segmentNumber, + completed, + plaintextBuffer); + segmentNumber++; + // Clear the ciphertext buffer because it's been read + ciphertextBuffer.clear(); + plaintextBuffer.flip(); + ctx.fireChannelRead(Unpooled.wrappedBuffer(plaintextBuffer)); + } else { + // Set the ciphertext buffer up to read the next chunk + ciphertextBuffer.limit(ciphertextBuffer.capacity()); + } + nettyBufReadableBytes = ciphertextNettyBuf.readableBytes(); + } + if (!completed) { + // Partial message: more bytes needed from the next channelRead() call. + break; } - // If the ciphertext buffer is full, or this is the last segment, - // then decrypt it and fire a read. - if (ciphertextBuffer.limit() == ciphertextBuffer.capacity() || completed) { - ByteBuffer plaintextBuffer = ByteBuffer.allocate(plaintextSegmentSize); - ((Buffer) ciphertextBuffer).flip(); - decrypter.decryptSegment( - ciphertextBuffer, - segmentNumber, - completed, - plaintextBuffer); - segmentNumber++; - // Clear the ciphertext buffer because it's been read - ((Buffer) ciphertextBuffer).clear(); - ((Buffer) plaintextBuffer).flip(); - ctx.fireChannelRead(Unpooled.wrappedBuffer(plaintextBuffer)); - } else { - // Set the ciphertext buffer up to read the next chunk - ((Buffer) ciphertextBuffer).limit(ciphertextBuffer.capacity()); + // Current message is fully decoded. Reset state so the handler can + // decode the next independent GCM message on the same channel. + resetForNextMessage(); + if (ciphertextNettyBuf.readableBytes() == 0) { + break; } - nettyBufReadableBytes = ciphertextNettyBuf.readableBytes(); + // Remaining bytes may belong to another message; loop to process them. } } finally { ciphertextNettyBuf.release(); diff --git a/common/network-common/src/test/java/org/apache/spark/network/crypto/GcmAuthEngineSuite.java b/common/network-common/src/test/java/org/apache/spark/network/crypto/GcmAuthEngineSuite.java index f25277aa1a997..0d720a97fa466 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/crypto/GcmAuthEngineSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/crypto/GcmAuthEngineSuite.java @@ -35,6 +35,7 @@ import java.nio.channels.WritableByteChannel; import java.util.Arrays; +import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertThrows; import static org.mockito.Mockito.*; @@ -294,6 +295,221 @@ public void testGcmUnalignedDecryption() throws Exception { } } + /** + * Verifies that the same DecryptionHandler instance correctly decodes multiple independent + * GCM-encrypted messages sent over the same channel. This is the regression test for the + * bug where DecryptionHandler.completed was never reset, causing every message after the + * first to be silently dropped — which manifested as YARN container launch failures. + */ + @Test + public void testMultipleMessages() throws Exception { + TransportConf gcmConf = getConf(2, false); + try (AuthEngine client = new AuthEngine("appId", "secret", gcmConf); + AuthEngine server = new AuthEngine("appId", "secret", gcmConf)) { + AuthMessage clientChallenge = client.challenge(); + AuthMessage serverResponse = server.response(clientChallenge); + client.deriveSessionCipher(clientChallenge, serverResponse); + TransportCipher cipher = server.sessionCipher(); + assert (cipher instanceof GcmTransportCipher); + GcmTransportCipher gcmTransportCipher = (GcmTransportCipher) cipher; + + GcmTransportCipher.EncryptionHandler encryptionHandler = + gcmTransportCipher.getEncryptionHandler(); + GcmTransportCipher.DecryptionHandler decryptionHandler = + gcmTransportCipher.getDecryptionHandler(); + + ChannelHandlerContext ctx = mock(ChannelHandlerContext.class); + ChannelPromise promise = mock(ChannelPromise.class); + + // --- First message --- + byte[] data1 = new byte[1024]; + Arrays.fill(data1, (byte) 'A'); + ArgumentCaptor captor1 = + ArgumentCaptor.forClass(GcmTransportCipher.GcmEncryptedMessage.class); + encryptionHandler.write(ctx, Unpooled.wrappedBuffer(data1), promise); + verify(ctx).write(captor1.capture(), eq(promise)); + ByteBuffer ct1 = ByteBuffer.allocate((int) captor1.getValue().count()); + captor1.getValue().transferTo(new ByteBufferWriteableChannel(ct1), 0); + ct1.flip(); + + ArgumentCaptor plaintextCaptor1 = ArgumentCaptor.forClass(ByteBuf.class); + decryptionHandler.channelRead(ctx, Unpooled.wrappedBuffer(ct1)); + verify(ctx, atLeastOnce()).fireChannelRead(plaintextCaptor1.capture()); + byte[] decrypted1 = new byte[data1.length]; + int offset = 0; + for (ByteBuf segment : plaintextCaptor1.getAllValues()) { + int len = segment.readableBytes(); + segment.readBytes(decrypted1, offset, len); + offset += len; + } + assertEquals(data1.length, offset); + assertArrayEquals(data1, decrypted1); + + // --- Second message (same handler, different content) --- + reset(ctx); + byte[] data2 = new byte[2048]; + Arrays.fill(data2, (byte) 'B'); + ArgumentCaptor captor2 = + ArgumentCaptor.forClass(GcmTransportCipher.GcmEncryptedMessage.class); + encryptionHandler.write(ctx, Unpooled.wrappedBuffer(data2), promise); + verify(ctx).write(captor2.capture(), eq(promise)); + ByteBuffer ct2 = ByteBuffer.allocate((int) captor2.getValue().count()); + captor2.getValue().transferTo(new ByteBufferWriteableChannel(ct2), 0); + ct2.flip(); + + ArgumentCaptor plaintextCaptor2 = ArgumentCaptor.forClass(ByteBuf.class); + decryptionHandler.channelRead(ctx, Unpooled.wrappedBuffer(ct2)); + verify(ctx, atLeastOnce()).fireChannelRead(plaintextCaptor2.capture()); + byte[] decrypted2 = new byte[data2.length]; + offset = 0; + for (ByteBuf segment : plaintextCaptor2.getAllValues()) { + int len = segment.readableBytes(); + segment.readBytes(decrypted2, offset, len); + offset += len; + } + assertEquals(data2.length, offset); + assertArrayEquals(data2, decrypted2); + } + } + + /** + * Verifies that multiple GCM-encrypted messages delivered inside a single channelRead() + * call (TCP coalescing) are all decoded correctly. This is the regression test for the + * IllegalStateException("Invalid expected ciphertext length") observed under SparkSQL + * shuffle load: when Netty batches two messages into one ByteBuf, the old code released + * the buffer after the first message, discarding remaining bytes. The next channelRead() + * then read bytes from the middle of the second message as a length header. + */ + @Test + public void testBatchedMessages() throws Exception { + TransportConf gcmConf = getConf(2, false); + try (AuthEngine client = new AuthEngine("appId", "secret", gcmConf); + AuthEngine server = new AuthEngine("appId", "secret", gcmConf)) { + AuthMessage clientChallenge = client.challenge(); + AuthMessage serverResponse = server.response(clientChallenge); + client.deriveSessionCipher(clientChallenge, serverResponse); + TransportCipher cipher = server.sessionCipher(); + assert (cipher instanceof GcmTransportCipher); + GcmTransportCipher gcmTransportCipher = (GcmTransportCipher) cipher; + + GcmTransportCipher.EncryptionHandler encryptionHandler = + gcmTransportCipher.getEncryptionHandler(); + GcmTransportCipher.DecryptionHandler decryptionHandler = + gcmTransportCipher.getDecryptionHandler(); + + ChannelHandlerContext ctx = mock(ChannelHandlerContext.class); + ChannelPromise promise = mock(ChannelPromise.class); + + byte[] data1 = new byte[1024]; + Arrays.fill(data1, (byte) 'A'); + ArgumentCaptor captor1 = + ArgumentCaptor.forClass(GcmTransportCipher.GcmEncryptedMessage.class); + encryptionHandler.write(ctx, Unpooled.wrappedBuffer(data1), promise); + verify(ctx).write(captor1.capture(), eq(promise)); + ByteBuffer ct1 = ByteBuffer.allocate((int) captor1.getValue().count()); + captor1.getValue().transferTo(new ByteBufferWriteableChannel(ct1), 0); + ct1.flip(); + + reset(ctx); + byte[] data2 = new byte[2048]; + Arrays.fill(data2, (byte) 'B'); + ArgumentCaptor captor2 = + ArgumentCaptor.forClass(GcmTransportCipher.GcmEncryptedMessage.class); + encryptionHandler.write(ctx, Unpooled.wrappedBuffer(data2), promise); + verify(ctx).write(captor2.capture(), eq(promise)); + ByteBuffer ct2 = ByteBuffer.allocate((int) captor2.getValue().count()); + captor2.getValue().transferTo(new ByteBufferWriteableChannel(ct2), 0); + ct2.flip(); + + // Simulate TCP coalescing: deliver both ciphertexts in one channelRead() call. + reset(ctx); + ByteBuf batched = Unpooled.wrappedBuffer(ct1, ct2); + ArgumentCaptor plaintextCaptor = ArgumentCaptor.forClass(ByteBuf.class); + decryptionHandler.channelRead(ctx, batched); + verify(ctx, atLeastOnce()).fireChannelRead(plaintextCaptor.capture()); + + byte[] decrypted = new byte[data1.length + data2.length]; + int offset = 0; + for (ByteBuf segment : plaintextCaptor.getAllValues()) { + int len = segment.readableBytes(); + segment.readBytes(decrypted, offset, len); + offset += len; + } + assertEquals(data1.length + data2.length, offset); + assertArrayEquals(data1, Arrays.copyOfRange(decrypted, 0, data1.length)); + assertArrayEquals(data2, Arrays.copyOfRange(decrypted, data1.length, decrypted.length)); + } + } + + /** + * Verifies that DecryptionHandler correctly handles a GCM message whose framing header + * is split across two channelRead() calls. This is the regression test for the + * IndexOutOfBoundsException in initializeDecrypter observed in benchmarking: when only + * 4 bytes of the 24-byte GCM internal header arrived in one Netty buffer, + * ByteBuf.readBytes(ByteBuffer) threw because it requires all dst.remaining() bytes to + * be available rather than performing a partial fill. + */ + @Test + public void testSplitHeader() throws Exception { + TransportConf gcmConf = getConf(2, false); + try (AuthEngine client = new AuthEngine("appId", "secret", gcmConf); + AuthEngine server = new AuthEngine("appId", "secret", gcmConf)) { + AuthMessage clientChallenge = client.challenge(); + AuthMessage serverResponse = server.response(clientChallenge); + client.deriveSessionCipher(clientChallenge, serverResponse); + TransportCipher cipher = server.sessionCipher(); + assert (cipher instanceof GcmTransportCipher); + GcmTransportCipher gcmTransportCipher = (GcmTransportCipher) cipher; + + GcmTransportCipher.EncryptionHandler encryptionHandler = + gcmTransportCipher.getEncryptionHandler(); + GcmTransportCipher.DecryptionHandler decryptionHandler = + gcmTransportCipher.getDecryptionHandler(); + + ChannelHandlerContext ctx = mock(ChannelHandlerContext.class); + ChannelPromise promise = mock(ChannelPromise.class); + + byte[] data = new byte[1024]; + Arrays.fill(data, (byte) 'X'); + ArgumentCaptor captor = + ArgumentCaptor.forClass(GcmTransportCipher.GcmEncryptedMessage.class); + encryptionHandler.write(ctx, Unpooled.wrappedBuffer(data), promise); + verify(ctx).write(captor.capture(), eq(promise)); + + ByteBuffer ciphertextBuffer = ByteBuffer.allocate((int) captor.getValue().count()); + captor.getValue().transferTo(new ByteBufferWriteableChannel(ciphertextBuffer), 0); + ciphertextBuffer.flip(); + byte[] ciphertext = new byte[ciphertextBuffer.remaining()]; + ciphertextBuffer.get(ciphertext); + + // Split in the middle of the 24-byte GCM internal header: + // chunk1 = [8-byte length field][4 bytes of GCM header] + // chunk2 = [remaining 20 bytes of GCM header][full ciphertext] + int splitPoint = 8 + 4; + ByteBuf chunk1 = Unpooled.wrappedBuffer(ciphertext, 0, splitPoint); + ByteBuf chunk2 = Unpooled.wrappedBuffer( + ciphertext, splitPoint, ciphertext.length - splitPoint); + + decryptionHandler.channelRead(ctx, chunk1); + // Only a partial header was delivered; no plaintext should be emitted yet. + verify(ctx, never()).fireChannelRead(any()); + + ArgumentCaptor plaintextCaptor = ArgumentCaptor.forClass(ByteBuf.class); + decryptionHandler.channelRead(ctx, chunk2); + verify(ctx, atLeastOnce()).fireChannelRead(plaintextCaptor.capture()); + + byte[] decrypted = new byte[data.length]; + int offset = 0; + for (ByteBuf segment : plaintextCaptor.getAllValues()) { + int len = segment.readableBytes(); + segment.readBytes(decrypted, offset, len); + offset += len; + } + assertEquals(data.length, offset); + assertArrayEquals(data, decrypted); + } + } + @Test public void testCorruptGcmEncryptedMessage() throws Exception { TransportConf gcmConf = getConf(2, false); From f7edf22393149259f28955defa23818d8c8d7a97 Mon Sep 17 00:00:00 2001 From: Akira Ajisaka Date: Mon, 27 Apr 2026 15:16:09 +0900 Subject: [PATCH 2/3] [SPARK-56227][CORE][FOLLOWUP] Fix GcmTransportCipher encryptedCount and reduce EventLoop callbacks MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Two follow-up fixes to SPARK-56227 (`GcmTransportCipher`): **Fix 1 — `encryptedCount` miscalculation for plaintext sizes in (32728, 32752]** `GcmEncryptedMessage.encryptedCount` was computed as: `LENGTH_HEADER_BYTES + expectedCiphertextSize(P)` Tink's `expectedCiphertextSize(P)` internally adds `getCiphertextOffset()` (24 bytes) to P before dividing by `plaintextSegmentSize` to count segments. For P in (32728, 32752], this predicts two ciphertext segments, but `transferTo()` writes the Tink header separately and passes all P bytes to a single `encryptSegment()` call, producing only one segment. The resulting `encryptedCount` was inflated by `TAG_SIZE_IN_BYTES` (16 bytes). After all ciphertext was written, `count() > transferred()`, so subsequent `transferTo()` calls returned 0 and the receiver stalled waiting indefinitely for 16 bytes that were never sent. Fix: subtract `getCiphertextOffset()` from the argument to `expectedCiphertextSize()` and add `getHeaderLength()` explicitly. **Fix 2 — Executor heartbeat timeout under concurrent shuffle load** `DecryptionHandler` called `ctx.fireChannelRead()` once per 32 KB ciphertext segment. Decoding a 10 MB shuffle block produced ~320 synchronous EventLoop callbacks within a single `processSelectedKeys()` invocation. With 50+ concurrent shuffle connections, the Netty EventLoop thread was occupied for seconds at a stretch, leaving no time for `runAllTasks()` to execute the executor-driver heartbeat task. Fix: accumulate all decrypted segments zero-copy into a `CompositeByteBuf` and issue a single `ctx.fireChannelRead()` when the complete message is decoded. `maxNumComponents` is set to `Integer.MAX_VALUE` to disable `consolidateIfNeeded()` (which would otherwise copy all data on O(n²) schedule once the default cap of 16 components is exceeded). - `testEncryptedCountBoundary`: encrypts and decrypts messages at 32729, 32740, and 32752 bytes, asserting `count() == transferred()` after encryption and correct round-trip for each. - `testSingleFirePerMessage`: encrypts a 5-segment plaintext and asserts exactly one `fireChannelRead` callback is issued for the full message. - Existing tests updated: `times(2)` → `times(1)` and assertions adjusted for full-message `CompositeByteBuf` in `testGcmEncryptedMessage`, `testGcmEncryptedMessageFileRegion`, and `testGcmUnalignedDecryption`. Co-Authored-By: Claude Sonnet 4.6 --- .../network/crypto/GcmTransportCipher.java | 123 ++++++++++--- .../network/crypto/GcmAuthEngineSuite.java | 171 ++++++++++++++++-- 2 files changed, 256 insertions(+), 38 deletions(-) diff --git a/common/network-common/src/main/java/org/apache/spark/network/crypto/GcmTransportCipher.java b/common/network-common/src/main/java/org/apache/spark/network/crypto/GcmTransportCipher.java index 4ecbd647151a6..ef863c3957574 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/crypto/GcmTransportCipher.java +++ b/common/network-common/src/main/java/org/apache/spark/network/crypto/GcmTransportCipher.java @@ -22,9 +22,13 @@ import com.google.common.primitives.Longs; import com.google.crypto.tink.subtle.*; import io.netty.buffer.ByteBuf; +import io.netty.buffer.CompositeByteBuf; import io.netty.buffer.Unpooled; import io.netty.channel.*; import io.netty.util.ReferenceCounted; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + import org.apache.spark.network.util.AbstractFileRegion; import org.apache.spark.network.util.ByteBufferWriteableChannel; @@ -37,23 +41,19 @@ import java.security.InvalidAlgorithmParameterException; public class GcmTransportCipher implements TransportCipher { + private static final Logger logger = LoggerFactory.getLogger(GcmTransportCipher.class); private static final String HKDF_ALG = "HmacSha256"; private static final int LENGTH_HEADER_BYTES = 8; @VisibleForTesting static final int CIPHERTEXT_BUFFER_SIZE = 32 * 1024; // 32KB private final SecretKeySpec aesKey; + private final AesGcmHkdfStreaming aesGcmHkdfStreaming; - public GcmTransportCipher(SecretKeySpec aesKey) { + public GcmTransportCipher(SecretKeySpec aesKey) throws InvalidAlgorithmParameterException { this.aesKey = aesKey; - } - - AesGcmHkdfStreaming getAesGcmHkdfStreaming() throws InvalidAlgorithmParameterException { - return new AesGcmHkdfStreaming( - aesKey.getEncoded(), - HKDF_ALG, - aesKey.getEncoded().length, - CIPHERTEXT_BUFFER_SIZE, - 0); + byte[] keyBytes = aesKey.getEncoded(); + this.aesGcmHkdfStreaming = new AesGcmHkdfStreaming( + keyBytes, HKDF_ALG, keyBytes.length, CIPHERTEXT_BUFFER_SIZE, 0); } /* @@ -82,12 +82,6 @@ public void addToChannel(Channel ch) throws GeneralSecurityException { @VisibleForTesting class EncryptionHandler extends ChannelOutboundHandlerAdapter { - private final AesGcmHkdfStreaming aesGcmHkdfStreaming; - - EncryptionHandler() throws InvalidAlgorithmParameterException { - aesGcmHkdfStreaming = getAesGcmHkdfStreaming(); - } - @Override public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception { @@ -121,8 +115,25 @@ static class GcmEncryptedMessage extends AbstractFileRegion { // set its limit to 0 to indicate the first call to transferTo. ((Buffer) this.ciphertextBuffer).limit(0); this.bytesToRead = getReadableBytes(); - this.encryptedCount = - LENGTH_HEADER_BYTES + aesGcmHkdfStreaming.expectedCiphertextSize(bytesToRead); + // Tink's expectedCiphertextSize(P) internally adds getCiphertextOffset() (the 24-byte + // streaming header) to P before computing the segment count: + // fullSegments = (P + getCiphertextOffset()) / plaintextSegmentSize + // This formula counts the header as occupying capacity in the first segment, so the + // effective plaintext capacity of segment 0 is (plaintextSegmentSize - + // getCiphertextOffset()) = 32728 bytes rather than 32752. + // + // However, transferTo() writes the streaming header separately (via headerByteBuffer) + // and passes all P bytes to encryptSegment() calls. For P in (32728, 32752], Tink's + // formula predicts two ciphertext segments but transferTo() produces only one, + // inflating encryptedCount by TAG_SIZE_IN_BYTES (16 bytes). The receiver then waits + // indefinitely for 16 bytes that were never written, causing a shuffle fetch stall. + // + // Fix: subtract getCiphertextOffset() before calling expectedCiphertextSize(), then + // add getHeaderLength() explicitly to account for the separately-written header. + this.encryptedCount = LENGTH_HEADER_BYTES + + aesGcmHkdfStreaming.getHeaderLength() + + aesGcmHkdfStreaming.expectedCiphertextSize( + bytesToRead - aesGcmHkdfStreaming.getCiphertextOffset()); byte[] lengthAad = Longs.toByteArray(encryptedCount); this.encrypter = aesGcmHkdfStreaming.newStreamSegmentEncrypter(lengthAad); this.headerByteBuffer = createHeaderByteBuffer(); @@ -287,7 +298,6 @@ class DecryptionHandler extends ChannelInboundHandlerAdapter { private final ByteBuffer expectedLengthBuffer; private final ByteBuffer headerBuffer; private final ByteBuffer ciphertextBuffer; - private final AesGcmHkdfStreaming aesGcmHkdfStreaming; private StreamSegmentDecrypter decrypter; private final int headerLength; private final int plaintextSegmentSize; @@ -296,9 +306,17 @@ class DecryptionHandler extends ChannelInboundHandlerAdapter { private int segmentNumber = 0; private long expectedLength = -1; private long ciphertextRead = 0; + // Accumulates all decrypted segments for the current GCM message. Each segment is + // appended as a zero-copy component via addComponent(true, segment). A single + // ctx.fireChannelRead() fires when the message is complete, reducing N EventLoop + // callbacks (one per 32 KB segment) to 1. This prevents the EventLoop from being + // monopolised by large messages, which would starve other channels sharing the + // thread (including the executor–driver heartbeat channel) under concurrent shuffle + // load. Null between messages; ownership is transferred to downstream on + // fireChannelRead(). + private CompositeByteBuf plaintextAccumulator = null; DecryptionHandler() throws GeneralSecurityException { - aesGcmHkdfStreaming = getAesGcmHkdfStreaming(); headerLength = aesGcmHkdfStreaming.getHeaderLength(); expectedLengthBuffer = ByteBuffer.allocate(LENGTH_HEADER_BYTES); headerBuffer = ByteBuffer.allocate(headerLength); @@ -316,6 +334,10 @@ class DecryptionHandler extends ChannelInboundHandlerAdapter { * be decrypted with a fresh StreamSegmentDecrypter. */ private void resetForNextMessage() throws GeneralSecurityException { + logger.debug( + "DecryptionHandler: message complete — " + + "expectedLength={}, ciphertextRead={}, segmentCount={}", + expectedLength, ciphertextRead, segmentNumber); expectedLength = -1; expectedLengthBuffer.clear(); headerBuffer.clear(); @@ -325,6 +347,7 @@ private void resetForNextMessage() throws GeneralSecurityException { segmentNumber = 0; ciphertextRead = 0; decrypter = aesGcmHkdfStreaming.newStreamSegmentDecrypter(); + plaintextAccumulator = null; } private boolean initializeExpectedLength(ByteBuf ciphertextNettyBuf) { @@ -349,6 +372,8 @@ private boolean initializeExpectedLength(ByteBuf ciphertextNettyBuf) { throw new IllegalStateException("Invalid expected ciphertext length."); } ciphertextRead += LENGTH_HEADER_BYTES; + logger.debug( + "DecryptionHandler: new message — expectedLength={}", expectedLength); } return true; } @@ -401,6 +426,11 @@ public void channelRead(ChannelHandlerContext ctx, Object ciphertextMessage) // outer loop processes as many complete messages as possible from the buffer // before releasing it, so that bytes belonging to the next message are never // discarded mid-stream. + final int incomingBytes = ciphertextNettyBuf.readableBytes(); + logger.debug("DecryptionHandler: channelRead — {} incoming bytes, " + + "currentMessageState: expectedLength={}, ciphertextRead={}", + incomingBytes, expectedLength, ciphertextRead); + int totalPlaintextFired = 0; try { while (true) { if (!initializeExpectedLength(ciphertextNettyBuf)) { @@ -444,7 +474,18 @@ public void channelRead(ChannelHandlerContext ctx, Object ciphertextMessage) // Clear the ciphertext buffer because it's been read ciphertextBuffer.clear(); plaintextBuffer.flip(); - ctx.fireChannelRead(Unpooled.wrappedBuffer(plaintextBuffer)); + totalPlaintextFired += plaintextBuffer.remaining(); + if (plaintextAccumulator == null) { + // Integer.MAX_VALUE disables consolidation entirely. + // CompositeByteBuf.newCompArray() always initialises the + // backing array to min(16, maxNumComponents) regardless of + // this value, so there is no upfront memory cost. + plaintextAccumulator = Unpooled.compositeBuffer(Integer.MAX_VALUE); + } + // Zero-copy append: addComponent(true, ...) increases writerIndex + // so the component is immediately readable from the composite. + plaintextAccumulator.addComponent( + true, Unpooled.wrappedBuffer(plaintextBuffer)); } else { // Set the ciphertext buffer up to read the next chunk ciphertextBuffer.limit(ciphertextBuffer.capacity()); @@ -453,8 +494,24 @@ public void channelRead(ChannelHandlerContext ctx, Object ciphertextMessage) } if (!completed) { // Partial message: more bytes needed from the next channelRead() call. + if (expectedLength < 0) { + logger.debug( + "DecryptionHandler: partial message — length header not yet read"); + } else { + logger.debug( + "DecryptionHandler: partial message — " + + "expectedLength={}, ciphertextRead={}, still need {} bytes", + expectedLength, ciphertextRead, expectedLength - ciphertextRead); + } break; } + // Fire the entire plaintext as a single event so that downstream + // handlers receive one callback per Spark message instead of one per + // 32 KB segment. + if (plaintextAccumulator != null) { + ctx.fireChannelRead(plaintextAccumulator); + plaintextAccumulator = null; // ownership transferred to downstream + } // Current message is fully decoded. Reset state so the handler can // decode the next independent GCM message on the same channel. resetForNextMessage(); @@ -466,6 +523,30 @@ public void channelRead(ChannelHandlerContext ctx, Object ciphertextMessage) } finally { ciphertextNettyBuf.release(); } + logger.debug( + "DecryptionHandler: channelRead done — {} incoming bytes, " + + "{} plaintext bytes fired", + incomingBytes, totalPlaintextFired); + } + + @Override + public void channelInactive(ChannelHandlerContext ctx) throws Exception { + releaseAccumulator(); + ctx.fireChannelInactive(); + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { + logger.error("Exception in GCM DecryptionHandler", cause); + releaseAccumulator(); + ctx.close(); + } + + private void releaseAccumulator() { + if (plaintextAccumulator != null) { + plaintextAccumulator.release(); + plaintextAccumulator = null; + } } } } diff --git a/common/network-common/src/test/java/org/apache/spark/network/crypto/GcmAuthEngineSuite.java b/common/network-common/src/test/java/org/apache/spark/network/crypto/GcmAuthEngineSuite.java index 0d720a97fa466..092489f227384 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/crypto/GcmAuthEngineSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/crypto/GcmAuthEngineSuite.java @@ -95,13 +95,13 @@ public void testGcmEncryptedMessage() throws Exception { // Capture the decrypted values and verify them ArgumentCaptor captorPlaintext = ArgumentCaptor.forClass(ByteBuf.class); decryptionHandler.channelRead(ctx, ciphertext); - verify(ctx, times(2)) + verify(ctx, times(1)) .fireChannelRead(captorPlaintext.capture()); - ByteBuf lastPlaintextSegment = captorPlaintext.getValue(); - assertEquals(plaintextSegmentSize/2, - lastPlaintextSegment.readableBytes()); + ByteBuf plaintext = captorPlaintext.getValue(); + assertEquals(plaintextSegmentSize + (plaintextSegmentSize / 2), + plaintext.readableBytes()); assertEquals('c', - lastPlaintextSegment.getByte((plaintextSegmentSize/2) - 10)); + plaintext.getByte(plaintext.readableBytes() - 10)); } } @@ -224,13 +224,11 @@ public void testGcmEncryptedMessageFileRegion() throws Exception { // Capture the decrypted values and verify them ArgumentCaptor captorPlaintext = ArgumentCaptor.forClass(ByteBuf.class); decryptionHandler.channelRead(ctx, ciphertext); - verify(ctx, times(2)).fireChannelRead(captorPlaintext.capture()); + verify(ctx, times(1)).fireChannelRead(captorPlaintext.capture()); ByteBuf plaintext = captorPlaintext.getValue(); - // We expect this to be the last partial plaintext segment - int expectedLength = totalSize % plaintextSegmentSize; - assertEquals(expectedLength, plaintext.readableBytes()); - // This will be the "remainder" segment that is filled to 'c' - assertEquals('c', plaintext.getByte(0)); + assertEquals(totalSize, plaintext.readableBytes()); + // 'c' starts at the second plaintext segment (offset plaintextSegmentSize) + assertEquals('c', plaintext.getByte(plaintextSegmentSize)); } } @@ -286,12 +284,10 @@ public void testGcmUnalignedDecryption() throws Exception { // Capture the decrypted values and verify them ArgumentCaptor captorPlaintext = ArgumentCaptor.forClass(ByteBuf.class); decryptionHandler.channelRead(ctx, mockCiphertext); - verify(ctx, times(2)).fireChannelRead(captorPlaintext.capture()); - ByteBuf lastPlaintextSegment = captorPlaintext.getValue(); - assertEquals(plaintextSegmentSize/2, - lastPlaintextSegment.readableBytes()); - assertEquals('x', - lastPlaintextSegment.getByte((plaintextSegmentSize/2) - 10)); + verify(ctx, times(1)).fireChannelRead(captorPlaintext.capture()); + ByteBuf plaintext = captorPlaintext.getValue(); + assertEquals(plaintextSize, plaintext.readableBytes()); + assertEquals('x', plaintext.getByte(plaintextSize - 10)); } } @@ -510,6 +506,147 @@ public void testSplitHeader() throws Exception { } } + /** + * Regression test for the encryptedCount miscalculation that caused shuffle fetch stalls + * for plaintext sizes in (plaintextSegmentSize - getCiphertextOffset(), plaintextSegmentSize] + * = (32728, 32752]. + * + * Root cause: encryptedCount was computed using {@code expectedCiphertextSize(P)} directly. + * Tink's formula internally adds {@code getCiphertextOffset()} = 24 to P before dividing by + * plaintextSegmentSize to count segments. For P in (32728, 32752] this predicted two + * ciphertext segments, but {@code transferTo()} writes the Tink header separately and passes + * all P bytes to a single {@code encryptSegment()} call, producing only one segment. The + * resulting encryptedCount was inflated by TAG_SIZE_IN_BYTES = 16, so + * {@code count() > transferred()} after all ciphertext was written. Subsequent + * {@code transferTo()} calls returned 0 and the receiver stalled waiting indefinitely for + * 16 bytes that were never sent. + */ + @Test + public void testEncryptedCountBoundary() throws Exception { + TransportConf gcmConf = getConf(2, false); + try (AuthEngine client = new AuthEngine("appId", "secret", gcmConf); + AuthEngine server = new AuthEngine("appId", "secret", gcmConf)) { + AuthMessage clientChallenge = client.challenge(); + AuthMessage serverResponse = server.response(clientChallenge); + client.deriveSessionCipher(clientChallenge, serverResponse); + TransportCipher cipher = server.sessionCipher(); + assert (cipher instanceof GcmTransportCipher); + GcmTransportCipher gcmTransportCipher = (GcmTransportCipher) cipher; + + GcmTransportCipher.EncryptionHandler encryptionHandler = + gcmTransportCipher.getEncryptionHandler(); + GcmTransportCipher.DecryptionHandler decryptionHandler = + gcmTransportCipher.getDecryptionHandler(); + + ChannelHandlerContext ctx = mock(ChannelHandlerContext.class); + ChannelPromise promise = mock(ChannelPromise.class); + + // plaintextSegmentSize = CIPHERTEXT_BUFFER_SIZE - TAG_SIZE = 32768 - 16 = 32752 + // getCiphertextOffset() = 24 (Tink streaming header, written separately by transferTo) + // Buggy range: P in (32728, 32752] — test lower boundary, midpoint, and upper boundary. + int plaintextSegmentSize = GcmTransportCipher.CIPHERTEXT_BUFFER_SIZE - 16; + int[] plaintextSizes = { + plaintextSegmentSize - 23, // 32729: one above the lower boundary + plaintextSegmentSize - 12, // 32740: midpoint of the affected range + plaintextSegmentSize // 32752: exactly one full segment (upper boundary) + }; + + for (int plaintextSize : plaintextSizes) { + reset(ctx); + byte[] data = new byte[plaintextSize]; + Arrays.fill(data, (byte) 'Z'); + + ArgumentCaptor encryptedCaptor = + ArgumentCaptor.forClass(GcmTransportCipher.GcmEncryptedMessage.class); + encryptionHandler.write(ctx, Unpooled.wrappedBuffer(data), promise); + verify(ctx).write(encryptedCaptor.capture(), eq(promise)); + + GcmTransportCipher.GcmEncryptedMessage encrypted = encryptedCaptor.getValue(); + ByteBuffer ciphertextBuf = ByteBuffer.allocate((int) encrypted.count()); + encrypted.transferTo(new ByteBufferWriteableChannel(ciphertextBuf), 0); + + // Before the fix: count() was inflated by 16 bytes for these sizes, so + // transferred() < count() after all plaintext was consumed. The channel stalled + // because subsequent transferTo() calls returned 0 instead of completing. + assertEquals("count() != transferred() for plaintextSize=" + plaintextSize, + encrypted.count(), encrypted.transferred()); + + // Verify the full round-trip also decrypts correctly. + reset(ctx); + ciphertextBuf.flip(); + ArgumentCaptor plaintextCaptor = ArgumentCaptor.forClass(ByteBuf.class); + decryptionHandler.channelRead(ctx, Unpooled.wrappedBuffer(ciphertextBuf)); + verify(ctx, times(1)).fireChannelRead(plaintextCaptor.capture()); + ByteBuf plaintext = plaintextCaptor.getValue(); + assertEquals(plaintextSize, plaintext.readableBytes()); + assertEquals('Z', plaintext.getByte(0)); + assertEquals('Z', plaintext.getByte(plaintextSize - 1)); + } + } + } + + /** + * Regression test for the executor heartbeat timeout caused by per-segment + * {@code ctx.fireChannelRead()} calls in {@code DecryptionHandler}. The old code fired one + * EventLoop callback per 32 KB ciphertext segment; decoding a large shuffle block produced + * many synchronous callbacks inside a single {@code processSelectedKeys()} call, + * monopolising the Netty EventLoop and starving the executor-driver heartbeat task. + * + * The fix accumulates all decrypted segments into a {@code CompositeByteBuf} and issues + * exactly one {@code ctx.fireChannelRead()} per Spark message. This test verifies that a + * multi-segment plaintext produces exactly one {@code fireChannelRead} call regardless of + * the segment count. + */ + @Test + public void testSingleFirePerMessage() throws Exception { + TransportConf gcmConf = getConf(2, false); + try (AuthEngine client = new AuthEngine("appId", "secret", gcmConf); + AuthEngine server = new AuthEngine("appId", "secret", gcmConf)) { + AuthMessage clientChallenge = client.challenge(); + AuthMessage serverResponse = server.response(clientChallenge); + client.deriveSessionCipher(clientChallenge, serverResponse); + TransportCipher cipher = server.sessionCipher(); + assert (cipher instanceof GcmTransportCipher); + GcmTransportCipher gcmTransportCipher = (GcmTransportCipher) cipher; + + GcmTransportCipher.EncryptionHandler encryptionHandler = + gcmTransportCipher.getEncryptionHandler(); + GcmTransportCipher.DecryptionHandler decryptionHandler = + gcmTransportCipher.getDecryptionHandler(); + + ChannelHandlerContext ctx = mock(ChannelHandlerContext.class); + ChannelPromise promise = mock(ChannelPromise.class); + + // Use a 5-segment plaintext so the old per-segment fire (5 calls) is clearly + // distinguishable from the expected single-fire behaviour (1 call). + int plaintextSegmentSize = GcmTransportCipher.CIPHERTEXT_BUFFER_SIZE - 16; + int plaintextSize = plaintextSegmentSize * 5; + byte[] data = new byte[plaintextSize]; + Arrays.fill(data, (byte) 'M'); + + ArgumentCaptor encryptedCaptor = + ArgumentCaptor.forClass(GcmTransportCipher.GcmEncryptedMessage.class); + encryptionHandler.write(ctx, Unpooled.wrappedBuffer(data), promise); + verify(ctx).write(encryptedCaptor.capture(), eq(promise)); + + GcmTransportCipher.GcmEncryptedMessage encrypted = encryptedCaptor.getValue(); + ByteBuffer ciphertextBuf = ByteBuffer.allocate((int) encrypted.count()); + encrypted.transferTo(new ByteBufferWriteableChannel(ciphertextBuf), 0); + ciphertextBuf.flip(); + + reset(ctx); + ArgumentCaptor plaintextCaptor = ArgumentCaptor.forClass(ByteBuf.class); + decryptionHandler.channelRead(ctx, Unpooled.wrappedBuffer(ciphertextBuf)); + // The old code fired once per 32 KB segment (5 times for this plaintext size). + // The fix must fire exactly once for the whole message. + verify(ctx, times(1)).fireChannelRead(plaintextCaptor.capture()); + ByteBuf plaintext = plaintextCaptor.getValue(); + assertEquals(plaintextSize, plaintext.readableBytes()); + assertEquals('M', plaintext.getByte(0)); + assertEquals('M', plaintext.getByte(plaintextSize - 1)); + } + } + @Test public void testCorruptGcmEncryptedMessage() throws Exception { TransportConf gcmConf = getConf(2, false); From 3864ea66b5f5d34e2ef84c742776f1e1b892b4e9 Mon Sep 17 00:00:00 2001 From: Akira Ajisaka Date: Fri, 1 May 2026 16:18:54 +0900 Subject: [PATCH 3/3] [SPARK-56227][CORE][FOLLOWUP] Add missing ((Buffer) casts for Java 8 compatibility MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit When compiled with a JDK 11/17 using -source 8 -target 8 (without --release 8), calls like byteBuffer.flip() emit invokevirtual ByteBuffer.flip():ByteBuffer — a descriptor that does not exist in Java 8's ByteBuffer. At Java 8 runtime this throws NoSuchMethodError, causing the External Shuffle Service (which runs in the NodeManager under Java 8) to drop connections. Fix: add ((Buffer) ...) casts in all new/rewritten code sections introduced by the previous commits — resetForNextMessage(), initializeExpectedLength(), initializeDecrypter(), and the channelRead() inner loop — to match the pattern already used throughout the rest of the file. Co-Authored-By: Claude Sonnet 4.6 --- .../network/crypto/GcmTransportCipher.java | 24 +++++++++---------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/common/network-common/src/main/java/org/apache/spark/network/crypto/GcmTransportCipher.java b/common/network-common/src/main/java/org/apache/spark/network/crypto/GcmTransportCipher.java index ef863c3957574..61d20a0f1156c 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/crypto/GcmTransportCipher.java +++ b/common/network-common/src/main/java/org/apache/spark/network/crypto/GcmTransportCipher.java @@ -339,9 +339,9 @@ private void resetForNextMessage() throws GeneralSecurityException { "expectedLength={}, ciphertextRead={}, segmentCount={}", expectedLength, ciphertextRead, segmentNumber); expectedLength = -1; - expectedLengthBuffer.clear(); - headerBuffer.clear(); - ciphertextBuffer.clear(); + ((Buffer) expectedLengthBuffer).clear(); + ((Buffer) headerBuffer).clear(); + ((Buffer) ciphertextBuffer).clear(); decrypterInit = false; completed = false; segmentNumber = 0; @@ -358,9 +358,9 @@ private boolean initializeExpectedLength(ByteBuf ciphertextNettyBuf) { expectedLengthBuffer.remaining()); if (toRead > 0) { int savedLimit = expectedLengthBuffer.limit(); - expectedLengthBuffer.limit(expectedLengthBuffer.position() + toRead); + ((Buffer) expectedLengthBuffer).limit(expectedLengthBuffer.position() + toRead); ciphertextNettyBuf.readBytes(expectedLengthBuffer); - expectedLengthBuffer.limit(savedLimit); + ((Buffer) expectedLengthBuffer).limit(savedLimit); } if (expectedLengthBuffer.hasRemaining()) { // We did not read enough bytes to initialize the expected length. @@ -390,9 +390,9 @@ private boolean initializeDecrypter(ByteBuf ciphertextNettyBuf) headerBuffer.remaining()); if (toRead > 0) { int savedLimit = headerBuffer.limit(); - headerBuffer.limit(headerBuffer.position() + toRead); + ((Buffer) headerBuffer).limit(headerBuffer.position() + toRead); ciphertextNettyBuf.readBytes(headerBuffer); - headerBuffer.limit(savedLimit); + ((Buffer) headerBuffer).limit(savedLimit); } if (headerBuffer.hasRemaining()) { // We did not read enough bytes to initialize the header. @@ -451,7 +451,7 @@ public void channelRead(ChannelHandlerContext ctx, Object ciphertextMessage) int expectedRemaining = (int) (expectedLength - ciphertextRead); int bytesToRead = Math.min(readableBytes, expectedRemaining); // The smallest ciphertext size is 16 bytes for the auth tag - ciphertextBuffer.limit(ciphertextBuffer.position() + bytesToRead); + ((Buffer) ciphertextBuffer).limit(ciphertextBuffer.position() + bytesToRead); ciphertextNettyBuf.readBytes(ciphertextBuffer); ciphertextRead += bytesToRead; // Check if this is the last segment @@ -464,7 +464,7 @@ public void channelRead(ChannelHandlerContext ctx, Object ciphertextMessage) // then decrypt it and fire a read. if (ciphertextBuffer.limit() == ciphertextBuffer.capacity() || completed) { ByteBuffer plaintextBuffer = ByteBuffer.allocate(plaintextSegmentSize); - ciphertextBuffer.flip(); + ((Buffer) ciphertextBuffer).flip(); decrypter.decryptSegment( ciphertextBuffer, segmentNumber, @@ -472,8 +472,8 @@ public void channelRead(ChannelHandlerContext ctx, Object ciphertextMessage) plaintextBuffer); segmentNumber++; // Clear the ciphertext buffer because it's been read - ciphertextBuffer.clear(); - plaintextBuffer.flip(); + ((Buffer) ciphertextBuffer).clear(); + ((Buffer) plaintextBuffer).flip(); totalPlaintextFired += plaintextBuffer.remaining(); if (plaintextAccumulator == null) { // Integer.MAX_VALUE disables consolidation entirely. @@ -488,7 +488,7 @@ public void channelRead(ChannelHandlerContext ctx, Object ciphertextMessage) true, Unpooled.wrappedBuffer(plaintextBuffer)); } else { // Set the ciphertext buffer up to read the next chunk - ciphertextBuffer.limit(ciphertextBuffer.capacity()); + ((Buffer) ciphertextBuffer).limit(ciphertextBuffer.capacity()); } nettyBufReadableBytes = ciphertextNettyBuf.readableBytes(); }