From ab783665b6a195586f180c3ecba73b8eaf8ef2b2 Mon Sep 17 00:00:00 2001 From: Steve Weis Date: Fri, 21 Jun 2024 21:07:37 +0800 Subject: [PATCH 1/3] [SPARK-47172][CORE][3.5] Add support for AES-GCM for RPC encryption This change adds AES-GCM as an optional AES cipher mode for RPC encryption. The current default is using AES-CTR without any authentication. That would allow someone on the network to easily modify RPC contents on the wire and impact Spark behavior. See [SPARK-47172](https://issues.apache.org/jira/browse/SPARK-47172) for more details. The current default is using AES-CTR without any authentication. That would allow someone on the network to easily modify RPC contents on the wire and impact Spark behavior. Yes, it adds an additional configuration flag is reflected in the documentation. Existing unit tests are all ensured to pass. New unit tests are written to explicitly test GCM support and to verify that modifying ciphertext content will cause an exception and fail. `build/sbt "network-common/test:testOnly"` `build/sbt "network-common/test:testOnly org.apache.spark.network.crypto.AuthIntegrationSuite"` `build/sbt "network-common/test:testOnly org.apache.spark.network.crypto.AuthEngineSuite"` Nope. Closes #46515 from sweisdb/SPARK-47172. Authored-by: Steve Weis Signed-off-by: Yi Wu --- .../spark/network/crypto/AuthEngine.java | 21 +- .../network/crypto/CtrTransportCipher.java | 381 ++++++++++++++++ .../network/crypto/GcmTransportCipher.java | 420 ++++++++++++++++++ .../spark/network/crypto/TransportCipher.java | 374 +--------------- .../util/ByteBufferWriteableChannel.java | 59 +++ .../spark/network/crypto/AuthEngineSuite.java | 203 ++------- .../network/crypto/AuthIntegrationSuite.java | 79 ++-- .../network/crypto/CtrAuthEngineSuite.java | 178 ++++++++ .../network/crypto/GcmAuthEngineSuite.java | 342 ++++++++++++++ .../network/crypto/TransportCipherSuite.java | 4 +- docs/security.md | 9 + 11 files changed, 1529 insertions(+), 541 deletions(-) create mode 100644 common/network-common/src/main/java/org/apache/spark/network/crypto/CtrTransportCipher.java create mode 100644 common/network-common/src/main/java/org/apache/spark/network/crypto/GcmTransportCipher.java create mode 100644 common/network-common/src/main/java/org/apache/spark/network/util/ByteBufferWriteableChannel.java create mode 100644 common/network-common/src/test/java/org/apache/spark/network/crypto/CtrAuthEngineSuite.java create mode 100644 common/network-common/src/test/java/org/apache/spark/network/crypto/GcmAuthEngineSuite.java diff --git a/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthEngine.java b/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthEngine.java index 14f0c23fd05fc..ee558bce7dab9 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthEngine.java +++ b/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthEngine.java @@ -45,6 +45,8 @@ class AuthEngine implements Closeable { public static final byte[] INPUT_IV_INFO = "inputIv".getBytes(UTF_8); public static final byte[] OUTPUT_IV_INFO = "outputIv".getBytes(UTF_8); private static final String MAC_ALGORITHM = "HMACSHA256"; + private static final String LEGACY_CIPHER_ALGORITHM = "AES/CTR/NoPadding"; + private static final String CIPHER_ALGORITHM = "AES/GCM/NoPadding"; private static final int AES_GCM_KEY_SIZE_BYTES = 16; private static final byte[] EMPTY_TRANSCRIPT = new byte[0]; private static final int UNSAFE_SKIP_HKDF_VERSION = 1; @@ -227,12 +229,19 @@ private TransportCipher generateTransportCipher( OUTPUT_IV_INFO, // This is the HKDF info field used to differentiate IV values AES_GCM_KEY_SIZE_BYTES); SecretKeySpec sessionKey = new SecretKeySpec(derivedKey, "AES"); - return new TransportCipher( - cryptoConf, - conf.cipherTransformation(), - sessionKey, - isClient ? clientIv : serverIv, // If it's the client, use the client IV first - isClient ? serverIv : clientIv); + if (LEGACY_CIPHER_ALGORITHM.equalsIgnoreCase(conf.cipherTransformation())) { + return new CtrTransportCipher( + cryptoConf, + sessionKey, + isClient ? clientIv : serverIv, // If it's the client, use the client IV first + isClient ? serverIv : clientIv); + } else if (CIPHER_ALGORITHM.equalsIgnoreCase(conf.cipherTransformation())) { + return new GcmTransportCipher(sessionKey); + } else { + throw new IllegalArgumentException( + String.format("Unsupported cipher mode: %s. %s and %s are supported.", + conf.cipherTransformation(), CIPHER_ALGORITHM, LEGACY_CIPHER_ALGORITHM)); + } } private byte[] getTranscript(AuthMessage... encryptedPublicKeys) { diff --git a/common/network-common/src/main/java/org/apache/spark/network/crypto/CtrTransportCipher.java b/common/network-common/src/main/java/org/apache/spark/network/crypto/CtrTransportCipher.java new file mode 100644 index 0000000000000..85b893751b39c --- /dev/null +++ b/common/network-common/src/main/java/org/apache/spark/network/crypto/CtrTransportCipher.java @@ -0,0 +1,381 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.crypto; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.channels.ReadableByteChannel; +import java.nio.channels.WritableByteChannel; +import java.security.GeneralSecurityException; +import java.util.Properties; +import javax.crypto.spec.SecretKeySpec; +import javax.crypto.spec.IvParameterSpec; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Preconditions; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.*; +import org.apache.commons.crypto.stream.CryptoInputStream; +import org.apache.commons.crypto.stream.CryptoOutputStream; + +import org.apache.spark.network.util.AbstractFileRegion; +import org.apache.spark.network.util.ByteArrayReadableChannel; +import org.apache.spark.network.util.ByteArrayWritableChannel; + +/** + * Cipher for encryption and decryption. + */ +public class CtrTransportCipher implements TransportCipher { + @VisibleForTesting + static final String ENCRYPTION_HANDLER_NAME = "CtrTransportEncryption"; + private static final String DECRYPTION_HANDLER_NAME = "CtrTransportDecryption"; + @VisibleForTesting + static final int STREAM_BUFFER_SIZE = 1024 * 32; + + private final Properties conf; + private static final String CIPHER_ALGORITHM = "AES/CTR/NoPadding"; + private final SecretKeySpec key; + private final byte[] inIv; + private final byte[] outIv; + + public CtrTransportCipher( + Properties conf, + SecretKeySpec key, + byte[] inIv, + byte[] outIv) { + this.conf = conf; + this.key = key; + this.inIv = inIv; + this.outIv = outIv; + } + + /* + * This method is for testing purposes only. + */ + @VisibleForTesting + public String getKeyId() throws GeneralSecurityException { + return TransportCipherUtil.getKeyId(key); + } + + @VisibleForTesting + SecretKeySpec getKey() { + return key; + } + + /** The IV for the input channel (i.e. output channel of the remote side). */ + public byte[] getInputIv() { + return inIv; + } + + /** The IV for the output channel (i.e. input channel of the remote side). */ + public byte[] getOutputIv() { + return outIv; + } + + @VisibleForTesting + CryptoOutputStream createOutputStream(WritableByteChannel ch) throws IOException { + return new CryptoOutputStream(CIPHER_ALGORITHM, conf, ch, key, new IvParameterSpec(outIv)); + } + + @VisibleForTesting + CryptoInputStream createInputStream(ReadableByteChannel ch) throws IOException { + return new CryptoInputStream(CIPHER_ALGORITHM, conf, ch, key, new IvParameterSpec(inIv)); + } + + /** + * Add handlers to channel. + * + * @param ch the channel for adding handlers + * @throws IOException + */ + public void addToChannel(Channel ch) throws IOException { + ch.pipeline() + .addFirst(ENCRYPTION_HANDLER_NAME, new EncryptionHandler(this)) + .addFirst(DECRYPTION_HANDLER_NAME, new DecryptionHandler(this)); + } + + @VisibleForTesting + static class EncryptionHandler extends ChannelOutboundHandlerAdapter { + private final ByteArrayWritableChannel byteEncChannel; + private final CryptoOutputStream cos; + private final ByteArrayWritableChannel byteRawChannel; + private boolean isCipherValid; + + EncryptionHandler(CtrTransportCipher cipher) throws IOException { + byteEncChannel = new ByteArrayWritableChannel(STREAM_BUFFER_SIZE); + cos = cipher.createOutputStream(byteEncChannel); + byteRawChannel = new ByteArrayWritableChannel(STREAM_BUFFER_SIZE); + isCipherValid = true; + } + + @Override + public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) + throws Exception { + ctx.write(createEncryptedMessage(msg), promise); + } + + @VisibleForTesting + EncryptedMessage createEncryptedMessage(Object msg) { + return new EncryptedMessage(this, cos, msg, byteEncChannel, byteRawChannel); + } + + @Override + public void close(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception { + try { + if (isCipherValid) { + cos.close(); + } + } finally { + super.close(ctx, promise); + } + } + + /** + * SPARK-25535. Workaround for CRYPTO-141. Avoid further interaction with the underlying cipher + * after an error occurs. + */ + void reportError() { + this.isCipherValid = false; + } + + boolean isCipherValid() { + return isCipherValid; + } + } + + private static class DecryptionHandler extends ChannelInboundHandlerAdapter { + private final CryptoInputStream cis; + private final ByteArrayReadableChannel byteChannel; + private boolean isCipherValid; + + DecryptionHandler(CtrTransportCipher cipher) throws IOException { + byteChannel = new ByteArrayReadableChannel(); + cis = cipher.createInputStream(byteChannel); + isCipherValid = true; + } + + @Override + public void channelRead(ChannelHandlerContext ctx, Object data) throws Exception { + ByteBuf buffer = (ByteBuf) data; + + try { + if (!isCipherValid) { + throw new IOException("Cipher is in invalid state."); + } + byte[] decryptedData = new byte[buffer.readableBytes()]; + byteChannel.feedData(buffer); + + int offset = 0; + while (offset < decryptedData.length) { + // SPARK-25535: workaround for CRYPTO-141. + try { + offset += cis.read(decryptedData, offset, decryptedData.length - offset); + } catch (InternalError ie) { + isCipherValid = false; + throw ie; + } + } + + ctx.fireChannelRead(Unpooled.wrappedBuffer(decryptedData, 0, decryptedData.length)); + } finally { + buffer.release(); + } + } + + @Override + public void handlerRemoved(ChannelHandlerContext ctx) throws Exception { + // We do the closing of the stream / channel in handlerRemoved(...) as + // this method will be called in all cases: + // + // - when the Channel becomes inactive + // - when the handler is removed from the ChannelPipeline + try { + if (isCipherValid) { + cis.close(); + } + } finally { + super.handlerRemoved(ctx); + } + } + } + + @VisibleForTesting + static class EncryptedMessage extends AbstractFileRegion { + private final boolean isByteBuf; + private final ByteBuf buf; + private final FileRegion region; + private final CryptoOutputStream cos; + private final EncryptionHandler handler; + private final long count; + private long transferred; + + // Due to streaming issue CRYPTO-125: https://issues.apache.org/jira/browse/CRYPTO-125, it has + // to utilize two helper ByteArrayWritableChannel for streaming. One is used to receive raw data + // from upper handler, another is used to store encrypted data. + private final ByteArrayWritableChannel byteEncChannel; + private final ByteArrayWritableChannel byteRawChannel; + + private ByteBuffer currentEncrypted; + + EncryptedMessage( + EncryptionHandler handler, + CryptoOutputStream cos, + Object msg, + ByteArrayWritableChannel byteEncChannel, + ByteArrayWritableChannel byteRawChannel) { + Preconditions.checkArgument(msg instanceof ByteBuf || msg instanceof FileRegion, + "Unrecognized message type: %s", msg.getClass().getName()); + this.handler = handler; + this.isByteBuf = msg instanceof ByteBuf; + this.buf = isByteBuf ? (ByteBuf) msg : null; + this.region = isByteBuf ? null : (FileRegion) msg; + this.transferred = 0; + this.cos = cos; + this.byteEncChannel = byteEncChannel; + this.byteRawChannel = byteRawChannel; + this.count = isByteBuf ? buf.readableBytes() : region.count(); + } + + @Override + public long count() { + return count; + } + + @Override + public long position() { + return 0; + } + + @Override + public long transferred() { + return transferred; + } + + @Override + public EncryptedMessage touch(Object o) { + super.touch(o); + if (region != null) { + region.touch(o); + } + if (buf != null) { + buf.touch(o); + } + return this; + } + + @Override + public EncryptedMessage retain(int increment) { + super.retain(increment); + if (region != null) { + region.retain(increment); + } + if (buf != null) { + buf.retain(increment); + } + return this; + } + + @Override + public boolean release(int decrement) { + if (region != null) { + region.release(decrement); + } + if (buf != null) { + buf.release(decrement); + } + return super.release(decrement); + } + + @Override + public long transferTo(WritableByteChannel target, long position) throws IOException { + Preconditions.checkArgument(position == transferred(), "Invalid position."); + + if (transferred == count) { + return 0; + } + + long totalBytesWritten = 0L; + do { + if (currentEncrypted == null) { + encryptMore(); + } + + long remaining = currentEncrypted.remaining(); + if (remaining == 0) { + // Just for safety to avoid endless loop. It usually won't happen, but since the + // underlying `region.transferTo` is allowed to transfer 0 bytes, we should handle it for + // safety. + currentEncrypted = null; + byteEncChannel.reset(); + return totalBytesWritten; + } + + long bytesWritten = target.write(currentEncrypted); + totalBytesWritten += bytesWritten; + transferred += bytesWritten; + if (bytesWritten < remaining) { + // break as the underlying buffer in "target" is full + break; + } + currentEncrypted = null; + byteEncChannel.reset(); + } while (transferred < count); + + return totalBytesWritten; + } + + private void encryptMore() throws IOException { + if (!handler.isCipherValid()) { + throw new IOException("Cipher is in invalid state."); + } + byteRawChannel.reset(); + + if (isByteBuf) { + int copied = byteRawChannel.write(buf.nioBuffer()); + buf.skipBytes(copied); + } else { + region.transferTo(byteRawChannel, region.transferred()); + } + + try { + cos.write(byteRawChannel.getData(), 0, byteRawChannel.length()); + cos.flush(); + } catch (InternalError ie) { + handler.reportError(); + throw ie; + } + + currentEncrypted = ByteBuffer.wrap(byteEncChannel.getData(), + 0, byteEncChannel.length()); + } + + @Override + protected void deallocate() { + byteRawChannel.reset(); + byteEncChannel.reset(); + if (region != null) { + region.release(); + } + if (buf != null) { + buf.release(); + } + } + } + +} diff --git a/common/network-common/src/main/java/org/apache/spark/network/crypto/GcmTransportCipher.java b/common/network-common/src/main/java/org/apache/spark/network/crypto/GcmTransportCipher.java new file mode 100644 index 0000000000000..9599b78007374 --- /dev/null +++ b/common/network-common/src/main/java/org/apache/spark/network/crypto/GcmTransportCipher.java @@ -0,0 +1,420 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.crypto; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Preconditions; +import com.google.common.primitives.Longs; +import com.google.crypto.tink.subtle.*; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.*; +import io.netty.util.ReferenceCounted; +import org.apache.spark.network.util.AbstractFileRegion; +import org.apache.spark.network.util.ByteBufferWriteableChannel; + +import javax.crypto.spec.SecretKeySpec; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.channels.WritableByteChannel; +import java.security.GeneralSecurityException; +import java.security.InvalidAlgorithmParameterException; + +public class GcmTransportCipher implements TransportCipher { + private static final String HKDF_ALG = "HmacSha256"; + private static final int LENGTH_HEADER_BYTES = 8; + @VisibleForTesting + static final int CIPHERTEXT_BUFFER_SIZE = 32 * 1024; // 32KB + private final SecretKeySpec aesKey; + + public GcmTransportCipher(SecretKeySpec aesKey) { + this.aesKey = aesKey; + } + + AesGcmHkdfStreaming getAesGcmHkdfStreaming() throws InvalidAlgorithmParameterException { + return new AesGcmHkdfStreaming( + aesKey.getEncoded(), + HKDF_ALG, + aesKey.getEncoded().length, + CIPHERTEXT_BUFFER_SIZE, + 0); + } + + /* + * This method is for testing purposes only. + */ + @VisibleForTesting + public String getKeyId() throws GeneralSecurityException { + return TransportCipherUtil.getKeyId(aesKey); + } + + @VisibleForTesting + EncryptionHandler getEncryptionHandler() throws GeneralSecurityException { + return new EncryptionHandler(); + } + + @VisibleForTesting + DecryptionHandler getDecryptionHandler() throws GeneralSecurityException { + return new DecryptionHandler(); + } + + public void addToChannel(Channel ch) throws GeneralSecurityException { + ch.pipeline() + .addFirst("GcmTransportEncryption", getEncryptionHandler()) + .addFirst("GcmTransportDecryption", getDecryptionHandler()); + } + + @VisibleForTesting + class EncryptionHandler extends ChannelOutboundHandlerAdapter { + private final ByteBuffer plaintextBuffer; + private final ByteBuffer ciphertextBuffer; + private final AesGcmHkdfStreaming aesGcmHkdfStreaming; + + EncryptionHandler() throws InvalidAlgorithmParameterException { + aesGcmHkdfStreaming = getAesGcmHkdfStreaming(); + plaintextBuffer = ByteBuffer.allocate(aesGcmHkdfStreaming.getPlaintextSegmentSize()); + ciphertextBuffer = ByteBuffer.allocate(aesGcmHkdfStreaming.getCiphertextSegmentSize()); + } + + @Override + public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) + throws Exception { + GcmEncryptedMessage encryptedMessage = new GcmEncryptedMessage( + aesGcmHkdfStreaming, + msg, + plaintextBuffer, + ciphertextBuffer); + ctx.write(encryptedMessage, promise); + } + } + + static class GcmEncryptedMessage extends AbstractFileRegion { + private final Object plaintextMessage; + private final ByteBuffer plaintextBuffer; + private final ByteBuffer ciphertextBuffer; + private final ByteBuffer headerByteBuffer; + private final long bytesToRead; + private long bytesRead = 0; + private final StreamSegmentEncrypter encrypter; + private long transferred = 0; + private final long encryptedCount; + + GcmEncryptedMessage(AesGcmHkdfStreaming aesGcmHkdfStreaming, + Object plaintextMessage, + ByteBuffer plaintextBuffer, + ByteBuffer ciphertextBuffer) throws GeneralSecurityException { + Preconditions.checkArgument( + plaintextMessage instanceof ByteBuf || plaintextMessage instanceof FileRegion, + "Unrecognized message type: %s", plaintextMessage.getClass().getName()); + this.plaintextMessage = plaintextMessage; + this.plaintextBuffer = plaintextBuffer; + this.ciphertextBuffer = ciphertextBuffer; + // If the ciphertext buffer cannot be fully written the target, transferTo may + // return with it containing some unwritten data. The initial call we'll explicitly + // set its limit to 0 to indicate the first call to transferTo. + this.ciphertextBuffer.limit(0); + + this.bytesToRead = getReadableBytes(); + this.encryptedCount = + LENGTH_HEADER_BYTES + aesGcmHkdfStreaming.expectedCiphertextSize(bytesToRead); + byte[] lengthAad = Longs.toByteArray(encryptedCount); + this.encrypter = aesGcmHkdfStreaming.newStreamSegmentEncrypter(lengthAad); + this.headerByteBuffer = createHeaderByteBuffer(); + } + + // The format of the output is: + // [8 byte length][Internal IV and header][Ciphertext][Auth Tag] + private ByteBuffer createHeaderByteBuffer() { + ByteBuffer encrypterHeader = encrypter.getHeader(); + return ByteBuffer + .allocate(encrypterHeader.remaining() + LENGTH_HEADER_BYTES) + .putLong(encryptedCount) + .put(encrypterHeader) + .flip(); + } + + @Override + public long position() { + return 0; + } + + @Override + public long transferred() { + return transferred; + } + + @Override + public long count() { + return encryptedCount; + } + + @Override + public GcmEncryptedMessage touch(Object o) { + super.touch(o); + if (plaintextMessage instanceof ByteBuf) { + ByteBuf byteBuf = (ByteBuf) plaintextMessage; + byteBuf.touch(o); + } else if (plaintextMessage instanceof FileRegion) { + FileRegion fileRegion = (FileRegion) plaintextMessage; + fileRegion.touch(o); + } + return this; + } + + @Override + public GcmEncryptedMessage retain(int increment) { + super.retain(increment); + if (plaintextMessage instanceof ByteBuf) { + ByteBuf byteBuf = (ByteBuf) plaintextMessage; + byteBuf.retain(increment); + } else if (plaintextMessage instanceof FileRegion) { + FileRegion fileRegion = (FileRegion) plaintextMessage; + fileRegion.retain(increment); + } + return this; + } + + @Override + public boolean release(int decrement) { + if (plaintextMessage instanceof ByteBuf) { + ByteBuf byteBuf = (ByteBuf) plaintextMessage; + byteBuf.release(decrement); + } else if (plaintextMessage instanceof FileRegion) { + FileRegion fileRegion = (FileRegion) plaintextMessage; + fileRegion.release(decrement); + } + return super.release(decrement); + } + + @Override + public long transferTo(WritableByteChannel target, long position) throws IOException { + int transferredThisCall = 0; + // If the header has is not empty, try to write it out to the target. + if (headerByteBuffer.hasRemaining()) { + int written = target.write(headerByteBuffer); + transferredThisCall += written; + this.transferred += written; + if (headerByteBuffer.hasRemaining()) { + return written; + } + } + // If the ciphertext buffer is not empty, try to write it to the target. + if (ciphertextBuffer.hasRemaining()) { + int written = target.write(ciphertextBuffer); + transferredThisCall += written; + this.transferred += written; + if (ciphertextBuffer.hasRemaining()) { + return transferredThisCall; + } + } + while (bytesRead < bytesToRead) { + long readableBytes = getReadableBytes(); + int readLimit = + (int) Math.min(readableBytes, plaintextBuffer.remaining()); + if (plaintextMessage instanceof ByteBuf) { + ByteBuf byteBuf = (ByteBuf) plaintextMessage; + Preconditions.checkState(0 == plaintextBuffer.position()); + plaintextBuffer.limit(readLimit); + byteBuf.readBytes(plaintextBuffer); + Preconditions.checkState(readLimit == plaintextBuffer.position()); + } else if (plaintextMessage instanceof FileRegion) { + FileRegion fileRegion = (FileRegion) plaintextMessage; + ByteBufferWriteableChannel plaintextChannel = + new ByteBufferWriteableChannel(plaintextBuffer); + long plaintextRead = + fileRegion.transferTo(plaintextChannel, fileRegion.transferred()); + if (plaintextRead < readLimit) { + // If we do not read a full plaintext buffer or all the available + // readable bytes, return what was transferred this call. + return transferredThisCall; + } + } + boolean lastSegment = getReadableBytes() == 0; + plaintextBuffer.flip(); + bytesRead += plaintextBuffer.remaining(); + ciphertextBuffer.clear(); + try { + encrypter.encryptSegment(plaintextBuffer, lastSegment, ciphertextBuffer); + } catch (GeneralSecurityException e) { + throw new IllegalStateException("GeneralSecurityException from encrypter", e); + } + plaintextBuffer.clear(); + ciphertextBuffer.flip(); + int written = target.write(ciphertextBuffer); + transferredThisCall += written; + this.transferred += written; + if (ciphertextBuffer.hasRemaining()) { + // In this case, upon calling transferTo again, it will try to write the + // remaining ciphertext buffer in the conditional before this loop. + return transferredThisCall; + } + } + return transferredThisCall; + } + + private long getReadableBytes() { + if (plaintextMessage instanceof ByteBuf) { + ByteBuf byteBuf = (ByteBuf) plaintextMessage; + return byteBuf.readableBytes(); + } else if (plaintextMessage instanceof FileRegion) { + FileRegion fileRegion = (FileRegion) plaintextMessage; + return fileRegion.count() - fileRegion.transferred(); + } else { + throw new IllegalArgumentException("Unsupported message type: " + + plaintextMessage.getClass().getName()); + } + } + + @Override + protected void deallocate() { + if (plaintextMessage instanceof ReferenceCounted) { + ((ReferenceCounted) plaintextMessage).release(); + } + plaintextBuffer.clear(); + ciphertextBuffer.clear(); + } + } + + @VisibleForTesting + class DecryptionHandler extends ChannelInboundHandlerAdapter { + private final ByteBuffer expectedLengthBuffer; + private final ByteBuffer headerBuffer; + private final ByteBuffer ciphertextBuffer; + private final AesGcmHkdfStreaming aesGcmHkdfStreaming; + private final StreamSegmentDecrypter decrypter; + private final int plaintextSegmentSize; + private boolean decrypterInit = false; + private boolean completed = false; + private int segmentNumber = 0; + private long expectedLength = -1; + private long ciphertextRead = 0; + + DecryptionHandler() throws GeneralSecurityException { + aesGcmHkdfStreaming = getAesGcmHkdfStreaming(); + expectedLengthBuffer = ByteBuffer.allocate(LENGTH_HEADER_BYTES); + headerBuffer = ByteBuffer.allocate(aesGcmHkdfStreaming.getHeaderLength()); + ciphertextBuffer = + ByteBuffer.allocate(aesGcmHkdfStreaming.getCiphertextSegmentSize()); + decrypter = aesGcmHkdfStreaming.newStreamSegmentDecrypter(); + plaintextSegmentSize = aesGcmHkdfStreaming.getPlaintextSegmentSize(); + } + + private boolean initalizeExpectedLength(ByteBuf ciphertextNettyBuf) { + if (expectedLength < 0) { + ciphertextNettyBuf.readBytes(expectedLengthBuffer); + if (expectedLengthBuffer.hasRemaining()) { + // We did not read enough bytes to initialize the expected length. + return false; + } + expectedLengthBuffer.flip(); + expectedLength = expectedLengthBuffer.getLong(); + if (expectedLength < 0) { + throw new IllegalStateException("Invalid expected ciphertext length."); + } + ciphertextRead += LENGTH_HEADER_BYTES; + } + return true; + } + + private boolean initalizeDecrypter(ByteBuf ciphertextNettyBuf) + throws GeneralSecurityException { + // Check if the ciphertext header has been read. This contains + // the IV and other internal metadata. + if (!decrypterInit) { + ciphertextNettyBuf.readBytes(headerBuffer); + if (headerBuffer.hasRemaining()) { + // We did not read enough bytes to initialize the header. + return false; + } + headerBuffer.flip(); + byte[] lengthAad = Longs.toByteArray(expectedLength); + decrypter.init(headerBuffer, lengthAad); + decrypterInit = true; + ciphertextRead += aesGcmHkdfStreaming.getHeaderLength(); + if (expectedLength == ciphertextRead) { + // If the expected length is just the header, the ciphertext is 0 length. + completed = true; + } + } + return true; + } + + @Override + public void channelRead(ChannelHandlerContext ctx, Object ciphertextMessage) + throws GeneralSecurityException { + Preconditions.checkArgument(ciphertextMessage instanceof ByteBuf, + "Unrecognized message type: %s", + ciphertextMessage.getClass().getName()); + ByteBuf ciphertextNettyBuf = (ByteBuf) ciphertextMessage; + // The format of the output is: + // [8 byte length][Internal IV and header][Ciphertext][Auth Tag] + try { + if (!initalizeExpectedLength(ciphertextNettyBuf)) { + // We have not read enough bytes to initialize the expected length. + return; + } + if (!initalizeDecrypter(ciphertextNettyBuf)) { + // We have not read enough bytes to initialize a header, needed to + // initialize a decrypter. + return; + } + int nettyBufReadableBytes = ciphertextNettyBuf.readableBytes(); + while (nettyBufReadableBytes > 0 && !completed) { + // Read the ciphertext into the local buffer + int readableBytes = Integer.min( + nettyBufReadableBytes, + ciphertextBuffer.remaining()); + int expectedRemaining = (int) (expectedLength - ciphertextRead); + int bytesToRead = Integer.min(readableBytes, expectedRemaining); + // The smallest ciphertext size is 16 bytes for the auth tag + ciphertextBuffer.limit(ciphertextBuffer.position() + bytesToRead); + ciphertextNettyBuf.readBytes(ciphertextBuffer); + ciphertextRead += bytesToRead; + // Check if this is the last segment + if (ciphertextRead == expectedLength) { + completed = true; + } else if (ciphertextRead > expectedLength) { + throw new IllegalStateException("Read more ciphertext than expected."); + } + // If the ciphertext buffer is full, or this is the last segment, + // then decrypt it and fire a read. + if (ciphertextBuffer.limit() == ciphertextBuffer.capacity() || completed) { + ByteBuffer plaintextBuffer = ByteBuffer.allocate(plaintextSegmentSize); + ciphertextBuffer.flip(); + decrypter.decryptSegment( + ciphertextBuffer, + segmentNumber, + completed, + plaintextBuffer); + segmentNumber++; + // Clear the ciphertext buffer because it's been read + ciphertextBuffer.clear(); + plaintextBuffer.flip(); + ctx.fireChannelRead(Unpooled.wrappedBuffer(plaintextBuffer)); + } else { + // Set the ciphertext buffer up to read the next chunk + ciphertextBuffer.limit(ciphertextBuffer.capacity()); + } + nettyBufReadableBytes = ciphertextNettyBuf.readableBytes(); + } + } finally { + ciphertextNettyBuf.release(); + } + } + } +} diff --git a/common/network-common/src/main/java/org/apache/spark/network/crypto/TransportCipher.java b/common/network-common/src/main/java/org/apache/spark/network/crypto/TransportCipher.java index b507f911fe11a..355c552720185 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/crypto/TransportCipher.java +++ b/common/network-common/src/main/java/org/apache/spark/network/crypto/TransportCipher.java @@ -17,362 +17,32 @@ package org.apache.spark.network.crypto; -import java.io.IOException; -import java.nio.ByteBuffer; -import java.nio.channels.ReadableByteChannel; -import java.nio.channels.WritableByteChannel; -import java.util.Properties; -import javax.crypto.spec.SecretKeySpec; -import javax.crypto.spec.IvParameterSpec; - import com.google.common.annotations.VisibleForTesting; -import com.google.common.base.Preconditions; -import io.netty.buffer.ByteBuf; -import io.netty.buffer.Unpooled; -import io.netty.channel.*; -import org.apache.commons.crypto.stream.CryptoInputStream; -import org.apache.commons.crypto.stream.CryptoOutputStream; - -import org.apache.spark.network.util.AbstractFileRegion; -import org.apache.spark.network.util.ByteArrayReadableChannel; -import org.apache.spark.network.util.ByteArrayWritableChannel; - -/** - * Cipher for encryption and decryption. - */ -public class TransportCipher { - @VisibleForTesting - static final String ENCRYPTION_HANDLER_NAME = "TransportEncryption"; - private static final String DECRYPTION_HANDLER_NAME = "TransportDecryption"; - @VisibleForTesting - static final int STREAM_BUFFER_SIZE = 1024 * 32; - - private final Properties conf; - private final String cipher; - private final SecretKeySpec key; - private final byte[] inIv; - private final byte[] outIv; - - public TransportCipher( - Properties conf, - String cipher, - SecretKeySpec key, - byte[] inIv, - byte[] outIv) { - this.conf = conf; - this.cipher = cipher; - this.key = key; - this.inIv = inIv; - this.outIv = outIv; - } - - public String getCipherTransformation() { - return cipher; - } - - @VisibleForTesting - SecretKeySpec getKey() { - return key; - } - - /** The IV for the input channel (i.e. output channel of the remote side). */ - public byte[] getInputIv() { - return inIv; - } - - /** The IV for the output channel (i.e. input channel of the remote side). */ - public byte[] getOutputIv() { - return outIv; - } - - @VisibleForTesting - CryptoOutputStream createOutputStream(WritableByteChannel ch) throws IOException { - return new CryptoOutputStream(cipher, conf, ch, key, new IvParameterSpec(outIv)); - } - - @VisibleForTesting - CryptoInputStream createInputStream(ReadableByteChannel ch) throws IOException { - return new CryptoInputStream(cipher, conf, ch, key, new IvParameterSpec(inIv)); - } - - /** - * Add handlers to channel. - * - * @param ch the channel for adding handlers - * @throws IOException - */ - public void addToChannel(Channel ch) throws IOException { - ch.pipeline() - .addFirst(ENCRYPTION_HANDLER_NAME, new EncryptionHandler(this)) - .addFirst(DECRYPTION_HANDLER_NAME, new DecryptionHandler(this)); - } - - @VisibleForTesting - static class EncryptionHandler extends ChannelOutboundHandlerAdapter { - private final ByteArrayWritableChannel byteEncChannel; - private final CryptoOutputStream cos; - private final ByteArrayWritableChannel byteRawChannel; - private boolean isCipherValid; - - EncryptionHandler(TransportCipher cipher) throws IOException { - byteEncChannel = new ByteArrayWritableChannel(STREAM_BUFFER_SIZE); - cos = cipher.createOutputStream(byteEncChannel); - byteRawChannel = new ByteArrayWritableChannel(STREAM_BUFFER_SIZE); - isCipherValid = true; - } +import com.google.crypto.tink.subtle.Hex; +import com.google.crypto.tink.subtle.Hkdf; +import io.netty.channel.Channel; - @Override - public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) - throws Exception { - ctx.write(createEncryptedMessage(msg), promise); - } - - @VisibleForTesting - EncryptedMessage createEncryptedMessage(Object msg) { - return new EncryptedMessage(this, cos, msg, byteEncChannel, byteRawChannel); - } +import javax.crypto.spec.SecretKeySpec; +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.security.GeneralSecurityException; - @Override - public void close(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception { - try { - if (isCipherValid) { - cos.close(); - } - } finally { - super.close(ctx, promise); - } - } +interface TransportCipher { + String getKeyId() throws GeneralSecurityException; + void addToChannel(Channel channel) throws IOException, GeneralSecurityException; +} - /** - * SPARK-25535. Workaround for CRYPTO-141. Avoid further interaction with the underlying cipher - * after an error occurs. +class TransportCipherUtil { + /* + * This method is used for testing to verify key derivation. */ - void reportError() { - this.isCipherValid = false; - } - - boolean isCipherValid() { - return isCipherValid; - } - } - - private static class DecryptionHandler extends ChannelInboundHandlerAdapter { - private final CryptoInputStream cis; - private final ByteArrayReadableChannel byteChannel; - private boolean isCipherValid; - - DecryptionHandler(TransportCipher cipher) throws IOException { - byteChannel = new ByteArrayReadableChannel(); - cis = cipher.createInputStream(byteChannel); - isCipherValid = true; - } - - @Override - public void channelRead(ChannelHandlerContext ctx, Object data) throws Exception { - ByteBuf buffer = (ByteBuf) data; - - try { - if (!isCipherValid) { - throw new IOException("Cipher is in invalid state."); - } - byte[] decryptedData = new byte[buffer.readableBytes()]; - byteChannel.feedData(buffer); - - int offset = 0; - while (offset < decryptedData.length) { - // SPARK-25535: workaround for CRYPTO-141. - try { - offset += cis.read(decryptedData, offset, decryptedData.length - offset); - } catch (InternalError ie) { - isCipherValid = false; - throw ie; - } - } - - ctx.fireChannelRead(Unpooled.wrappedBuffer(decryptedData, 0, decryptedData.length)); - } finally { - buffer.release(); - } - } - - @Override - public void handlerRemoved(ChannelHandlerContext ctx) throws Exception { - // We do the closing of the stream / channel in handlerRemoved(...) as - // this method will be called in all cases: - // - // - when the Channel becomes inactive - // - when the handler is removed from the ChannelPipeline - try { - if (isCipherValid) { - cis.close(); - } - } finally { - super.handlerRemoved(ctx); - } - } - } - - @VisibleForTesting - static class EncryptedMessage extends AbstractFileRegion { - private final boolean isByteBuf; - private final ByteBuf buf; - private final FileRegion region; - private final CryptoOutputStream cos; - private final EncryptionHandler handler; - private final long count; - private long transferred; - - // Due to streaming issue CRYPTO-125: https://issues.apache.org/jira/browse/CRYPTO-125, it has - // to utilize two helper ByteArrayWritableChannel for streaming. One is used to receive raw data - // from upper handler, another is used to store encrypted data. - private final ByteArrayWritableChannel byteEncChannel; - private final ByteArrayWritableChannel byteRawChannel; - - private ByteBuffer currentEncrypted; - - EncryptedMessage( - EncryptionHandler handler, - CryptoOutputStream cos, - Object msg, - ByteArrayWritableChannel byteEncChannel, - ByteArrayWritableChannel byteRawChannel) { - Preconditions.checkArgument(msg instanceof ByteBuf || msg instanceof FileRegion, - "Unrecognized message type: %s", msg.getClass().getName()); - this.handler = handler; - this.isByteBuf = msg instanceof ByteBuf; - this.buf = isByteBuf ? (ByteBuf) msg : null; - this.region = isByteBuf ? null : (FileRegion) msg; - this.transferred = 0; - this.cos = cos; - this.byteEncChannel = byteEncChannel; - this.byteRawChannel = byteRawChannel; - this.count = isByteBuf ? buf.readableBytes() : region.count(); - } - - @Override - public long count() { - return count; - } - - @Override - public long position() { - return 0; - } - - @Override - public long transferred() { - return transferred; - } - - @Override - public EncryptedMessage touch(Object o) { - super.touch(o); - if (region != null) { - region.touch(o); - } - if (buf != null) { - buf.touch(o); - } - return this; - } - - @Override - public EncryptedMessage retain(int increment) { - super.retain(increment); - if (region != null) { - region.retain(increment); - } - if (buf != null) { - buf.retain(increment); - } - return this; - } - - @Override - public boolean release(int decrement) { - if (region != null) { - region.release(decrement); - } - if (buf != null) { - buf.release(decrement); - } - return super.release(decrement); - } - - @Override - public long transferTo(WritableByteChannel target, long position) throws IOException { - Preconditions.checkArgument(position == transferred(), "Invalid position."); - - if (transferred == count) { - return 0; - } - - long totalBytesWritten = 0L; - do { - if (currentEncrypted == null) { - encryptMore(); - } - - long remaining = currentEncrypted.remaining(); - if (remaining == 0) { - // Just for safety to avoid endless loop. It usually won't happen, but since the - // underlying `region.transferTo` is allowed to transfer 0 bytes, we should handle it for - // safety. - currentEncrypted = null; - byteEncChannel.reset(); - return totalBytesWritten; - } - - long bytesWritten = target.write(currentEncrypted); - totalBytesWritten += bytesWritten; - transferred += bytesWritten; - if (bytesWritten < remaining) { - // break as the underlying buffer in "target" is full - break; - } - currentEncrypted = null; - byteEncChannel.reset(); - } while (transferred < count); - - return totalBytesWritten; - } - - private void encryptMore() throws IOException { - if (!handler.isCipherValid()) { - throw new IOException("Cipher is in invalid state."); - } - byteRawChannel.reset(); - - if (isByteBuf) { - int copied = byteRawChannel.write(buf.nioBuffer()); - buf.skipBytes(copied); - } else { - region.transferTo(byteRawChannel, region.transferred()); - } - - try { - cos.write(byteRawChannel.getData(), 0, byteRawChannel.length()); - cos.flush(); - } catch (InternalError ie) { - handler.reportError(); - throw ie; - } - - currentEncrypted = ByteBuffer.wrap(byteEncChannel.getData(), - 0, byteEncChannel.length()); - } - - @Override - protected void deallocate() { - byteRawChannel.reset(); - byteEncChannel.reset(); - if (region != null) { - region.release(); - } - if (buf != null) { - buf.release(); - } + @VisibleForTesting + static String getKeyId(SecretKeySpec key) throws GeneralSecurityException { + byte[] keyIdBytes = Hkdf.computeHkdf("HmacSha256", + key.getEncoded(), + null, + "keyID".getBytes(StandardCharsets.UTF_8), + 32); + return Hex.encode(keyIdBytes); } - } - } diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/ByteBufferWriteableChannel.java b/common/network-common/src/main/java/org/apache/spark/network/util/ByteBufferWriteableChannel.java new file mode 100644 index 0000000000000..b20240cfcaa6d --- /dev/null +++ b/common/network-common/src/main/java/org/apache/spark/network/util/ByteBufferWriteableChannel.java @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.util; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.channels.ClosedChannelException; +import java.nio.channels.WritableByteChannel; + +public class ByteBufferWriteableChannel implements WritableByteChannel { + private final ByteBuffer destination; + private boolean open; + + public ByteBufferWriteableChannel(ByteBuffer destination) { + this.destination = destination; + this.open = true; + } + + @Override + public int write(ByteBuffer src) throws IOException { + if (!isOpen()) { + throw new ClosedChannelException(); + } + int bytesToWrite = Math.min(src.remaining(), destination.remaining()); + // Destination buffer is full + if (bytesToWrite == 0) { + return 0; + } + ByteBuffer temp = src.slice().limit(bytesToWrite); + destination.put(temp); + src.position(src.position() + bytesToWrite); + return bytesToWrite; + } + + @Override + public boolean isOpen() { + return open; + } + + @Override + public void close() { + open = false; + } +} diff --git a/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthEngineSuite.java b/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthEngineSuite.java index 971e3ef2ff98c..ad737e5332dd4 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthEngineSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthEngineSuite.java @@ -18,75 +18,76 @@ package org.apache.spark.network.crypto; import java.nio.ByteBuffer; -import java.nio.channels.WritableByteChannel; import java.security.GeneralSecurityException; -import java.util.Collections; -import java.util.Random; +import java.util.Map; +import com.google.common.collect.ImmutableMap; import com.google.crypto.tink.subtle.Hex; -import io.netty.buffer.ByteBuf; -import io.netty.buffer.Unpooled; -import io.netty.channel.FileRegion; -import org.apache.spark.network.util.ByteArrayWritableChannel; -import org.apache.spark.network.util.ConfigProvider; -import org.apache.spark.network.util.MapConfigProvider; -import org.apache.spark.network.util.TransportConf; +import org.apache.spark.network.util.*; + import static org.junit.Assert.*; -import org.junit.BeforeClass; import org.junit.Test; -import static org.mockito.Mockito.*; -import org.mockito.invocation.InvocationOnMock; -import org.mockito.stubbing.Answer; - -public class AuthEngineSuite { - - private static final String clientPrivate = - "efe6b68b3fce92158e3637f6ef9d937e75558928dd4b401de04b43d300a73186"; - private static final String clientChallengeHex = - "fb00000005617070496400000010890b6e960f48e998777267a7e4e623220000003c48ad7dc7ec9466da9" + - "3bda9f11488dc9404050e02c661d87d67c782444944c6e369b27e0a416c30845a2d9e64271511ca98b41d" + - "65f8c426e18ff380f6"; - private static final String serverResponseHex = - "fb00000005617070496400000010708451c9dd2792c97c1ca66e6df449ef0000003c64fe899ecdaf458d4" + - "e25e9d5c5a380b8e6d1a184692fac065ed84f8592c18e9629f9c636809dca2ffc041f20346eb53db78738" + - "08ecad08b46b5ee3ff"; - private static final String derivedKey = "2d6e7a9048c8265c33a8f3747bfcc84c"; +abstract class AuthEngineSuite { + static final String clientPrivate = + "efe6b68b3fce92158e3637f6ef9d937e75558928dd4b401de04b43d300a73186"; + static final String clientChallengeHex = + "fb00000005617070496400000010890b6e960f48e998777267a7e4e623220000003c48ad7dc7ec9466da9" + + "3bda9f11488dc9404050e02c661d87d67c782444944c6e369b27e0a416c30845a2d9e64271511ca98b41d" + + "65f8c426e18ff380f6"; + static final String serverResponseHex = + "fb00000005617070496400000010708451c9dd2792c97c1ca66e6df449ef0000003c64fe899ecdaf458d4" + + "e25e9d5c5a380b8e6d1a184692fac065ed84f8592c18e9629f9c636809dca2ffc041f20346eb53db78738" + + "08ecad08b46b5ee3ff"; + static final String derivedKeyId = + "de04fd52d71040ed9d260579dacfdf4f5695f991ce8ddb1dde05a7335880906e"; // This key would have been derived for version 1.0 protocol that did not run a final HKDF round. - private static final String unsafeDerivedKey = - "31963f15a320d5c90333f7ecf5cf3a31c7eaf151de07fef8494663a9f47cfd31"; - private static final String inputIv = "fc6a5dc8b90a9dad8f54f08b51a59ed2"; - private static final String outputIv = "a72709baf00785cad6329ce09f631f71"; - private static TransportConf conf; - - @BeforeClass - public static void setUp() { - ConfigProvider v2Provider = new MapConfigProvider(Collections.singletonMap( - "spark.network.crypto.authEngineVersion", "2")); - conf = new TransportConf("rpc", v2Provider); + static final String unsafeDerivedKey = + "31963f15a320d5c90333f7ecf5cf3a31c7eaf151de07fef8494663a9f47cfd31"; + static TransportConf conf; + + static TransportConf getConf(int authEngineVerison, boolean useCtr) { + String authEngineVersion = (authEngineVerison == 1) ? "1" : "2"; + String mode = useCtr ? "AES/CTR/NoPadding" : "AES/GCM/NoPadding"; + Map confMap = ImmutableMap.of( + "spark.network.crypto.enabled", "true", + "spark.network.crypto.authEngineVersion", authEngineVersion, + "spark.network.crypto.cipher", mode + ); + ConfigProvider v2Provider = new MapConfigProvider(confMap); + return new TransportConf("rpc", v2Provider); } @Test public void testAuthEngine() throws Exception { - try (AuthEngine client = new AuthEngine("appId", "secret", conf); AuthEngine server = new AuthEngine("appId", "secret", conf)) { AuthMessage clientChallenge = client.challenge(); AuthMessage serverResponse = server.response(clientChallenge); client.deriveSessionCipher(clientChallenge, serverResponse); - TransportCipher serverCipher = server.sessionCipher(); TransportCipher clientCipher = client.sessionCipher(); + assertEquals(clientCipher.getKeyId(), serverCipher.getKeyId()); + } + } - assertArrayEquals(serverCipher.getInputIv(), clientCipher.getOutputIv()); - assertArrayEquals(serverCipher.getOutputIv(), clientCipher.getInputIv()); - assertEquals(serverCipher.getKey(), clientCipher.getKey()); + @Test + public void testFixedChallengeResponse() throws Exception { + try (AuthEngine client = new AuthEngine("appId", "secret", conf)) { + byte[] clientPrivateKey = Hex.decode(clientPrivate); + client.setClientPrivateKey(clientPrivateKey); + AuthMessage clientChallenge = + AuthMessage.decodeMessage(ByteBuffer.wrap(Hex.decode(clientChallengeHex))); + AuthMessage serverResponse = + AuthMessage.decodeMessage(ByteBuffer.wrap(Hex.decode(serverResponseHex))); + // Verify that the client will accept an old transcript. + client.deriveSessionCipher(clientChallenge, serverResponse); + assertEquals(client.sessionCipher().getKeyId(), derivedKeyId); } } @Test public void testCorruptChallengeAppId() throws Exception { - try (AuthEngine client = new AuthEngine("appId", "secret", conf); AuthEngine server = new AuthEngine("appId", "secret", conf)) { AuthMessage clientChallenge = client.challenge(); @@ -98,7 +99,6 @@ public void testCorruptChallengeAppId() throws Exception { @Test public void testCorruptChallengeSalt() throws Exception { - try (AuthEngine client = new AuthEngine("appId", "secret", conf); AuthEngine server = new AuthEngine("appId", "secret", conf)) { AuthMessage clientChallenge = client.challenge(); @@ -109,7 +109,6 @@ public void testCorruptChallengeSalt() throws Exception { @Test public void testCorruptChallengeCiphertext() throws Exception { - try (AuthEngine client = new AuthEngine("appId", "secret", conf); AuthEngine server = new AuthEngine("appId", "secret", conf)) { AuthMessage clientChallenge = client.challenge(); @@ -120,7 +119,6 @@ public void testCorruptChallengeCiphertext() throws Exception { @Test public void testCorruptResponseAppId() throws Exception { - try (AuthEngine client = new AuthEngine("appId", "secret", conf); AuthEngine server = new AuthEngine("appId", "secret", conf)) { AuthMessage clientChallenge = client.challenge(); @@ -134,20 +132,18 @@ public void testCorruptResponseAppId() throws Exception { @Test public void testCorruptResponseSalt() throws Exception { - try (AuthEngine client = new AuthEngine("appId", "secret", conf); AuthEngine server = new AuthEngine("appId", "secret", conf)) { AuthMessage clientChallenge = client.challenge(); AuthMessage serverResponse = server.response(clientChallenge); serverResponse.salt[0] ^= 1; assertThrows(GeneralSecurityException.class, - () -> client.deriveSessionCipher(clientChallenge, serverResponse)); + () -> client.deriveSessionCipher(clientChallenge, serverResponse)); } } @Test public void testCorruptServerCiphertext() throws Exception { - try (AuthEngine client = new AuthEngine("appId", "secret", conf); AuthEngine server = new AuthEngine("appId", "secret", conf)) { AuthMessage clientChallenge = client.challenge(); @@ -169,45 +165,6 @@ public void testFixedChallenge() throws Exception { } } - @Test - public void testFixedChallengeResponse() throws Exception { - try (AuthEngine client = new AuthEngine("appId", "secret", conf)) { - byte[] clientPrivateKey = Hex.decode(clientPrivate); - client.setClientPrivateKey(clientPrivateKey); - AuthMessage clientChallenge = - AuthMessage.decodeMessage(ByteBuffer.wrap(Hex.decode(clientChallengeHex))); - AuthMessage serverResponse = - AuthMessage.decodeMessage(ByteBuffer.wrap(Hex.decode(serverResponseHex))); - // Verify that the client will accept an old transcript. - client.deriveSessionCipher(clientChallenge, serverResponse); - TransportCipher clientCipher = client.sessionCipher(); - assertEquals(Hex.encode(clientCipher.getKey().getEncoded()), derivedKey); - assertEquals(Hex.encode(clientCipher.getInputIv()), inputIv); - assertEquals(Hex.encode(clientCipher.getOutputIv()), outputIv); - } - } - - @Test - public void testFixedChallengeResponseUnsafeVersion() throws Exception { - ConfigProvider v1Provider = new MapConfigProvider(Collections.singletonMap( - "spark.network.crypto.authEngineVersion", "1")); - TransportConf v1Conf = new TransportConf("rpc", v1Provider); - try (AuthEngine client = new AuthEngine("appId", "secret", v1Conf)) { - byte[] clientPrivateKey = Hex.decode(clientPrivate); - client.setClientPrivateKey(clientPrivateKey); - AuthMessage clientChallenge = - AuthMessage.decodeMessage(ByteBuffer.wrap(Hex.decode(clientChallengeHex))); - AuthMessage serverResponse = - AuthMessage.decodeMessage(ByteBuffer.wrap(Hex.decode(serverResponseHex))); - // Verify that the client will accept an old transcript. - client.deriveSessionCipher(clientChallenge, serverResponse); - TransportCipher clientCipher = client.sessionCipher(); - assertEquals(Hex.encode(clientCipher.getKey().getEncoded()), unsafeDerivedKey); - assertEquals(Hex.encode(clientCipher.getInputIv()), inputIv); - assertEquals(Hex.encode(clientCipher.getOutputIv()), outputIv); - } - } - @Test public void testMismatchedSecret() throws Exception { try (AuthEngine client = new AuthEngine("appId", "secret", conf); @@ -216,70 +173,4 @@ public void testMismatchedSecret() throws Exception { assertThrows(GeneralSecurityException.class, () -> server.response(clientChallenge)); } } - - @Test - public void testEncryptedMessage() throws Exception { - try (AuthEngine client = new AuthEngine("appId", "secret", conf); - AuthEngine server = new AuthEngine("appId", "secret", conf)) { - AuthMessage clientChallenge = client.challenge(); - AuthMessage serverResponse = server.response(clientChallenge); - client.deriveSessionCipher(clientChallenge, serverResponse); - - TransportCipher cipher = server.sessionCipher(); - TransportCipher.EncryptionHandler handler = new TransportCipher.EncryptionHandler(cipher); - - byte[] data = new byte[TransportCipher.STREAM_BUFFER_SIZE + 1]; - new Random().nextBytes(data); - ByteBuf buf = Unpooled.wrappedBuffer(data); - - ByteArrayWritableChannel channel = new ByteArrayWritableChannel(data.length); - TransportCipher.EncryptedMessage emsg = handler.createEncryptedMessage(buf); - while (emsg.transferred() < emsg.count()) { - emsg.transferTo(channel, emsg.transferred()); - } - assertEquals(data.length, channel.length()); - } - } - - @Test - public void testEncryptedMessageWhenTransferringZeroBytes() throws Exception { - try (AuthEngine client = new AuthEngine("appId", "secret", conf); - AuthEngine server = new AuthEngine("appId", "secret", conf)) { - AuthMessage clientChallenge = client.challenge(); - AuthMessage serverResponse = server.response(clientChallenge); - client.deriveSessionCipher(clientChallenge, serverResponse); - - TransportCipher cipher = server.sessionCipher(); - TransportCipher.EncryptionHandler handler = new TransportCipher.EncryptionHandler(cipher); - - int testDataLength = 4; - FileRegion region = mock(FileRegion.class); - when(region.count()).thenReturn((long) testDataLength); - // Make `region.transferTo` do nothing in first call and transfer 4 bytes in the second one. - when(region.transferTo(any(), anyLong())).thenAnswer(new Answer() { - - private boolean firstTime = true; - - @Override - public Long answer(InvocationOnMock invocationOnMock) throws Throwable { - if (firstTime) { - firstTime = false; - return 0L; - } else { - WritableByteChannel channel = invocationOnMock.getArgument(0); - channel.write(ByteBuffer.wrap(new byte[testDataLength])); - return (long) testDataLength; - } - } - }); - - TransportCipher.EncryptedMessage emsg = handler.createEncryptedMessage(region); - ByteArrayWritableChannel channel = new ByteArrayWritableChannel(testDataLength); - // "transferTo" should act correctly when the underlying FileRegion transfers 0 bytes. - assertEquals(0L, emsg.transferTo(channel, emsg.transferred())); - assertEquals(testDataLength, emsg.transferTo(channel, emsg.transferred())); - assertEquals(emsg.transferred(), emsg.count()); - assertEquals(4, channel.length()); - } - } } diff --git a/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthIntegrationSuite.java b/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthIntegrationSuite.java index 4a5b426b1158a..ad8bbdb4c2655 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthIntegrationSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthIntegrationSuite.java @@ -49,7 +49,7 @@ public class AuthIntegrationSuite { private AuthTestCtx ctx; @After - public void cleanUp() throws Exception { + public void cleanUp() { if (ctx != null) { ctx.close(); } @@ -57,8 +57,8 @@ public void cleanUp() throws Exception { } @Test - public void testNewAuth() throws Exception { - ctx = new AuthTestCtx(); + public void testNewCtrAuth() throws Exception { + ctx = new AuthTestCtx(new DummyRpcHandler(), "AES/CTR/NoPadding"); ctx.createServer("secret"); ctx.createClient("secret"); @@ -68,8 +68,28 @@ public void testNewAuth() throws Exception { } @Test - public void testAuthFailure() throws Exception { - ctx = new AuthTestCtx(); + public void testNewGcmAuth() throws Exception { + ctx = new AuthTestCtx(new DummyRpcHandler(), "AES/GCM/NoPadding"); + ctx.createServer("secret"); + ctx.createClient("secret"); + ByteBuffer reply = ctx.client.sendRpcSync(JavaUtils.stringToBytes("Ping"), 5000); + assertEquals("Pong", JavaUtils.bytesToString(reply)); + assertNull(ctx.authRpcHandler.saslHandler); + } + + @Test + public void testCtrAuthFailure() throws Exception { + ctx = new AuthTestCtx(new DummyRpcHandler(), "AES/CTR/NoPadding"); + ctx.createServer("server"); + + assertThrows(Exception.class, () -> ctx.createClient("client")); + assertFalse(ctx.authRpcHandler.isAuthenticated()); + assertFalse(ctx.serverChannel.isActive()); + } + + @Test + public void testGcmAuthFailure() throws Exception { + ctx = new AuthTestCtx(new DummyRpcHandler(), "AES/GCM/NoPadding"); ctx.createServer("server"); assertThrows(Exception.class, () -> ctx.createClient("client")); @@ -100,7 +120,7 @@ public void testSaslClientFallback() throws Exception { } @Test - public void testAuthReplay() throws Exception { + public void testCtrAuthReplay() throws Exception { // This test covers the case where an attacker replays a challenge message sniffed from the // network, but doesn't know the actual secret. The server should close the connection as // soon as a message is sent after authentication is performed. This is emulated by removing @@ -110,16 +130,16 @@ public void testAuthReplay() throws Exception { ctx.createClient("secret"); assertNotNull(ctx.client.getChannel().pipeline() - .remove(TransportCipher.ENCRYPTION_HANDLER_NAME)); + .remove(CtrTransportCipher.ENCRYPTION_HANDLER_NAME)); assertThrows(Exception.class, () -> ctx.client.sendRpcSync(JavaUtils.stringToBytes("Ping"), 5000)); assertTrue(ctx.authRpcHandler.isAuthenticated()); } @Test - public void testLargeMessageEncryption() throws Exception { + public void testLargeCtrMessageEncryption() throws Exception { // Use a big length to create a message that cannot be put into the encryption buffer completely - final int testErrorMessageLength = TransportCipher.STREAM_BUFFER_SIZE; + final int testErrorMessageLength = CtrTransportCipher.STREAM_BUFFER_SIZE; ctx = new AuthTestCtx(new RpcHandler() { @Override public void receive( @@ -157,6 +177,23 @@ public void testValidMergedBlockMetaReqHandler() throws Exception { assertNotNull(ctx.authRpcHandler.getMergedBlockMetaReqHandler()); } + private static class DummyRpcHandler extends RpcHandler { + @Override + public void receive( + TransportClient client, + ByteBuffer message, + RpcResponseCallback callback) { + String messageString = JavaUtils.bytesToString(message); + assertEquals("Ping", messageString); + callback.onSuccess(JavaUtils.stringToBytes("Pong")); + } + + @Override + public StreamManager getStreamManager() { + return null; + } + } + private static class AuthTestCtx { private final String appId = "testAppId"; @@ -169,25 +206,17 @@ private static class AuthTestCtx { volatile AuthRpcHandler authRpcHandler; AuthTestCtx() throws Exception { - this(new RpcHandler() { - @Override - public void receive( - TransportClient client, - ByteBuffer message, - RpcResponseCallback callback) { - assertEquals("Ping", JavaUtils.bytesToString(message)); - callback.onSuccess(JavaUtils.stringToBytes("Pong")); - } - - @Override - public StreamManager getStreamManager() { - return null; - } - }); + this(new DummyRpcHandler()); } AuthTestCtx(RpcHandler rpcHandler) throws Exception { - Map testConf = ImmutableMap.of("spark.network.crypto.enabled", "true"); + this(rpcHandler, "AES/CTR/NoPadding"); + } + + AuthTestCtx(RpcHandler rpcHandler, String mode) throws Exception { + Map testConf = ImmutableMap.of( + "spark.network.crypto.enabled", "true", + "spark.network.crypto.cipher", mode); this.conf = new TransportConf("rpc", new MapConfigProvider(testConf)); this.ctx = new TransportContext(conf, rpcHandler); } diff --git a/common/network-common/src/test/java/org/apache/spark/network/crypto/CtrAuthEngineSuite.java b/common/network-common/src/test/java/org/apache/spark/network/crypto/CtrAuthEngineSuite.java new file mode 100644 index 0000000000000..dcec2f17be532 --- /dev/null +++ b/common/network-common/src/test/java/org/apache/spark/network/crypto/CtrAuthEngineSuite.java @@ -0,0 +1,178 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.crypto; + +import com.google.crypto.tink.subtle.Hex; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.FileRegion; +import org.apache.spark.network.util.ByteArrayWritableChannel; +import org.apache.spark.network.util.TransportConf; +import org.junit.Before; +import org.junit.Test; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; + +import java.nio.ByteBuffer; +import java.nio.channels.WritableByteChannel; +import java.util.Random; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.mockito.Mockito.*; + +public class CtrAuthEngineSuite extends AuthEngineSuite { + private static final String inputIv = "fc6a5dc8b90a9dad8f54f08b51a59ed2"; + private static final String outputIv = "a72709baf00785cad6329ce09f631f71"; + + @Before + public void setUp() { + conf = getConf(2, true); + } + + @Test + public void testAuthEngine() throws Exception { + try (AuthEngine client = new AuthEngine("appId", "secret", conf); + AuthEngine server = new AuthEngine("appId", "secret", conf)) { + AuthMessage clientChallenge = client.challenge(); + AuthMessage serverResponse = server.response(clientChallenge); + client.deriveSessionCipher(clientChallenge, serverResponse); + + TransportCipher serverCipher = server.sessionCipher(); + TransportCipher clientCipher = client.sessionCipher(); + assert(clientCipher instanceof CtrTransportCipher); + assert(serverCipher instanceof CtrTransportCipher); + CtrTransportCipher ctrClient = (CtrTransportCipher) clientCipher; + CtrTransportCipher ctrServer = (CtrTransportCipher) serverCipher; + assertArrayEquals(ctrServer.getInputIv(), ctrClient.getOutputIv()); + assertArrayEquals(ctrServer.getOutputIv(), ctrClient.getInputIv()); + assertEquals(ctrServer.getKey(), ctrClient.getKey()); + } + } + + @Test + public void testCtrFixedChallengeIvResponse() throws Exception { + try (AuthEngine client = new AuthEngine("appId", "secret", conf)) { + byte[] clientPrivateKey = Hex.decode(clientPrivate); + client.setClientPrivateKey(clientPrivateKey); + AuthMessage clientChallenge = + AuthMessage.decodeMessage(ByteBuffer.wrap(Hex.decode(clientChallengeHex))); + AuthMessage serverResponse = + AuthMessage.decodeMessage(ByteBuffer.wrap(Hex.decode(serverResponseHex))); + // Verify that the client will accept an old transcript. + client.deriveSessionCipher(clientChallenge, serverResponse); + TransportCipher clientCipher = client.sessionCipher(); + assertEquals(clientCipher.getKeyId(), derivedKeyId); + assert(clientCipher instanceof CtrTransportCipher); + CtrTransportCipher ctrTransportCipher = (CtrTransportCipher) clientCipher; + assertEquals(Hex.encode(ctrTransportCipher.getInputIv()), inputIv); + assertEquals(Hex.encode(ctrTransportCipher.getOutputIv()), outputIv); + } + } + + @Test + public void testFixedChallengeResponseUnsafeVersion() throws Exception { + TransportConf v1Conf = getConf(1, true); + try (AuthEngine client = new AuthEngine("appId", "secret", v1Conf)) { + byte[] clientPrivateKey = Hex.decode(clientPrivate); + client.setClientPrivateKey(clientPrivateKey); + AuthMessage clientChallenge = + AuthMessage.decodeMessage(ByteBuffer.wrap(Hex.decode(clientChallengeHex))); + AuthMessage serverResponse = + AuthMessage.decodeMessage(ByteBuffer.wrap(Hex.decode(serverResponseHex))); + // Verify that the client will accept an old transcript. + client.deriveSessionCipher(clientChallenge, serverResponse); + TransportCipher clientCipher = client.sessionCipher(); + assert(clientCipher instanceof CtrTransportCipher); + CtrTransportCipher ctrTransportCipher = (CtrTransportCipher) clientCipher; + assertEquals(Hex.encode(ctrTransportCipher.getKey().getEncoded()), unsafeDerivedKey); + assertEquals(Hex.encode(ctrTransportCipher.getInputIv()), inputIv); + assertEquals(Hex.encode(ctrTransportCipher.getOutputIv()), outputIv); + } + } + + @Test + public void testCtrEncryptedMessage() throws Exception { + try (AuthEngine client = new AuthEngine("appId", "secret", conf); + AuthEngine server = new AuthEngine("appId", "secret", conf)) { + AuthMessage clientChallenge = client.challenge(); + AuthMessage serverResponse = server.response(clientChallenge); + client.deriveSessionCipher(clientChallenge, serverResponse); + + TransportCipher clientCipher = server.sessionCipher(); + assert(clientCipher instanceof CtrTransportCipher); + CtrTransportCipher ctrTransportCipher = (CtrTransportCipher) clientCipher; + CtrTransportCipher.EncryptionHandler handler = + new CtrTransportCipher.EncryptionHandler(ctrTransportCipher); + + byte[] data = new byte[CtrTransportCipher.STREAM_BUFFER_SIZE + 1]; + new Random().nextBytes(data); + ByteBuf buf = Unpooled.wrappedBuffer(data); + + ByteArrayWritableChannel channel = new ByteArrayWritableChannel(data.length); + CtrTransportCipher.EncryptedMessage emsg = handler.createEncryptedMessage(buf); + while (emsg.transferred() < emsg.count()) { + emsg.transferTo(channel, emsg.transferred()); + } + assertEquals(data.length, channel.length()); + } + } + + @Test + public void testCtrEncryptedMessageWhenTransferringZeroBytes() throws Exception { + try (AuthEngine client = new AuthEngine("appId", "secret", conf); + AuthEngine server = new AuthEngine("appId", "secret", conf)) { + AuthMessage clientChallenge = client.challenge(); + AuthMessage serverResponse = server.response(clientChallenge); + client.deriveSessionCipher(clientChallenge, serverResponse); + TransportCipher clientCipher = server.sessionCipher(); + assert(clientCipher instanceof CtrTransportCipher); + CtrTransportCipher ctrTransportCipher = (CtrTransportCipher) clientCipher; + CtrTransportCipher.EncryptionHandler handler = + new CtrTransportCipher.EncryptionHandler(ctrTransportCipher); + int testDataLength = 4; + FileRegion region = mock(FileRegion.class); + when(region.count()).thenReturn((long) testDataLength); + // Make `region.transferTo` do nothing in first call and transfer 4 bytes in the second one. + when(region.transferTo(any(), anyLong())).thenAnswer(new Answer() { + + private boolean firstTime = true; + + @Override + public Long answer(InvocationOnMock invocationOnMock) throws Throwable { + if (firstTime) { + firstTime = false; + return 0L; + } else { + WritableByteChannel channel = invocationOnMock.getArgument(0); + channel.write(ByteBuffer.wrap(new byte[testDataLength])); + return (long) testDataLength; + } + } + }); + + CtrTransportCipher.EncryptedMessage emsg = handler.createEncryptedMessage(region); + ByteArrayWritableChannel channel = new ByteArrayWritableChannel(testDataLength); + // "transferTo" should act correctly when the underlying FileRegion transfers 0 bytes. + assertEquals(0L, emsg.transferTo(channel, emsg.transferred())); + assertEquals(testDataLength, emsg.transferTo(channel, emsg.transferred())); + assertEquals(emsg.transferred(), emsg.count()); + assertEquals(4, channel.length()); + } + } +} diff --git a/common/network-common/src/test/java/org/apache/spark/network/crypto/GcmAuthEngineSuite.java b/common/network-common/src/test/java/org/apache/spark/network/crypto/GcmAuthEngineSuite.java new file mode 100644 index 0000000000000..19e9eb41a1ef8 --- /dev/null +++ b/common/network-common/src/test/java/org/apache/spark/network/crypto/GcmAuthEngineSuite.java @@ -0,0 +1,342 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.crypto; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelPromise; +import org.apache.spark.network.util.AbstractFileRegion; +import org.apache.spark.network.util.ByteBufferWriteableChannel; +import org.apache.spark.network.util.TransportConf; +import org.junit.Before; +import org.junit.Test; +import org.mockito.ArgumentCaptor; + +import javax.crypto.AEADBadTagException; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.channels.WritableByteChannel; +import java.util.Arrays; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; +import static org.mockito.Mockito.*; + +public class GcmAuthEngineSuite extends AuthEngineSuite { + + @Before + public void setUp() { + // Uses GCM mode + conf = getConf(2, false); + } + + @Test + public void testGcmEncryptedMessage() throws Exception { + TransportConf gcmConf = getConf(2, false); + try (AuthEngine client = new AuthEngine("appId", "secret", gcmConf); + AuthEngine server = new AuthEngine("appId", "secret", gcmConf)) { + AuthMessage clientChallenge = client.challenge(); + AuthMessage serverResponse = server.response(clientChallenge); + client.deriveSessionCipher(clientChallenge, serverResponse); + TransportCipher clientCipher = server.sessionCipher(); + // Verify that it derives a GcmTransportCipher + assert (clientCipher instanceof GcmTransportCipher); + GcmTransportCipher gcmTransportCipher = (GcmTransportCipher) clientCipher; + GcmTransportCipher.EncryptionHandler encryptionHandler = + gcmTransportCipher.getEncryptionHandler(); + GcmTransportCipher.DecryptionHandler decryptionHandler = + gcmTransportCipher.getDecryptionHandler(); + // Allocating 1.5x the buffer size to test multiple segments and a fractional segment. + int plaintextSegmentSize = GcmTransportCipher.CIPHERTEXT_BUFFER_SIZE - 16; + byte[] data = new byte[plaintextSegmentSize + (plaintextSegmentSize / 2)]; + // Just writing some bytes. + data[0] = 'a'; + data[data.length / 2] = 'b'; + data[data.length - 10] = 'c'; + ByteBuf buf = Unpooled.wrappedBuffer(data); + + // Mock the context and capture the arguments passed to it + ChannelHandlerContext ctx = mock(ChannelHandlerContext.class); + ChannelPromise promise = mock(ChannelPromise.class); + ArgumentCaptor captorWrappedEncrypted = + ArgumentCaptor.forClass(GcmTransportCipher.GcmEncryptedMessage.class); + encryptionHandler.write(ctx, buf, promise); + verify(ctx).write(captorWrappedEncrypted.capture(), eq(promise)); + + // Get the encrypted value and pass it to the decryption handler + GcmTransportCipher.GcmEncryptedMessage encrypted = + captorWrappedEncrypted.getValue(); + ByteBuffer ciphertextBuffer = + ByteBuffer.allocate((int) encrypted.count()); + ByteBufferWriteableChannel channel = + new ByteBufferWriteableChannel(ciphertextBuffer); + encrypted.transferTo(channel, 0); + ciphertextBuffer.flip(); + ByteBuf ciphertext = Unpooled.wrappedBuffer(ciphertextBuffer); + + // Capture the decrypted values and verify them + ArgumentCaptor captorPlaintext = ArgumentCaptor.forClass(ByteBuf.class); + decryptionHandler.channelRead(ctx, ciphertext); + verify(ctx, times(2)) + .fireChannelRead(captorPlaintext.capture()); + ByteBuf lastPlaintextSegment = captorPlaintext.getValue(); + assertEquals(plaintextSegmentSize/2, + lastPlaintextSegment.readableBytes()); + assertEquals('c', + lastPlaintextSegment.getByte((plaintextSegmentSize/2) - 10)); + } + } + + static class FakeRegion extends AbstractFileRegion { + private final ByteBuffer[] source; + private int sourcePosition; + private final long count; + + FakeRegion(ByteBuffer... source) { + this.source = source; + sourcePosition = 0; + count = remaining(); + } + + private long remaining() { + long remaining = 0; + for (ByteBuffer buffer : source) { + remaining += buffer.remaining(); + } + return remaining; + } + + @Override + public long position() { + return 0; + } + + @Override + public long transferred() { + return count - remaining(); + } + + @Override + public long count() { + return count; + } + + @Override + public long transferTo(WritableByteChannel target, long position) throws IOException { + if (sourcePosition < source.length) { + ByteBuffer currentBuffer = source[sourcePosition]; + long written = target.write(currentBuffer); + if (!currentBuffer.hasRemaining()) { + sourcePosition++; + } + return written; + } else { + return 0; + } + } + + @Override + protected void deallocate() { + } + } + + private static ByteBuffer getTestByteBuf(int size, byte fill) { + byte[] data = new byte[size]; + Arrays.fill(data, fill); + return ByteBuffer.wrap(data); + } + + @Test + public void testGcmEncryptedMessageFileRegion() throws Exception { + TransportConf gcmConf = getConf(2, false); + try (AuthEngine client = new AuthEngine("appId", "secret", gcmConf); + AuthEngine server = new AuthEngine("appId", "secret", gcmConf)) { + AuthMessage clientChallenge = client.challenge(); + AuthMessage serverResponse = server.response(clientChallenge); + client.deriveSessionCipher(clientChallenge, serverResponse); + TransportCipher clientCipher = server.sessionCipher(); + // Verify that it derives a GcmTransportCipher + assert (clientCipher instanceof GcmTransportCipher); + GcmTransportCipher gcmTransportCipher = (GcmTransportCipher) clientCipher; + GcmTransportCipher.EncryptionHandler encryptionHandler = + gcmTransportCipher.getEncryptionHandler(); + GcmTransportCipher.DecryptionHandler decryptionHandler = + gcmTransportCipher.getDecryptionHandler(); + // Allocating 1.5x the buffer size to test multiple segments and a fractional segment. + int plaintextSegmentSize = GcmTransportCipher.CIPHERTEXT_BUFFER_SIZE - 16; + int halfSegmentSize = plaintextSegmentSize / 2; + int totalSize = plaintextSegmentSize + halfSegmentSize; + + // Set up some fragmented segments to test + ByteBuffer halfSegment = getTestByteBuf(halfSegmentSize, (byte) 'a'); + int smallFragmentSize = 128; + ByteBuffer smallFragment = getTestByteBuf(smallFragmentSize, (byte) 'b'); + int remainderSize = totalSize - halfSegmentSize - smallFragmentSize; + ByteBuffer remainder = getTestByteBuf(remainderSize, (byte) 'c'); + FakeRegion fakeRegion = new FakeRegion(halfSegment, smallFragment, remainder); + assertEquals(totalSize, fakeRegion.count()); + + ChannelHandlerContext ctx = mock(ChannelHandlerContext.class); + ChannelPromise promise = mock(ChannelPromise.class); + ArgumentCaptor captorWrappedEncrypted = + ArgumentCaptor.forClass(GcmTransportCipher.GcmEncryptedMessage.class); + encryptionHandler.write(ctx, fakeRegion, promise); + verify(ctx).write(captorWrappedEncrypted.capture(), eq(promise)); + + // Get the encrypted value and pass it to the decryption handler + GcmTransportCipher.GcmEncryptedMessage encrypted = + captorWrappedEncrypted.getValue(); + ByteBuffer ciphertextBuffer = + ByteBuffer.allocate((int) encrypted.count()); + ByteBufferWriteableChannel channel = + new ByteBufferWriteableChannel(ciphertextBuffer); + + // We'll simulate the FileRegion only transferring half a segment. + // The encrypted message should buffer the partial segment plaintext. + long ciphertextTransferred = 0; + while (ciphertextTransferred < encrypted.count()) { + long chunkTransferred = encrypted.transferTo(channel, 0); + ciphertextTransferred += chunkTransferred; + } + assertEquals(encrypted.count(), ciphertextTransferred); + + ciphertextBuffer.flip(); + ByteBuf ciphertext = Unpooled.wrappedBuffer(ciphertextBuffer); + + // Capture the decrypted values and verify them + ArgumentCaptor captorPlaintext = ArgumentCaptor.forClass(ByteBuf.class); + decryptionHandler.channelRead(ctx, ciphertext); + verify(ctx, times(2)).fireChannelRead(captorPlaintext.capture()); + ByteBuf plaintext = captorPlaintext.getValue(); + // We expect this to be the last partial plaintext segment + int expectedLength = totalSize % plaintextSegmentSize; + assertEquals(expectedLength, plaintext.readableBytes()); + // This will be the "remainder" segment that is filled to 'c' + assertEquals('c', plaintext.getByte(0)); + } + } + + + @Test + public void testGcmUnalignedDecryption() throws Exception { + TransportConf gcmConf = getConf(2, false); + try (AuthEngine client = new AuthEngine("appId", "secret", gcmConf); + AuthEngine server = new AuthEngine("appId", "secret", gcmConf)) { + AuthMessage clientChallenge = client.challenge(); + AuthMessage serverResponse = server.response(clientChallenge); + client.deriveSessionCipher(clientChallenge, serverResponse); + TransportCipher clientCipher = server.sessionCipher(); + // Verify that it derives a GcmTransportCipher + assert (clientCipher instanceof GcmTransportCipher); + GcmTransportCipher gcmTransportCipher = (GcmTransportCipher) clientCipher; + GcmTransportCipher.EncryptionHandler encryptionHandler = + gcmTransportCipher.getEncryptionHandler(); + GcmTransportCipher.DecryptionHandler decryptionHandler = + gcmTransportCipher.getDecryptionHandler(); + // Allocating 1.5x the buffer size to test multiple segments and a fractional segment. + int plaintextSegmentSize = GcmTransportCipher.CIPHERTEXT_BUFFER_SIZE - 16; + int plaintextSize = plaintextSegmentSize + (plaintextSegmentSize / 2); + byte[] data = new byte[plaintextSize]; + Arrays.fill(data, (byte) 'x'); + ByteBuf buf = Unpooled.wrappedBuffer(data); + + // Mock the context and capture the arguments passed to it + ChannelHandlerContext ctx = mock(ChannelHandlerContext.class); + ChannelPromise promise = mock(ChannelPromise.class); + ArgumentCaptor captorWrappedEncrypted = + ArgumentCaptor.forClass(GcmTransportCipher.GcmEncryptedMessage.class); + encryptionHandler.write(ctx, buf, promise); + verify(ctx).write(captorWrappedEncrypted.capture(), eq(promise)); + + // Get the encrypted value and pass it to the decryption handler + GcmTransportCipher.GcmEncryptedMessage encrypted = + captorWrappedEncrypted.getValue(); + ByteBuffer ciphertextBuffer = + ByteBuffer.allocate((int) encrypted.count()); + ByteBufferWriteableChannel channel = + new ByteBufferWriteableChannel(ciphertextBuffer); + encrypted.transferTo(channel, 0); + ciphertextBuffer.flip(); + ByteBuf ciphertext = Unpooled.wrappedBuffer(ciphertextBuffer); + + // Split up the ciphertext into some different sized chunks + int firstChunkSize = plaintextSize / 2; + ByteBuf mockCiphertext = spy(ciphertext); + when(mockCiphertext.readableBytes()) + .thenReturn(firstChunkSize, firstChunkSize).thenCallRealMethod(); + + // Capture the decrypted values and verify them + ArgumentCaptor captorPlaintext = ArgumentCaptor.forClass(ByteBuf.class); + decryptionHandler.channelRead(ctx, mockCiphertext); + verify(ctx, times(2)).fireChannelRead(captorPlaintext.capture()); + ByteBuf lastPlaintextSegment = captorPlaintext.getValue(); + assertEquals(plaintextSegmentSize/2, + lastPlaintextSegment.readableBytes()); + assertEquals('x', + lastPlaintextSegment.getByte((plaintextSegmentSize/2) - 10)); + } + } + + @Test + public void testCorruptGcmEncryptedMessage() throws Exception { + TransportConf gcmConf = getConf(2, false); + + try (AuthEngine client = new AuthEngine("appId", "secret", gcmConf); + AuthEngine server = new AuthEngine("appId", "secret", gcmConf)) { + AuthMessage clientChallenge = client.challenge(); + AuthMessage serverResponse = server.response(clientChallenge); + client.deriveSessionCipher(clientChallenge, serverResponse); + + TransportCipher clientCipher = server.sessionCipher(); + assert (clientCipher instanceof GcmTransportCipher); + + GcmTransportCipher gcmTransportCipher = (GcmTransportCipher) clientCipher; + GcmTransportCipher.EncryptionHandler encryptionHandler = + gcmTransportCipher.getEncryptionHandler(); + GcmTransportCipher.DecryptionHandler decryptionHandler = + gcmTransportCipher.getDecryptionHandler(); + byte[] zeroData = new byte[1024 * 32]; + // Just writing some bytes. + ByteBuf buf = Unpooled.wrappedBuffer(zeroData); + + // Mock the context and capture the arguments passed to it + ChannelHandlerContext ctx = mock(ChannelHandlerContext.class); + ChannelPromise promise = mock(ChannelPromise.class); + ArgumentCaptor captorWrappedEncrypted = + ArgumentCaptor.forClass(GcmTransportCipher.GcmEncryptedMessage.class); + encryptionHandler.write(ctx, buf, promise); + verify(ctx).write(captorWrappedEncrypted.capture(), eq(promise)); + + GcmTransportCipher.GcmEncryptedMessage encrypted = + captorWrappedEncrypted.getValue(); + ByteBuffer ciphertextBuffer = + ByteBuffer.allocate((int) encrypted.count()); + ByteBufferWriteableChannel channel = + new ByteBufferWriteableChannel(ciphertextBuffer); + encrypted.transferTo(channel, 0); + ciphertextBuffer.flip(); + ByteBuf ciphertext = Unpooled.wrappedBuffer(ciphertextBuffer); + + byte b = ciphertext.getByte(100); + // Inverting the bits of the 100th bit + ciphertext.setByte(100, ~b & 0xFF); + assertThrows(AEADBadTagException.class, () -> decryptionHandler.channelRead(ctx, ciphertext)); + } + } +} diff --git a/common/network-common/src/test/java/org/apache/spark/network/crypto/TransportCipherSuite.java b/common/network-common/src/test/java/org/apache/spark/network/crypto/TransportCipherSuite.java index cde5c1c1022c4..35f7886e174a9 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/crypto/TransportCipherSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/crypto/TransportCipherSuite.java @@ -41,10 +41,10 @@ public class TransportCipherSuite { @Test - public void testBufferNotLeaksOnInternalError() throws IOException { + public void testCtrBufferNotLeaksOnInternalError() throws IOException { String algorithm = "TestAlgorithm"; TransportConf conf = new TransportConf("Test", MapConfigProvider.EMPTY); - TransportCipher cipher = new TransportCipher(conf.cryptoConf(), conf.cipherTransformation(), + CtrTransportCipher cipher = new CtrTransportCipher(conf.cryptoConf(), new SecretKeySpec(new byte[256], algorithm), new byte[0], new byte[0]) { @Override diff --git a/docs/security.md b/docs/security.md index c0a4b4da03030..a9065a225017e 100644 --- a/docs/security.md +++ b/docs/security.md @@ -175,6 +175,15 @@ The following table describes the different options available for configuring th 2.2.0 + + spark.network.crypto.cipher + AES/CTR/NoPadding + + Cipher mode to use. Defaults "AES/CTR/NoPadding" for backward compatibility, which is not authenticated. + Recommended to use "AES/GCM/NoPadding", which is an authenticated encryption mode. + + 4.0.0 + spark.network.crypto.authEngineVersion 1 From aec3208c6ce35ac6e71481569fe21566cf097852 Mon Sep 17 00:00:00 2001 From: Steve Weis Date: Mon, 24 Jun 2024 10:09:05 -0700 Subject: [PATCH 2/3] [SPARK-47172][3.4] Fixing Java 8 build issues --- .../network/crypto/GcmTransportCipher.java | 35 ++++++++++--------- .../util/ByteBufferWriteableChannel.java | 6 ++-- .../network/crypto/GcmAuthEngineSuite.java | 9 ++--- 3 files changed, 27 insertions(+), 23 deletions(-) diff --git a/common/network-common/src/main/java/org/apache/spark/network/crypto/GcmTransportCipher.java b/common/network-common/src/main/java/org/apache/spark/network/crypto/GcmTransportCipher.java index 9599b78007374..6f9566662fd7d 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/crypto/GcmTransportCipher.java +++ b/common/network-common/src/main/java/org/apache/spark/network/crypto/GcmTransportCipher.java @@ -30,6 +30,7 @@ import javax.crypto.spec.SecretKeySpec; import java.io.IOException; +import java.nio.Buffer; import java.nio.ByteBuffer; import java.nio.channels.WritableByteChannel; import java.security.GeneralSecurityException; @@ -127,8 +128,7 @@ static class GcmEncryptedMessage extends AbstractFileRegion { // If the ciphertext buffer cannot be fully written the target, transferTo may // return with it containing some unwritten data. The initial call we'll explicitly // set its limit to 0 to indicate the first call to transferTo. - this.ciphertextBuffer.limit(0); - + ((Buffer) this.ciphertextBuffer).limit(0); this.bytesToRead = getReadableBytes(); this.encryptedCount = LENGTH_HEADER_BYTES + aesGcmHkdfStreaming.expectedCiphertextSize(bytesToRead); @@ -141,11 +141,12 @@ static class GcmEncryptedMessage extends AbstractFileRegion { // [8 byte length][Internal IV and header][Ciphertext][Auth Tag] private ByteBuffer createHeaderByteBuffer() { ByteBuffer encrypterHeader = encrypter.getHeader(); - return ByteBuffer + ByteBuffer output = ByteBuffer .allocate(encrypterHeader.remaining() + LENGTH_HEADER_BYTES) .putLong(encryptedCount) - .put(encrypterHeader) - .flip(); + .put(encrypterHeader); + ((Buffer) output).flip(); + return output; } @Override @@ -229,7 +230,7 @@ public long transferTo(WritableByteChannel target, long position) throws IOExcep if (plaintextMessage instanceof ByteBuf) { ByteBuf byteBuf = (ByteBuf) plaintextMessage; Preconditions.checkState(0 == plaintextBuffer.position()); - plaintextBuffer.limit(readLimit); + ((Buffer) plaintextBuffer).limit(readLimit); byteBuf.readBytes(plaintextBuffer); Preconditions.checkState(readLimit == plaintextBuffer.position()); } else if (plaintextMessage instanceof FileRegion) { @@ -245,16 +246,16 @@ public long transferTo(WritableByteChannel target, long position) throws IOExcep } } boolean lastSegment = getReadableBytes() == 0; - plaintextBuffer.flip(); + ((Buffer) plaintextBuffer).flip(); bytesRead += plaintextBuffer.remaining(); - ciphertextBuffer.clear(); + ((Buffer) ciphertextBuffer).clear(); try { encrypter.encryptSegment(plaintextBuffer, lastSegment, ciphertextBuffer); } catch (GeneralSecurityException e) { throw new IllegalStateException("GeneralSecurityException from encrypter", e); } - plaintextBuffer.clear(); - ciphertextBuffer.flip(); + ((Buffer) plaintextBuffer).clear(); + ((Buffer) ciphertextBuffer).flip(); int written = target.write(ciphertextBuffer); transferredThisCall += written; this.transferred += written; @@ -321,7 +322,7 @@ private boolean initalizeExpectedLength(ByteBuf ciphertextNettyBuf) { // We did not read enough bytes to initialize the expected length. return false; } - expectedLengthBuffer.flip(); + ((Buffer) expectedLengthBuffer).flip(); expectedLength = expectedLengthBuffer.getLong(); if (expectedLength < 0) { throw new IllegalStateException("Invalid expected ciphertext length."); @@ -341,7 +342,7 @@ private boolean initalizeDecrypter(ByteBuf ciphertextNettyBuf) // We did not read enough bytes to initialize the header. return false; } - headerBuffer.flip(); + ((Buffer) headerBuffer).flip(); byte[] lengthAad = Longs.toByteArray(expectedLength); decrypter.init(headerBuffer, lengthAad); decrypterInit = true; @@ -382,7 +383,7 @@ public void channelRead(ChannelHandlerContext ctx, Object ciphertextMessage) int expectedRemaining = (int) (expectedLength - ciphertextRead); int bytesToRead = Integer.min(readableBytes, expectedRemaining); // The smallest ciphertext size is 16 bytes for the auth tag - ciphertextBuffer.limit(ciphertextBuffer.position() + bytesToRead); + ((Buffer) ciphertextBuffer).limit(((Buffer) ciphertextBuffer).position() + bytesToRead); ciphertextNettyBuf.readBytes(ciphertextBuffer); ciphertextRead += bytesToRead; // Check if this is the last segment @@ -395,7 +396,7 @@ public void channelRead(ChannelHandlerContext ctx, Object ciphertextMessage) // then decrypt it and fire a read. if (ciphertextBuffer.limit() == ciphertextBuffer.capacity() || completed) { ByteBuffer plaintextBuffer = ByteBuffer.allocate(plaintextSegmentSize); - ciphertextBuffer.flip(); + ((Buffer) ciphertextBuffer).flip(); decrypter.decryptSegment( ciphertextBuffer, segmentNumber, @@ -403,12 +404,12 @@ public void channelRead(ChannelHandlerContext ctx, Object ciphertextMessage) plaintextBuffer); segmentNumber++; // Clear the ciphertext buffer because it's been read - ciphertextBuffer.clear(); - plaintextBuffer.flip(); + ((Buffer) ciphertextBuffer).clear(); + ((Buffer) plaintextBuffer).flip(); ctx.fireChannelRead(Unpooled.wrappedBuffer(plaintextBuffer)); } else { // Set the ciphertext buffer up to read the next chunk - ciphertextBuffer.limit(ciphertextBuffer.capacity()); + ((Buffer) ciphertextBuffer).limit(ciphertextBuffer.capacity()); } nettyBufReadableBytes = ciphertextNettyBuf.readableBytes(); } diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/ByteBufferWriteableChannel.java b/common/network-common/src/main/java/org/apache/spark/network/util/ByteBufferWriteableChannel.java index b20240cfcaa6d..d49f46afa7ec4 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/util/ByteBufferWriteableChannel.java +++ b/common/network-common/src/main/java/org/apache/spark/network/util/ByteBufferWriteableChannel.java @@ -18,6 +18,7 @@ package org.apache.spark.network.util; import java.io.IOException; +import java.nio.Buffer; import java.nio.ByteBuffer; import java.nio.channels.ClosedChannelException; import java.nio.channels.WritableByteChannel; @@ -41,9 +42,10 @@ public int write(ByteBuffer src) throws IOException { if (bytesToWrite == 0) { return 0; } - ByteBuffer temp = src.slice().limit(bytesToWrite); + ByteBuffer temp = src.slice(); + ((Buffer) temp).limit(bytesToWrite); destination.put(temp); - src.position(src.position() + bytesToWrite); + ((Buffer) src).position(((Buffer) src).position() + bytesToWrite); return bytesToWrite; } diff --git a/common/network-common/src/test/java/org/apache/spark/network/crypto/GcmAuthEngineSuite.java b/common/network-common/src/test/java/org/apache/spark/network/crypto/GcmAuthEngineSuite.java index 19e9eb41a1ef8..f25277aa1a997 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/crypto/GcmAuthEngineSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/crypto/GcmAuthEngineSuite.java @@ -30,6 +30,7 @@ import javax.crypto.AEADBadTagException; import java.io.IOException; +import java.nio.Buffer; import java.nio.ByteBuffer; import java.nio.channels.WritableByteChannel; import java.util.Arrays; @@ -87,7 +88,7 @@ public void testGcmEncryptedMessage() throws Exception { ByteBufferWriteableChannel channel = new ByteBufferWriteableChannel(ciphertextBuffer); encrypted.transferTo(channel, 0); - ciphertextBuffer.flip(); + ((Buffer) ciphertextBuffer).flip(); ByteBuf ciphertext = Unpooled.wrappedBuffer(ciphertextBuffer); // Capture the decrypted values and verify them @@ -216,7 +217,7 @@ public void testGcmEncryptedMessageFileRegion() throws Exception { } assertEquals(encrypted.count(), ciphertextTransferred); - ciphertextBuffer.flip(); + ((Buffer) ciphertextBuffer).flip(); ByteBuf ciphertext = Unpooled.wrappedBuffer(ciphertextBuffer); // Capture the decrypted values and verify them @@ -272,7 +273,7 @@ public void testGcmUnalignedDecryption() throws Exception { ByteBufferWriteableChannel channel = new ByteBufferWriteableChannel(ciphertextBuffer); encrypted.transferTo(channel, 0); - ciphertextBuffer.flip(); + ((Buffer) ciphertextBuffer).flip(); ByteBuf ciphertext = Unpooled.wrappedBuffer(ciphertextBuffer); // Split up the ciphertext into some different sized chunks @@ -330,7 +331,7 @@ public void testCorruptGcmEncryptedMessage() throws Exception { ByteBufferWriteableChannel channel = new ByteBufferWriteableChannel(ciphertextBuffer); encrypted.transferTo(channel, 0); - ciphertextBuffer.flip(); + ((Buffer) ciphertextBuffer).flip(); ByteBuf ciphertext = Unpooled.wrappedBuffer(ciphertextBuffer); byte b = ciphertext.getByte(100); From aca0295728ab742da999b49c6084617088ec9700 Mon Sep 17 00:00:00 2001 From: Steve Weis Date: Mon, 24 Jun 2024 14:29:37 -0700 Subject: [PATCH 3/3] [SPARK-47172][3.5] Fixing lint errors --- .../org/apache/spark/network/crypto/GcmTransportCipher.java | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/common/network-common/src/main/java/org/apache/spark/network/crypto/GcmTransportCipher.java b/common/network-common/src/main/java/org/apache/spark/network/crypto/GcmTransportCipher.java index 6f9566662fd7d..d3f1bf490d3a3 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/crypto/GcmTransportCipher.java +++ b/common/network-common/src/main/java/org/apache/spark/network/crypto/GcmTransportCipher.java @@ -383,7 +383,8 @@ public void channelRead(ChannelHandlerContext ctx, Object ciphertextMessage) int expectedRemaining = (int) (expectedLength - ciphertextRead); int bytesToRead = Integer.min(readableBytes, expectedRemaining); // The smallest ciphertext size is 16 bytes for the auth tag - ((Buffer) ciphertextBuffer).limit(((Buffer) ciphertextBuffer).position() + bytesToRead); + ((Buffer) ciphertextBuffer).limit( + ((Buffer) ciphertextBuffer).position() + bytesToRead); ciphertextNettyBuf.readBytes(ciphertextBuffer); ciphertextRead += bytesToRead; // Check if this is the last segment