Skip to content

Commit

Permalink
[SPARK-13331] AES support for over-the-wire encryption
Browse files Browse the repository at this point in the history
Use encryption streaming from Apache Common Crypto to simply the encryption pipeline.
  • Loading branch information
Junjie Chen committed Oct 11, 2016
2 parents c1936eb + d8e7baf commit 0bf663f
Show file tree
Hide file tree
Showing 14 changed files with 580 additions and 477 deletions.
Expand Up @@ -18,16 +18,21 @@
package org.apache.spark.network.sasl;

import java.io.IOException;
import java.nio.ByteBuffer;
import javax.security.sasl.Sasl;
import javax.security.sasl.SaslException;

import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.Channel;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import org.apache.spark.network.client.TransportClient;
import org.apache.spark.network.client.TransportClientBootstrap;
import org.apache.spark.network.sasl.aes.SparkAesSaslClient;
import org.apache.spark.network.sasl.aes.AesEncryption;
import org.apache.spark.network.sasl.aes.AesCipher;
import org.apache.spark.network.util.JavaUtils;
import org.apache.spark.network.util.TransportConf;

/**
Expand Down Expand Up @@ -64,13 +69,20 @@ public SaslClientBootstrap(
*/
@Override
public void doBootstrap(TransportClient client, Channel channel) {
boolean aesEnable = conf.saslEncryptionAesEnabled();
SparkSaslClient saslClient = aesEnable ?
new SparkAesSaslClient(appId, secretKeyHolder, encrypt) :
new SparkSaslClient(appId, secretKeyHolder, encrypt);

SparkSaslClient saslClient = new SparkSaslClient(appId, secretKeyHolder, encrypt);
try {
saslClient.negotiate(client, conf);
byte[] payload = saslClient.firstToken();

while (!saslClient.isComplete()) {
SaslMessage msg = new SaslMessage(appId, payload);
ByteBuf buf = Unpooled.buffer(msg.encodedLength() + (int) msg.body().size());
msg.encode(buf);
buf.writeBytes(msg.body().nioByteBuffer());

ByteBuffer response = client.sendRpcSync(buf.nioBuffer(), conf.saslRTTimeoutMs());
payload = saslClient.response(JavaUtils.bufferToArray(response));
}

client.setClientId(appId);

if (encrypt) {
Expand All @@ -79,7 +91,15 @@ public void doBootstrap(TransportClient client, Channel channel) {
new SaslException("Encryption requests by negotiated non-encrypted connection."));
}

SaslEncryption.addToChannel(channel, saslClient, conf.maxSaslEncryptedBlockSize());
if(conf.saslEncryptionAesEnabled()) {
Object result = saslClient.negotiate(client, conf);
if (result instanceof AesCipher) {
logger.info("Enabling AES encryption for client channel {}", client);
AesEncryption.addToChannel(channel, (AesCipher) result);
}
} else {
SaslEncryption.addToChannel(channel, saslClient, conf.maxSaslEncryptedBlockSize());
}
saslClient = null;
logger.debug("Channel {} configured for SASL encryption.", client);
}
Expand Down
Expand Up @@ -20,6 +20,7 @@
import java.io.IOException;
import java.nio.ByteBuffer;
import javax.security.sasl.Sasl;
import javax.security.sasl.SaslException;

import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
Expand All @@ -29,7 +30,8 @@

import org.apache.spark.network.client.RpcResponseCallback;
import org.apache.spark.network.client.TransportClient;
import org.apache.spark.network.sasl.aes.SparkAesSaslServer;
import org.apache.spark.network.sasl.aes.AesEncryption;
import org.apache.spark.network.sasl.aes.AesCipher;
import org.apache.spark.network.server.RpcHandler;
import org.apache.spark.network.server.StreamManager;
import org.apache.spark.network.util.JavaUtils;
Expand Down Expand Up @@ -60,6 +62,7 @@ class SaslRpcHandler extends RpcHandler {

private SparkSaslServer saslServer;
private boolean isComplete;
private boolean isAuthenticated;

SaslRpcHandler(
TransportConf conf,
Expand All @@ -72,11 +75,11 @@ class SaslRpcHandler extends RpcHandler {
this.secretKeyHolder = secretKeyHolder;
this.saslServer = null;
this.isComplete = false;
this.isAuthenticated = false;
}

@Override
public void receive(TransportClient client, ByteBuffer message, RpcResponseCallback callback) {
boolean aesEnable = conf.saslEncryptionAesEnabled();
boolean encrypt = conf.saslServerAlwaysEncrypt();
if (isComplete) {
// Authentication complete, delegate to base handler.
Expand All @@ -95,9 +98,7 @@ public void receive(TransportClient client, ByteBuffer message, RpcResponseCallb
if (saslServer == null) {
// First message in the handshake, setup the necessary state.
client.setClientId(saslMessage.appId);
saslServer = aesEnable ?
new SparkAesSaslServer(saslMessage.appId, secretKeyHolder, encrypt) :
new SparkSaslServer(saslMessage.appId, secretKeyHolder, encrypt);
saslServer = new SparkSaslServer(saslMessage.appId, secretKeyHolder, encrypt);
}

byte[] response;
Expand All @@ -117,11 +118,29 @@ public void receive(TransportClient client, ByteBuffer message, RpcResponseCallb
// messages are being written to the channel while negotiation is still going on.
if (saslServer.isComplete()) {
if (SparkSaslServer.QOP_AUTH_CONF.equals(saslServer.getNegotiatedProperty(Sasl.QOP))) {
// Return directly if negotiate not finish.
if (!saslServer.negotiate(message, callback, conf)) return;
logger.debug("Enabling encryption for channel {}", client);
SaslEncryption.addToChannel(channel, saslServer, conf.maxSaslEncryptedBlockSize());
saslServer = null;
try {
if (conf.saslEncryptionAesEnabled()) {
// Extra negotiation should happen after authentication, so return directly while
// processing authenticate.
if (!isAuthenticated) {
logger.debug("SASL authentication successful for channel {}", client);
isAuthenticated = true;
return ;
} else {
Object result = saslServer.negotiate(message, callback, conf);
if (result instanceof AesCipher) {
logger.info("Enabling AES encryption for Server channel {}", client);
AesEncryption.addToChannel(channel, (AesCipher) result);
}
}
} else {
logger.debug("Enabling encryption for channel {}", client);
SaslEncryption.addToChannel(channel, saslServer, conf.maxSaslEncryptedBlockSize());
}
saslServer = null;
} catch (SaslException e) {
return ;
}
} else {
logger.debug("SASL authentication successful for channel {}", client);
saslServer.dispose();
Expand Down Expand Up @@ -161,4 +180,5 @@ public void channelInactive(TransportClient client) {
public void exceptionCaught(Throwable cause, TransportClient client) {
delegate.exceptionCaught(cause, client);
}
}

}
Expand Up @@ -20,6 +20,7 @@
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.Map;
import java.util.Properties;
import javax.security.auth.callback.Callback;
import javax.security.auth.callback.CallbackHandler;
import javax.security.auth.callback.NameCallback;
Expand All @@ -35,11 +36,13 @@
import com.google.common.collect.ImmutableMap;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import org.apache.spark.network.client.TransportClient;
import org.apache.spark.network.util.JavaUtils;
import org.apache.spark.network.sasl.aes.AesCipherOption;
import org.apache.spark.network.sasl.aes.AesCipher;
import org.apache.spark.network.util.TransportConf;

import static org.apache.spark.network.sasl.SparkSaslServer.*;
Expand Down Expand Up @@ -127,22 +130,35 @@ public synchronized void dispose() {
}

/**
* Negotiate extra encryption options for SASL
* @param client is transport client used to connect to peer.
* @param conf contain client transport configuration.
* @throws IOException
* @return The object represent the result of negotiate.
*/
public void negotiate(TransportClient client, TransportConf conf) throws IOException {
byte[] payload = firstToken();
public Object negotiate(TransportClient client, TransportConf conf) throws IOException {
// Create option for negotiation
AesCipherOption cipherOption = new AesCipherOption();
ByteBuf buf = Unpooled.buffer(cipherOption.encodedLength());
cipherOption.encode(buf);

while (!isComplete()) {
SaslMessage msg = new SaslMessage(secretKeyId, payload);
ByteBuf buf = Unpooled.buffer(msg.encodedLength() + (int) msg.body().size());
msg.encode(buf);
buf.writeBytes(msg.body().nioByteBuffer());
// Send option to server and decode received negotiated option
ByteBuffer response = client.sendRpcSync(buf.nioBuffer(), conf.saslRTTimeoutMs());
cipherOption = AesCipherOption.decode(Unpooled.wrappedBuffer(response));

ByteBuffer response = client.sendRpcSync(buf.nioBuffer(), conf.saslRTTimeoutMs());
payload = response(JavaUtils.bufferToArray(response));
}
// Decrypt key from option. Server's outKey is client's inKey, and vice versa.
byte[] outKey = unwrap(cipherOption.inKey, 0, cipherOption.inKey.length);
byte[] inKey = unwrap(cipherOption.outKey, 0, cipherOption.outKey.length);

// Enable AES on SaslClient
Properties properties = new Properties();

AesCipher cipher = new AesCipher(properties, inKey, outKey,
cipherOption.outIv, cipherOption.inIv);

logger.debug("AES enabled for SASL client encryption.");

return cipher;
}

/**
Expand Down Expand Up @@ -184,4 +200,5 @@ public byte[] wrap(byte[] data, int offset, int len) throws SaslException {
public byte[] unwrap(byte[] data, int offset, int len) throws SaslException {
return saslClient.unwrap(data, offset, len);
}

}
Expand Up @@ -31,16 +31,23 @@
import javax.security.sasl.SaslException;
import javax.security.sasl.SaslServer;
import java.util.Map;
import java.util.Properties;

import com.google.common.base.Preconditions;
import com.google.common.base.Throwables;
import com.google.common.collect.ImmutableMap;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.handler.codec.base64.Base64;
import org.apache.commons.crypto.cipher.CryptoCipherFactory;
import org.apache.commons.crypto.random.CryptoRandom;
import org.apache.commons.crypto.random.CryptoRandomFactory;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import org.apache.spark.network.client.RpcResponseCallback;
import org.apache.spark.network.sasl.aes.AesCipherOption;
import org.apache.spark.network.sasl.aes.AesCipher;
import org.apache.spark.network.util.TransportConf;

/**
Expand Down Expand Up @@ -158,10 +165,55 @@ public byte[] unwrap(byte[] data, int offset, int len) throws SaslException {
* @param message is message receive from peer which may contains communication parameters.
* @param callback is rpc callback.
* @param conf contains transport configuration.
* @return true if negotiate finish successfully, else false.
* @return Object which represent the result of negotiate.
*/
public boolean negotiate(ByteBuffer message, RpcResponseCallback callback, TransportConf conf) {
return true;
public Object negotiate(ByteBuffer message, RpcResponseCallback callback, TransportConf conf)
throws SaslException {
AesCipher cipher;

// Receive initial option from client
AesCipherOption cipherOption = AesCipherOption.decode(Unpooled.wrappedBuffer(message));
String transformation = AesCipher.TRANSFORM;
Properties properties = new Properties();

try {
// Generate key and iv
if (conf.saslEncryptionAesCipherKeySizeBits() % 8 != 0) {
throw new IllegalArgumentException("The AES cipher key size in bits should be a multiple " +
"of byte");
}

int keyLen = conf.saslEncryptionAesCipherKeySizeBits() / 8;
int paramLen = CryptoCipherFactory.getCryptoCipher(transformation,properties).getBlockSize();
byte[] inKey = new byte[keyLen];
byte[] outKey = new byte[keyLen];
byte[] inIv = new byte[paramLen];
byte[] outIv = new byte[paramLen];

// Get the 'CryptoRandom' instance.
CryptoRandom random = CryptoRandomFactory.getCryptoRandom(properties);
random.nextBytes(inKey);
random.nextBytes(outKey);
random.nextBytes(inIv);
random.nextBytes(outIv);

// Update cipher option for client. The key is encrypted.
cipherOption.setParameters(wrap(inKey, 0, inKey.length), inIv,
wrap(outKey, 0, outKey.length), outIv);

// Enable AES on saslServer
cipher = new AesCipher(properties, inKey, outKey, inIv, outIv);

// Send cipher option to client
ByteBuf buf = Unpooled.buffer(cipherOption.encodedLength());
cipherOption.encode(buf);
callback.onSuccess(buf.nioBuffer());
} catch (Exception e) {
logger.error("AES negotiation exception: ", e);
throw Throwables.propagate(e);
}

return cipher;
}

/**
Expand Down Expand Up @@ -212,4 +264,4 @@ public static char[] encodePassword(String password) {
return Base64.encode(Unpooled.wrappedBuffer(password.getBytes(StandardCharsets.UTF_8)))
.toString(StandardCharsets.UTF_8).toCharArray();
}
}
}

0 comments on commit 0bf663f

Please sign in to comment.