diff --git a/common/network-common/src/main/java/org/apache/spark/network/crypto/GcmTransportCipher.java b/common/network-common/src/main/java/org/apache/spark/network/crypto/GcmTransportCipher.java index d3f1bf490d3a..61d20a0f1156 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/crypto/GcmTransportCipher.java +++ b/common/network-common/src/main/java/org/apache/spark/network/crypto/GcmTransportCipher.java @@ -22,9 +22,13 @@ import com.google.common.primitives.Longs; import com.google.crypto.tink.subtle.*; import io.netty.buffer.ByteBuf; +import io.netty.buffer.CompositeByteBuf; import io.netty.buffer.Unpooled; import io.netty.channel.*; import io.netty.util.ReferenceCounted; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + import org.apache.spark.network.util.AbstractFileRegion; import org.apache.spark.network.util.ByteBufferWriteableChannel; @@ -37,23 +41,19 @@ import java.security.InvalidAlgorithmParameterException; public class GcmTransportCipher implements TransportCipher { + private static final Logger logger = LoggerFactory.getLogger(GcmTransportCipher.class); private static final String HKDF_ALG = "HmacSha256"; private static final int LENGTH_HEADER_BYTES = 8; @VisibleForTesting static final int CIPHERTEXT_BUFFER_SIZE = 32 * 1024; // 32KB private final SecretKeySpec aesKey; + private final AesGcmHkdfStreaming aesGcmHkdfStreaming; - public GcmTransportCipher(SecretKeySpec aesKey) { + public GcmTransportCipher(SecretKeySpec aesKey) throws InvalidAlgorithmParameterException { this.aesKey = aesKey; - } - - AesGcmHkdfStreaming getAesGcmHkdfStreaming() throws InvalidAlgorithmParameterException { - return new AesGcmHkdfStreaming( - aesKey.getEncoded(), - HKDF_ALG, - aesKey.getEncoded().length, - CIPHERTEXT_BUFFER_SIZE, - 0); + byte[] keyBytes = aesKey.getEncoded(); + this.aesGcmHkdfStreaming = new AesGcmHkdfStreaming( + keyBytes, HKDF_ALG, keyBytes.length, CIPHERTEXT_BUFFER_SIZE, 0); } /* @@ -82,25 +82,10 @@ public void addToChannel(Channel ch) throws GeneralSecurityException { @VisibleForTesting class EncryptionHandler extends ChannelOutboundHandlerAdapter { - private final ByteBuffer plaintextBuffer; - private final ByteBuffer ciphertextBuffer; - private final AesGcmHkdfStreaming aesGcmHkdfStreaming; - - EncryptionHandler() throws InvalidAlgorithmParameterException { - aesGcmHkdfStreaming = getAesGcmHkdfStreaming(); - plaintextBuffer = ByteBuffer.allocate(aesGcmHkdfStreaming.getPlaintextSegmentSize()); - ciphertextBuffer = ByteBuffer.allocate(aesGcmHkdfStreaming.getCiphertextSegmentSize()); - } - @Override public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception { - GcmEncryptedMessage encryptedMessage = new GcmEncryptedMessage( - aesGcmHkdfStreaming, - msg, - plaintextBuffer, - ciphertextBuffer); - ctx.write(encryptedMessage, promise); + ctx.write(new GcmEncryptedMessage(aesGcmHkdfStreaming, msg), promise); } } @@ -116,22 +101,39 @@ static class GcmEncryptedMessage extends AbstractFileRegion { private final long encryptedCount; GcmEncryptedMessage(AesGcmHkdfStreaming aesGcmHkdfStreaming, - Object plaintextMessage, - ByteBuffer plaintextBuffer, - ByteBuffer ciphertextBuffer) throws GeneralSecurityException { + Object plaintextMessage) throws GeneralSecurityException { Preconditions.checkArgument( plaintextMessage instanceof ByteBuf || plaintextMessage instanceof FileRegion, "Unrecognized message type: %s", plaintextMessage.getClass().getName()); this.plaintextMessage = plaintextMessage; - this.plaintextBuffer = plaintextBuffer; - this.ciphertextBuffer = ciphertextBuffer; + this.plaintextBuffer = + ByteBuffer.allocate(aesGcmHkdfStreaming.getPlaintextSegmentSize()); + this.ciphertextBuffer = + ByteBuffer.allocate(aesGcmHkdfStreaming.getCiphertextSegmentSize()); // If the ciphertext buffer cannot be fully written the target, transferTo may // return with it containing some unwritten data. The initial call we'll explicitly // set its limit to 0 to indicate the first call to transferTo. ((Buffer) this.ciphertextBuffer).limit(0); this.bytesToRead = getReadableBytes(); - this.encryptedCount = - LENGTH_HEADER_BYTES + aesGcmHkdfStreaming.expectedCiphertextSize(bytesToRead); + // Tink's expectedCiphertextSize(P) internally adds getCiphertextOffset() (the 24-byte + // streaming header) to P before computing the segment count: + // fullSegments = (P + getCiphertextOffset()) / plaintextSegmentSize + // This formula counts the header as occupying capacity in the first segment, so the + // effective plaintext capacity of segment 0 is (plaintextSegmentSize - + // getCiphertextOffset()) = 32728 bytes rather than 32752. + // + // However, transferTo() writes the streaming header separately (via headerByteBuffer) + // and passes all P bytes to encryptSegment() calls. For P in (32728, 32752], Tink's + // formula predicts two ciphertext segments but transferTo() produces only one, + // inflating encryptedCount by TAG_SIZE_IN_BYTES (16 bytes). The receiver then waits + // indefinitely for 16 bytes that were never written, causing a shuffle fetch stall. + // + // Fix: subtract getCiphertextOffset() before calling expectedCiphertextSize(), then + // add getHeaderLength() explicitly to account for the separately-written header. + this.encryptedCount = LENGTH_HEADER_BYTES + + aesGcmHkdfStreaming.getHeaderLength() + + aesGcmHkdfStreaming.expectedCiphertextSize( + bytesToRead - aesGcmHkdfStreaming.getCiphertextOffset()); byte[] lengthAad = Longs.toByteArray(encryptedCount); this.encrypter = aesGcmHkdfStreaming.newStreamSegmentEncrypter(lengthAad); this.headerByteBuffer = createHeaderByteBuffer(); @@ -296,28 +298,70 @@ class DecryptionHandler extends ChannelInboundHandlerAdapter { private final ByteBuffer expectedLengthBuffer; private final ByteBuffer headerBuffer; private final ByteBuffer ciphertextBuffer; - private final AesGcmHkdfStreaming aesGcmHkdfStreaming; - private final StreamSegmentDecrypter decrypter; + private StreamSegmentDecrypter decrypter; + private final int headerLength; private final int plaintextSegmentSize; private boolean decrypterInit = false; private boolean completed = false; private int segmentNumber = 0; private long expectedLength = -1; private long ciphertextRead = 0; + // Accumulates all decrypted segments for the current GCM message. Each segment is + // appended as a zero-copy component via addComponent(true, segment). A single + // ctx.fireChannelRead() fires when the message is complete, reducing N EventLoop + // callbacks (one per 32 KB segment) to 1. This prevents the EventLoop from being + // monopolised by large messages, which would starve other channels sharing the + // thread (including the executor–driver heartbeat channel) under concurrent shuffle + // load. Null between messages; ownership is transferred to downstream on + // fireChannelRead(). + private CompositeByteBuf plaintextAccumulator = null; DecryptionHandler() throws GeneralSecurityException { - aesGcmHkdfStreaming = getAesGcmHkdfStreaming(); + headerLength = aesGcmHkdfStreaming.getHeaderLength(); expectedLengthBuffer = ByteBuffer.allocate(LENGTH_HEADER_BYTES); - headerBuffer = ByteBuffer.allocate(aesGcmHkdfStreaming.getHeaderLength()); + headerBuffer = ByteBuffer.allocate(headerLength); ciphertextBuffer = ByteBuffer.allocate(aesGcmHkdfStreaming.getCiphertextSegmentSize()); decrypter = aesGcmHkdfStreaming.newStreamSegmentDecrypter(); plaintextSegmentSize = aesGcmHkdfStreaming.getPlaintextSegmentSize(); } - private boolean initalizeExpectedLength(ByteBuf ciphertextNettyBuf) { + /** + * Resets all per-message state so that the next incoming GCM message can be + * decoded through the same channel handler instance. This must be called after + * every successfully completed message because AesGcmHkdfStreaming is a one-shot + * streaming primitive: each encrypted message carries its own random IV and must + * be decrypted with a fresh StreamSegmentDecrypter. + */ + private void resetForNextMessage() throws GeneralSecurityException { + logger.debug( + "DecryptionHandler: message complete — " + + "expectedLength={}, ciphertextRead={}, segmentCount={}", + expectedLength, ciphertextRead, segmentNumber); + expectedLength = -1; + ((Buffer) expectedLengthBuffer).clear(); + ((Buffer) headerBuffer).clear(); + ((Buffer) ciphertextBuffer).clear(); + decrypterInit = false; + completed = false; + segmentNumber = 0; + ciphertextRead = 0; + decrypter = aesGcmHkdfStreaming.newStreamSegmentDecrypter(); + plaintextAccumulator = null; + } + + private boolean initializeExpectedLength(ByteBuf ciphertextNettyBuf) { if (expectedLength < 0) { - ciphertextNettyBuf.readBytes(expectedLengthBuffer); + // ByteBuf.readBytes(ByteBuffer) throws if fewer than dst.remaining() bytes + // are available, so temporarily narrow the limit to what is actually present. + int toRead = Math.min(ciphertextNettyBuf.readableBytes(), + expectedLengthBuffer.remaining()); + if (toRead > 0) { + int savedLimit = expectedLengthBuffer.limit(); + ((Buffer) expectedLengthBuffer).limit(expectedLengthBuffer.position() + toRead); + ciphertextNettyBuf.readBytes(expectedLengthBuffer); + ((Buffer) expectedLengthBuffer).limit(savedLimit); + } if (expectedLengthBuffer.hasRemaining()) { // We did not read enough bytes to initialize the expected length. return false; @@ -328,16 +372,28 @@ private boolean initalizeExpectedLength(ByteBuf ciphertextNettyBuf) { throw new IllegalStateException("Invalid expected ciphertext length."); } ciphertextRead += LENGTH_HEADER_BYTES; + logger.debug( + "DecryptionHandler: new message — expectedLength={}", expectedLength); } return true; } - private boolean initalizeDecrypter(ByteBuf ciphertextNettyBuf) + private boolean initializeDecrypter(ByteBuf ciphertextNettyBuf) throws GeneralSecurityException { // Check if the ciphertext header has been read. This contains // the IV and other internal metadata. if (!decrypterInit) { - ciphertextNettyBuf.readBytes(headerBuffer); + // ByteBuf.readBytes(ByteBuffer) throws if fewer than dst.remaining() bytes + // are available. Under TCP fragmentation the header can arrive in multiple + // chunks, so temporarily narrow the limit to what is actually present. + int toRead = Math.min(ciphertextNettyBuf.readableBytes(), + headerBuffer.remaining()); + if (toRead > 0) { + int savedLimit = headerBuffer.limit(); + ((Buffer) headerBuffer).limit(headerBuffer.position() + toRead); + ciphertextNettyBuf.readBytes(headerBuffer); + ((Buffer) headerBuffer).limit(savedLimit); + } if (headerBuffer.hasRemaining()) { // We did not read enough bytes to initialize the header. return false; @@ -346,7 +402,7 @@ private boolean initalizeDecrypter(ByteBuf ciphertextNettyBuf) byte[] lengthAad = Longs.toByteArray(expectedLength); decrypter.init(headerBuffer, lengthAad); decrypterInit = true; - ciphertextRead += aesGcmHkdfStreaming.getHeaderLength(); + ciphertextRead += headerLength; if (expectedLength == ciphertextRead) { // If the expected length is just the header, the ciphertext is 0 length. completed = true; @@ -362,61 +418,135 @@ public void channelRead(ChannelHandlerContext ctx, Object ciphertextMessage) "Unrecognized message type: %s", ciphertextMessage.getClass().getName()); ByteBuf ciphertextNettyBuf = (ByteBuf) ciphertextMessage; - // The format of the output is: + // The format of each message is: // [8 byte length][Internal IV and header][Ciphertext][Auth Tag] + // + // A single channelRead() call may deliver bytes from multiple back-to-back + // GCM messages (common under shuffle load when TCP coalesces writes). The + // outer loop processes as many complete messages as possible from the buffer + // before releasing it, so that bytes belonging to the next message are never + // discarded mid-stream. + final int incomingBytes = ciphertextNettyBuf.readableBytes(); + logger.debug("DecryptionHandler: channelRead — {} incoming bytes, " + + "currentMessageState: expectedLength={}, ciphertextRead={}", + incomingBytes, expectedLength, ciphertextRead); + int totalPlaintextFired = 0; try { - if (!initalizeExpectedLength(ciphertextNettyBuf)) { - // We have not read enough bytes to initialize the expected length. - return; - } - if (!initalizeDecrypter(ciphertextNettyBuf)) { - // We have not read enough bytes to initialize a header, needed to - // initialize a decrypter. - return; - } - int nettyBufReadableBytes = ciphertextNettyBuf.readableBytes(); - while (nettyBufReadableBytes > 0 && !completed) { - // Read the ciphertext into the local buffer - int readableBytes = Integer.min( - nettyBufReadableBytes, - ciphertextBuffer.remaining()); - int expectedRemaining = (int) (expectedLength - ciphertextRead); - int bytesToRead = Integer.min(readableBytes, expectedRemaining); - // The smallest ciphertext size is 16 bytes for the auth tag - ((Buffer) ciphertextBuffer).limit( - ((Buffer) ciphertextBuffer).position() + bytesToRead); - ciphertextNettyBuf.readBytes(ciphertextBuffer); - ciphertextRead += bytesToRead; - // Check if this is the last segment - if (ciphertextRead == expectedLength) { - completed = true; - } else if (ciphertextRead > expectedLength) { - throw new IllegalStateException("Read more ciphertext than expected."); + while (true) { + if (!initializeExpectedLength(ciphertextNettyBuf)) { + // We have not read enough bytes to initialize the expected length. + break; + } + if (!initializeDecrypter(ciphertextNettyBuf)) { + // We have not read enough bytes to initialize a header, needed to + // initialize a decrypter. + break; + } + int nettyBufReadableBytes = ciphertextNettyBuf.readableBytes(); + while (nettyBufReadableBytes > 0 && !completed) { + // Read the ciphertext into the local buffer + int readableBytes = Math.min( + nettyBufReadableBytes, + ciphertextBuffer.remaining()); + int expectedRemaining = (int) (expectedLength - ciphertextRead); + int bytesToRead = Math.min(readableBytes, expectedRemaining); + // The smallest ciphertext size is 16 bytes for the auth tag + ((Buffer) ciphertextBuffer).limit(ciphertextBuffer.position() + bytesToRead); + ciphertextNettyBuf.readBytes(ciphertextBuffer); + ciphertextRead += bytesToRead; + // Check if this is the last segment + if (ciphertextRead == expectedLength) { + completed = true; + } else if (ciphertextRead > expectedLength) { + throw new IllegalStateException("Read more ciphertext than expected."); + } + // If the ciphertext buffer is full, or this is the last segment, + // then decrypt it and fire a read. + if (ciphertextBuffer.limit() == ciphertextBuffer.capacity() || completed) { + ByteBuffer plaintextBuffer = ByteBuffer.allocate(plaintextSegmentSize); + ((Buffer) ciphertextBuffer).flip(); + decrypter.decryptSegment( + ciphertextBuffer, + segmentNumber, + completed, + plaintextBuffer); + segmentNumber++; + // Clear the ciphertext buffer because it's been read + ((Buffer) ciphertextBuffer).clear(); + ((Buffer) plaintextBuffer).flip(); + totalPlaintextFired += plaintextBuffer.remaining(); + if (plaintextAccumulator == null) { + // Integer.MAX_VALUE disables consolidation entirely. + // CompositeByteBuf.newCompArray() always initialises the + // backing array to min(16, maxNumComponents) regardless of + // this value, so there is no upfront memory cost. + plaintextAccumulator = Unpooled.compositeBuffer(Integer.MAX_VALUE); + } + // Zero-copy append: addComponent(true, ...) increases writerIndex + // so the component is immediately readable from the composite. + plaintextAccumulator.addComponent( + true, Unpooled.wrappedBuffer(plaintextBuffer)); + } else { + // Set the ciphertext buffer up to read the next chunk + ((Buffer) ciphertextBuffer).limit(ciphertextBuffer.capacity()); + } + nettyBufReadableBytes = ciphertextNettyBuf.readableBytes(); + } + if (!completed) { + // Partial message: more bytes needed from the next channelRead() call. + if (expectedLength < 0) { + logger.debug( + "DecryptionHandler: partial message — length header not yet read"); + } else { + logger.debug( + "DecryptionHandler: partial message — " + + "expectedLength={}, ciphertextRead={}, still need {} bytes", + expectedLength, ciphertextRead, expectedLength - ciphertextRead); + } + break; } - // If the ciphertext buffer is full, or this is the last segment, - // then decrypt it and fire a read. - if (ciphertextBuffer.limit() == ciphertextBuffer.capacity() || completed) { - ByteBuffer plaintextBuffer = ByteBuffer.allocate(plaintextSegmentSize); - ((Buffer) ciphertextBuffer).flip(); - decrypter.decryptSegment( - ciphertextBuffer, - segmentNumber, - completed, - plaintextBuffer); - segmentNumber++; - // Clear the ciphertext buffer because it's been read - ((Buffer) ciphertextBuffer).clear(); - ((Buffer) plaintextBuffer).flip(); - ctx.fireChannelRead(Unpooled.wrappedBuffer(plaintextBuffer)); - } else { - // Set the ciphertext buffer up to read the next chunk - ((Buffer) ciphertextBuffer).limit(ciphertextBuffer.capacity()); + // Fire the entire plaintext as a single event so that downstream + // handlers receive one callback per Spark message instead of one per + // 32 KB segment. + if (plaintextAccumulator != null) { + ctx.fireChannelRead(plaintextAccumulator); + plaintextAccumulator = null; // ownership transferred to downstream } - nettyBufReadableBytes = ciphertextNettyBuf.readableBytes(); + // Current message is fully decoded. Reset state so the handler can + // decode the next independent GCM message on the same channel. + resetForNextMessage(); + if (ciphertextNettyBuf.readableBytes() == 0) { + break; + } + // Remaining bytes may belong to another message; loop to process them. } } finally { ciphertextNettyBuf.release(); } + logger.debug( + "DecryptionHandler: channelRead done — {} incoming bytes, " + + "{} plaintext bytes fired", + incomingBytes, totalPlaintextFired); + } + + @Override + public void channelInactive(ChannelHandlerContext ctx) throws Exception { + releaseAccumulator(); + ctx.fireChannelInactive(); + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { + logger.error("Exception in GCM DecryptionHandler", cause); + releaseAccumulator(); + ctx.close(); + } + + private void releaseAccumulator() { + if (plaintextAccumulator != null) { + plaintextAccumulator.release(); + plaintextAccumulator = null; + } } } } diff --git a/common/network-common/src/test/java/org/apache/spark/network/crypto/GcmAuthEngineSuite.java b/common/network-common/src/test/java/org/apache/spark/network/crypto/GcmAuthEngineSuite.java index f25277aa1a99..092489f22738 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/crypto/GcmAuthEngineSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/crypto/GcmAuthEngineSuite.java @@ -35,6 +35,7 @@ import java.nio.channels.WritableByteChannel; import java.util.Arrays; +import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertThrows; import static org.mockito.Mockito.*; @@ -94,13 +95,13 @@ public void testGcmEncryptedMessage() throws Exception { // Capture the decrypted values and verify them ArgumentCaptor captorPlaintext = ArgumentCaptor.forClass(ByteBuf.class); decryptionHandler.channelRead(ctx, ciphertext); - verify(ctx, times(2)) + verify(ctx, times(1)) .fireChannelRead(captorPlaintext.capture()); - ByteBuf lastPlaintextSegment = captorPlaintext.getValue(); - assertEquals(plaintextSegmentSize/2, - lastPlaintextSegment.readableBytes()); + ByteBuf plaintext = captorPlaintext.getValue(); + assertEquals(plaintextSegmentSize + (plaintextSegmentSize / 2), + plaintext.readableBytes()); assertEquals('c', - lastPlaintextSegment.getByte((plaintextSegmentSize/2) - 10)); + plaintext.getByte(plaintext.readableBytes() - 10)); } } @@ -223,13 +224,11 @@ public void testGcmEncryptedMessageFileRegion() throws Exception { // Capture the decrypted values and verify them ArgumentCaptor captorPlaintext = ArgumentCaptor.forClass(ByteBuf.class); decryptionHandler.channelRead(ctx, ciphertext); - verify(ctx, times(2)).fireChannelRead(captorPlaintext.capture()); + verify(ctx, times(1)).fireChannelRead(captorPlaintext.capture()); ByteBuf plaintext = captorPlaintext.getValue(); - // We expect this to be the last partial plaintext segment - int expectedLength = totalSize % plaintextSegmentSize; - assertEquals(expectedLength, plaintext.readableBytes()); - // This will be the "remainder" segment that is filled to 'c' - assertEquals('c', plaintext.getByte(0)); + assertEquals(totalSize, plaintext.readableBytes()); + // 'c' starts at the second plaintext segment (offset plaintextSegmentSize) + assertEquals('c', plaintext.getByte(plaintextSegmentSize)); } } @@ -285,12 +284,366 @@ public void testGcmUnalignedDecryption() throws Exception { // Capture the decrypted values and verify them ArgumentCaptor captorPlaintext = ArgumentCaptor.forClass(ByteBuf.class); decryptionHandler.channelRead(ctx, mockCiphertext); - verify(ctx, times(2)).fireChannelRead(captorPlaintext.capture()); - ByteBuf lastPlaintextSegment = captorPlaintext.getValue(); - assertEquals(plaintextSegmentSize/2, - lastPlaintextSegment.readableBytes()); - assertEquals('x', - lastPlaintextSegment.getByte((plaintextSegmentSize/2) - 10)); + verify(ctx, times(1)).fireChannelRead(captorPlaintext.capture()); + ByteBuf plaintext = captorPlaintext.getValue(); + assertEquals(plaintextSize, plaintext.readableBytes()); + assertEquals('x', plaintext.getByte(plaintextSize - 10)); + } + } + + /** + * Verifies that the same DecryptionHandler instance correctly decodes multiple independent + * GCM-encrypted messages sent over the same channel. This is the regression test for the + * bug where DecryptionHandler.completed was never reset, causing every message after the + * first to be silently dropped — which manifested as YARN container launch failures. + */ + @Test + public void testMultipleMessages() throws Exception { + TransportConf gcmConf = getConf(2, false); + try (AuthEngine client = new AuthEngine("appId", "secret", gcmConf); + AuthEngine server = new AuthEngine("appId", "secret", gcmConf)) { + AuthMessage clientChallenge = client.challenge(); + AuthMessage serverResponse = server.response(clientChallenge); + client.deriveSessionCipher(clientChallenge, serverResponse); + TransportCipher cipher = server.sessionCipher(); + assert (cipher instanceof GcmTransportCipher); + GcmTransportCipher gcmTransportCipher = (GcmTransportCipher) cipher; + + GcmTransportCipher.EncryptionHandler encryptionHandler = + gcmTransportCipher.getEncryptionHandler(); + GcmTransportCipher.DecryptionHandler decryptionHandler = + gcmTransportCipher.getDecryptionHandler(); + + ChannelHandlerContext ctx = mock(ChannelHandlerContext.class); + ChannelPromise promise = mock(ChannelPromise.class); + + // --- First message --- + byte[] data1 = new byte[1024]; + Arrays.fill(data1, (byte) 'A'); + ArgumentCaptor captor1 = + ArgumentCaptor.forClass(GcmTransportCipher.GcmEncryptedMessage.class); + encryptionHandler.write(ctx, Unpooled.wrappedBuffer(data1), promise); + verify(ctx).write(captor1.capture(), eq(promise)); + ByteBuffer ct1 = ByteBuffer.allocate((int) captor1.getValue().count()); + captor1.getValue().transferTo(new ByteBufferWriteableChannel(ct1), 0); + ct1.flip(); + + ArgumentCaptor plaintextCaptor1 = ArgumentCaptor.forClass(ByteBuf.class); + decryptionHandler.channelRead(ctx, Unpooled.wrappedBuffer(ct1)); + verify(ctx, atLeastOnce()).fireChannelRead(plaintextCaptor1.capture()); + byte[] decrypted1 = new byte[data1.length]; + int offset = 0; + for (ByteBuf segment : plaintextCaptor1.getAllValues()) { + int len = segment.readableBytes(); + segment.readBytes(decrypted1, offset, len); + offset += len; + } + assertEquals(data1.length, offset); + assertArrayEquals(data1, decrypted1); + + // --- Second message (same handler, different content) --- + reset(ctx); + byte[] data2 = new byte[2048]; + Arrays.fill(data2, (byte) 'B'); + ArgumentCaptor captor2 = + ArgumentCaptor.forClass(GcmTransportCipher.GcmEncryptedMessage.class); + encryptionHandler.write(ctx, Unpooled.wrappedBuffer(data2), promise); + verify(ctx).write(captor2.capture(), eq(promise)); + ByteBuffer ct2 = ByteBuffer.allocate((int) captor2.getValue().count()); + captor2.getValue().transferTo(new ByteBufferWriteableChannel(ct2), 0); + ct2.flip(); + + ArgumentCaptor plaintextCaptor2 = ArgumentCaptor.forClass(ByteBuf.class); + decryptionHandler.channelRead(ctx, Unpooled.wrappedBuffer(ct2)); + verify(ctx, atLeastOnce()).fireChannelRead(plaintextCaptor2.capture()); + byte[] decrypted2 = new byte[data2.length]; + offset = 0; + for (ByteBuf segment : plaintextCaptor2.getAllValues()) { + int len = segment.readableBytes(); + segment.readBytes(decrypted2, offset, len); + offset += len; + } + assertEquals(data2.length, offset); + assertArrayEquals(data2, decrypted2); + } + } + + /** + * Verifies that multiple GCM-encrypted messages delivered inside a single channelRead() + * call (TCP coalescing) are all decoded correctly. This is the regression test for the + * IllegalStateException("Invalid expected ciphertext length") observed under SparkSQL + * shuffle load: when Netty batches two messages into one ByteBuf, the old code released + * the buffer after the first message, discarding remaining bytes. The next channelRead() + * then read bytes from the middle of the second message as a length header. + */ + @Test + public void testBatchedMessages() throws Exception { + TransportConf gcmConf = getConf(2, false); + try (AuthEngine client = new AuthEngine("appId", "secret", gcmConf); + AuthEngine server = new AuthEngine("appId", "secret", gcmConf)) { + AuthMessage clientChallenge = client.challenge(); + AuthMessage serverResponse = server.response(clientChallenge); + client.deriveSessionCipher(clientChallenge, serverResponse); + TransportCipher cipher = server.sessionCipher(); + assert (cipher instanceof GcmTransportCipher); + GcmTransportCipher gcmTransportCipher = (GcmTransportCipher) cipher; + + GcmTransportCipher.EncryptionHandler encryptionHandler = + gcmTransportCipher.getEncryptionHandler(); + GcmTransportCipher.DecryptionHandler decryptionHandler = + gcmTransportCipher.getDecryptionHandler(); + + ChannelHandlerContext ctx = mock(ChannelHandlerContext.class); + ChannelPromise promise = mock(ChannelPromise.class); + + byte[] data1 = new byte[1024]; + Arrays.fill(data1, (byte) 'A'); + ArgumentCaptor captor1 = + ArgumentCaptor.forClass(GcmTransportCipher.GcmEncryptedMessage.class); + encryptionHandler.write(ctx, Unpooled.wrappedBuffer(data1), promise); + verify(ctx).write(captor1.capture(), eq(promise)); + ByteBuffer ct1 = ByteBuffer.allocate((int) captor1.getValue().count()); + captor1.getValue().transferTo(new ByteBufferWriteableChannel(ct1), 0); + ct1.flip(); + + reset(ctx); + byte[] data2 = new byte[2048]; + Arrays.fill(data2, (byte) 'B'); + ArgumentCaptor captor2 = + ArgumentCaptor.forClass(GcmTransportCipher.GcmEncryptedMessage.class); + encryptionHandler.write(ctx, Unpooled.wrappedBuffer(data2), promise); + verify(ctx).write(captor2.capture(), eq(promise)); + ByteBuffer ct2 = ByteBuffer.allocate((int) captor2.getValue().count()); + captor2.getValue().transferTo(new ByteBufferWriteableChannel(ct2), 0); + ct2.flip(); + + // Simulate TCP coalescing: deliver both ciphertexts in one channelRead() call. + reset(ctx); + ByteBuf batched = Unpooled.wrappedBuffer(ct1, ct2); + ArgumentCaptor plaintextCaptor = ArgumentCaptor.forClass(ByteBuf.class); + decryptionHandler.channelRead(ctx, batched); + verify(ctx, atLeastOnce()).fireChannelRead(plaintextCaptor.capture()); + + byte[] decrypted = new byte[data1.length + data2.length]; + int offset = 0; + for (ByteBuf segment : plaintextCaptor.getAllValues()) { + int len = segment.readableBytes(); + segment.readBytes(decrypted, offset, len); + offset += len; + } + assertEquals(data1.length + data2.length, offset); + assertArrayEquals(data1, Arrays.copyOfRange(decrypted, 0, data1.length)); + assertArrayEquals(data2, Arrays.copyOfRange(decrypted, data1.length, decrypted.length)); + } + } + + /** + * Verifies that DecryptionHandler correctly handles a GCM message whose framing header + * is split across two channelRead() calls. This is the regression test for the + * IndexOutOfBoundsException in initializeDecrypter observed in benchmarking: when only + * 4 bytes of the 24-byte GCM internal header arrived in one Netty buffer, + * ByteBuf.readBytes(ByteBuffer) threw because it requires all dst.remaining() bytes to + * be available rather than performing a partial fill. + */ + @Test + public void testSplitHeader() throws Exception { + TransportConf gcmConf = getConf(2, false); + try (AuthEngine client = new AuthEngine("appId", "secret", gcmConf); + AuthEngine server = new AuthEngine("appId", "secret", gcmConf)) { + AuthMessage clientChallenge = client.challenge(); + AuthMessage serverResponse = server.response(clientChallenge); + client.deriveSessionCipher(clientChallenge, serverResponse); + TransportCipher cipher = server.sessionCipher(); + assert (cipher instanceof GcmTransportCipher); + GcmTransportCipher gcmTransportCipher = (GcmTransportCipher) cipher; + + GcmTransportCipher.EncryptionHandler encryptionHandler = + gcmTransportCipher.getEncryptionHandler(); + GcmTransportCipher.DecryptionHandler decryptionHandler = + gcmTransportCipher.getDecryptionHandler(); + + ChannelHandlerContext ctx = mock(ChannelHandlerContext.class); + ChannelPromise promise = mock(ChannelPromise.class); + + byte[] data = new byte[1024]; + Arrays.fill(data, (byte) 'X'); + ArgumentCaptor captor = + ArgumentCaptor.forClass(GcmTransportCipher.GcmEncryptedMessage.class); + encryptionHandler.write(ctx, Unpooled.wrappedBuffer(data), promise); + verify(ctx).write(captor.capture(), eq(promise)); + + ByteBuffer ciphertextBuffer = ByteBuffer.allocate((int) captor.getValue().count()); + captor.getValue().transferTo(new ByteBufferWriteableChannel(ciphertextBuffer), 0); + ciphertextBuffer.flip(); + byte[] ciphertext = new byte[ciphertextBuffer.remaining()]; + ciphertextBuffer.get(ciphertext); + + // Split in the middle of the 24-byte GCM internal header: + // chunk1 = [8-byte length field][4 bytes of GCM header] + // chunk2 = [remaining 20 bytes of GCM header][full ciphertext] + int splitPoint = 8 + 4; + ByteBuf chunk1 = Unpooled.wrappedBuffer(ciphertext, 0, splitPoint); + ByteBuf chunk2 = Unpooled.wrappedBuffer( + ciphertext, splitPoint, ciphertext.length - splitPoint); + + decryptionHandler.channelRead(ctx, chunk1); + // Only a partial header was delivered; no plaintext should be emitted yet. + verify(ctx, never()).fireChannelRead(any()); + + ArgumentCaptor plaintextCaptor = ArgumentCaptor.forClass(ByteBuf.class); + decryptionHandler.channelRead(ctx, chunk2); + verify(ctx, atLeastOnce()).fireChannelRead(plaintextCaptor.capture()); + + byte[] decrypted = new byte[data.length]; + int offset = 0; + for (ByteBuf segment : plaintextCaptor.getAllValues()) { + int len = segment.readableBytes(); + segment.readBytes(decrypted, offset, len); + offset += len; + } + assertEquals(data.length, offset); + assertArrayEquals(data, decrypted); + } + } + + /** + * Regression test for the encryptedCount miscalculation that caused shuffle fetch stalls + * for plaintext sizes in (plaintextSegmentSize - getCiphertextOffset(), plaintextSegmentSize] + * = (32728, 32752]. + * + * Root cause: encryptedCount was computed using {@code expectedCiphertextSize(P)} directly. + * Tink's formula internally adds {@code getCiphertextOffset()} = 24 to P before dividing by + * plaintextSegmentSize to count segments. For P in (32728, 32752] this predicted two + * ciphertext segments, but {@code transferTo()} writes the Tink header separately and passes + * all P bytes to a single {@code encryptSegment()} call, producing only one segment. The + * resulting encryptedCount was inflated by TAG_SIZE_IN_BYTES = 16, so + * {@code count() > transferred()} after all ciphertext was written. Subsequent + * {@code transferTo()} calls returned 0 and the receiver stalled waiting indefinitely for + * 16 bytes that were never sent. + */ + @Test + public void testEncryptedCountBoundary() throws Exception { + TransportConf gcmConf = getConf(2, false); + try (AuthEngine client = new AuthEngine("appId", "secret", gcmConf); + AuthEngine server = new AuthEngine("appId", "secret", gcmConf)) { + AuthMessage clientChallenge = client.challenge(); + AuthMessage serverResponse = server.response(clientChallenge); + client.deriveSessionCipher(clientChallenge, serverResponse); + TransportCipher cipher = server.sessionCipher(); + assert (cipher instanceof GcmTransportCipher); + GcmTransportCipher gcmTransportCipher = (GcmTransportCipher) cipher; + + GcmTransportCipher.EncryptionHandler encryptionHandler = + gcmTransportCipher.getEncryptionHandler(); + GcmTransportCipher.DecryptionHandler decryptionHandler = + gcmTransportCipher.getDecryptionHandler(); + + ChannelHandlerContext ctx = mock(ChannelHandlerContext.class); + ChannelPromise promise = mock(ChannelPromise.class); + + // plaintextSegmentSize = CIPHERTEXT_BUFFER_SIZE - TAG_SIZE = 32768 - 16 = 32752 + // getCiphertextOffset() = 24 (Tink streaming header, written separately by transferTo) + // Buggy range: P in (32728, 32752] — test lower boundary, midpoint, and upper boundary. + int plaintextSegmentSize = GcmTransportCipher.CIPHERTEXT_BUFFER_SIZE - 16; + int[] plaintextSizes = { + plaintextSegmentSize - 23, // 32729: one above the lower boundary + plaintextSegmentSize - 12, // 32740: midpoint of the affected range + plaintextSegmentSize // 32752: exactly one full segment (upper boundary) + }; + + for (int plaintextSize : plaintextSizes) { + reset(ctx); + byte[] data = new byte[plaintextSize]; + Arrays.fill(data, (byte) 'Z'); + + ArgumentCaptor encryptedCaptor = + ArgumentCaptor.forClass(GcmTransportCipher.GcmEncryptedMessage.class); + encryptionHandler.write(ctx, Unpooled.wrappedBuffer(data), promise); + verify(ctx).write(encryptedCaptor.capture(), eq(promise)); + + GcmTransportCipher.GcmEncryptedMessage encrypted = encryptedCaptor.getValue(); + ByteBuffer ciphertextBuf = ByteBuffer.allocate((int) encrypted.count()); + encrypted.transferTo(new ByteBufferWriteableChannel(ciphertextBuf), 0); + + // Before the fix: count() was inflated by 16 bytes for these sizes, so + // transferred() < count() after all plaintext was consumed. The channel stalled + // because subsequent transferTo() calls returned 0 instead of completing. + assertEquals("count() != transferred() for plaintextSize=" + plaintextSize, + encrypted.count(), encrypted.transferred()); + + // Verify the full round-trip also decrypts correctly. + reset(ctx); + ciphertextBuf.flip(); + ArgumentCaptor plaintextCaptor = ArgumentCaptor.forClass(ByteBuf.class); + decryptionHandler.channelRead(ctx, Unpooled.wrappedBuffer(ciphertextBuf)); + verify(ctx, times(1)).fireChannelRead(plaintextCaptor.capture()); + ByteBuf plaintext = plaintextCaptor.getValue(); + assertEquals(plaintextSize, plaintext.readableBytes()); + assertEquals('Z', plaintext.getByte(0)); + assertEquals('Z', plaintext.getByte(plaintextSize - 1)); + } + } + } + + /** + * Regression test for the executor heartbeat timeout caused by per-segment + * {@code ctx.fireChannelRead()} calls in {@code DecryptionHandler}. The old code fired one + * EventLoop callback per 32 KB ciphertext segment; decoding a large shuffle block produced + * many synchronous callbacks inside a single {@code processSelectedKeys()} call, + * monopolising the Netty EventLoop and starving the executor-driver heartbeat task. + * + * The fix accumulates all decrypted segments into a {@code CompositeByteBuf} and issues + * exactly one {@code ctx.fireChannelRead()} per Spark message. This test verifies that a + * multi-segment plaintext produces exactly one {@code fireChannelRead} call regardless of + * the segment count. + */ + @Test + public void testSingleFirePerMessage() throws Exception { + TransportConf gcmConf = getConf(2, false); + try (AuthEngine client = new AuthEngine("appId", "secret", gcmConf); + AuthEngine server = new AuthEngine("appId", "secret", gcmConf)) { + AuthMessage clientChallenge = client.challenge(); + AuthMessage serverResponse = server.response(clientChallenge); + client.deriveSessionCipher(clientChallenge, serverResponse); + TransportCipher cipher = server.sessionCipher(); + assert (cipher instanceof GcmTransportCipher); + GcmTransportCipher gcmTransportCipher = (GcmTransportCipher) cipher; + + GcmTransportCipher.EncryptionHandler encryptionHandler = + gcmTransportCipher.getEncryptionHandler(); + GcmTransportCipher.DecryptionHandler decryptionHandler = + gcmTransportCipher.getDecryptionHandler(); + + ChannelHandlerContext ctx = mock(ChannelHandlerContext.class); + ChannelPromise promise = mock(ChannelPromise.class); + + // Use a 5-segment plaintext so the old per-segment fire (5 calls) is clearly + // distinguishable from the expected single-fire behaviour (1 call). + int plaintextSegmentSize = GcmTransportCipher.CIPHERTEXT_BUFFER_SIZE - 16; + int plaintextSize = plaintextSegmentSize * 5; + byte[] data = new byte[plaintextSize]; + Arrays.fill(data, (byte) 'M'); + + ArgumentCaptor encryptedCaptor = + ArgumentCaptor.forClass(GcmTransportCipher.GcmEncryptedMessage.class); + encryptionHandler.write(ctx, Unpooled.wrappedBuffer(data), promise); + verify(ctx).write(encryptedCaptor.capture(), eq(promise)); + + GcmTransportCipher.GcmEncryptedMessage encrypted = encryptedCaptor.getValue(); + ByteBuffer ciphertextBuf = ByteBuffer.allocate((int) encrypted.count()); + encrypted.transferTo(new ByteBufferWriteableChannel(ciphertextBuf), 0); + ciphertextBuf.flip(); + + reset(ctx); + ArgumentCaptor plaintextCaptor = ArgumentCaptor.forClass(ByteBuf.class); + decryptionHandler.channelRead(ctx, Unpooled.wrappedBuffer(ciphertextBuf)); + // The old code fired once per 32 KB segment (5 times for this plaintext size). + // The fix must fire exactly once for the whole message. + verify(ctx, times(1)).fireChannelRead(plaintextCaptor.capture()); + ByteBuf plaintext = plaintextCaptor.getValue(); + assertEquals(plaintextSize, plaintext.readableBytes()); + assertEquals('M', plaintext.getByte(0)); + assertEquals('M', plaintext.getByte(plaintextSize - 1)); } }