Skip to content

Commit

Permalink
Apply appropriate RPC handler to receive, receiveStream when auth ena…
Browse files Browse the repository at this point in the history
…bled
  • Loading branch information
srowen committed Apr 18, 2020
1 parent 29f5962 commit c80d5f7
Show file tree
Hide file tree
Showing 5 changed files with 142 additions and 113 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,11 @@
import org.slf4j.LoggerFactory;

import org.apache.spark.network.client.RpcResponseCallback;
import org.apache.spark.network.client.StreamCallbackWithID;
import org.apache.spark.network.client.TransportClient;
import org.apache.spark.network.sasl.SecretKeyHolder;
import org.apache.spark.network.sasl.SaslRpcHandler;
import org.apache.spark.network.server.AbstractAuthRpcHandler;
import org.apache.spark.network.server.RpcHandler;
import org.apache.spark.network.server.StreamManager;
import org.apache.spark.network.util.TransportConf;

/**
Expand All @@ -46,7 +45,7 @@
* The delegate will only receive messages if the given connection has been successfully
* authenticated. A connection may be authenticated at most once.
*/
class AuthRpcHandler extends RpcHandler {
class AuthRpcHandler extends AbstractAuthRpcHandler {
private static final Logger LOG = LoggerFactory.getLogger(AuthRpcHandler.class);

/** Transport configuration. */
Expand All @@ -55,36 +54,31 @@ class AuthRpcHandler extends RpcHandler {
/** The client channel. */
private final Channel channel;

/**
* RpcHandler we will delegate to for authenticated connections. When falling back to SASL
* this will be replaced with the SASL RPC handler.
*/
@VisibleForTesting
RpcHandler delegate;

/** Class which provides secret keys which are shared by server and client on a per-app basis. */
private final SecretKeyHolder secretKeyHolder;

/** Whether auth is done and future calls should be delegated. */
/** RPC handler for auth handshake when falling back to SASL auth. */
@VisibleForTesting
boolean doDelegate;
SaslRpcHandler saslHandler;

AuthRpcHandler(
TransportConf conf,
Channel channel,
RpcHandler delegate,
SecretKeyHolder secretKeyHolder) {
super(delegate);
this.conf = conf;
this.channel = channel;
this.delegate = delegate;
this.secretKeyHolder = secretKeyHolder;
}

@Override
public void receive(TransportClient client, ByteBuffer message, RpcResponseCallback callback) {
if (doDelegate) {
delegate.receive(client, message, callback);
return;
protected boolean doAuthChallenge(
TransportClient client,
ByteBuffer message,
RpcResponseCallback callback) {
if (saslHandler != null) {
return saslHandler.doAuthChallenge(client, message, callback);
}

int position = message.position();
Expand All @@ -98,18 +92,17 @@ public void receive(TransportClient client, ByteBuffer message, RpcResponseCallb
if (conf.saslFallback()) {
LOG.warn("Failed to parse new auth challenge, reverting to SASL for client {}.",
channel.remoteAddress());
delegate = new SaslRpcHandler(conf, channel, delegate, secretKeyHolder);
saslHandler = new SaslRpcHandler(conf, channel, null, secretKeyHolder);
message.position(position);
message.limit(limit);
delegate.receive(client, message, callback);
doDelegate = true;
return saslHandler.doAuthChallenge(client, message, callback);
} else {
LOG.debug("Unexpected challenge message from client {}, closing channel.",
channel.remoteAddress());
callback.onFailure(new IllegalArgumentException("Unknown challenge message."));
channel.close();
}
return;
return false;
}

// Here we have the client challenge, so perform the new auth protocol and set up the channel.
Expand All @@ -131,7 +124,7 @@ public void receive(TransportClient client, ByteBuffer message, RpcResponseCallb
LOG.debug("Authentication failed for client {}, closing channel.", channel.remoteAddress());
callback.onFailure(new IllegalArgumentException("Authentication failed."));
channel.close();
return;
return false;
} finally {
if (engine != null) {
try {
Expand All @@ -143,40 +136,6 @@ public void receive(TransportClient client, ByteBuffer message, RpcResponseCallb
}

LOG.debug("Authorization successful for client {}.", channel.remoteAddress());
doDelegate = true;
}

@Override
public void receive(TransportClient client, ByteBuffer message) {
delegate.receive(client, message);
}

@Override
public StreamCallbackWithID receiveStream(
TransportClient client,
ByteBuffer message,
RpcResponseCallback callback) {
return delegate.receiveStream(client, message, callback);
return true;
}

@Override
public StreamManager getStreamManager() {
return delegate.getStreamManager();
}

@Override
public void channelActive(TransportClient client) {
delegate.channelActive(client);
}

@Override
public void channelInactive(TransportClient client) {
delegate.channelInactive(client);
}

@Override
public void exceptionCaught(Throwable cause, TransportClient client) {
delegate.exceptionCaught(cause, client);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,9 @@
import org.slf4j.LoggerFactory;

import org.apache.spark.network.client.RpcResponseCallback;
import org.apache.spark.network.client.StreamCallbackWithID;
import org.apache.spark.network.client.TransportClient;
import org.apache.spark.network.server.AbstractAuthRpcHandler;
import org.apache.spark.network.server.RpcHandler;
import org.apache.spark.network.server.StreamManager;
import org.apache.spark.network.util.JavaUtils;
import org.apache.spark.network.util.TransportConf;

Expand All @@ -43,7 +42,7 @@
* Note that the authentication process consists of multiple challenge-response pairs, each of
* which are individual RPCs.
*/
public class SaslRpcHandler extends RpcHandler {
public class SaslRpcHandler extends AbstractAuthRpcHandler {
private static final Logger logger = LoggerFactory.getLogger(SaslRpcHandler.class);

/** Transport configuration. */
Expand All @@ -52,37 +51,28 @@ public class SaslRpcHandler extends RpcHandler {
/** The client channel. */
private final Channel channel;

/** RpcHandler we will delegate to for authenticated connections. */
private final RpcHandler delegate;

/** Class which provides secret keys which are shared by server and client on a per-app basis. */
private final SecretKeyHolder secretKeyHolder;

private SparkSaslServer saslServer;
private boolean isComplete;
private boolean isAuthenticated;

public SaslRpcHandler(
TransportConf conf,
Channel channel,
RpcHandler delegate,
SecretKeyHolder secretKeyHolder) {
super(delegate);
this.conf = conf;
this.channel = channel;
this.delegate = delegate;
this.secretKeyHolder = secretKeyHolder;
this.saslServer = null;
this.isComplete = false;
this.isAuthenticated = false;
}

@Override
public void receive(TransportClient client, ByteBuffer message, RpcResponseCallback callback) {
if (isComplete) {
// Authentication complete, delegate to base handler.
delegate.receive(client, message, callback);
return;
}
public boolean doAuthChallenge(
TransportClient client,
ByteBuffer message,
RpcResponseCallback callback) {
if (saslServer == null || !saslServer.isComplete()) {
ByteBuf nettyBuf = Unpooled.wrappedBuffer(message);
SaslMessage saslMessage;
Expand Down Expand Up @@ -118,55 +108,28 @@ public void receive(TransportClient client, ByteBuffer message, RpcResponseCallb
if (!SparkSaslServer.QOP_AUTH_CONF.equals(saslServer.getNegotiatedProperty(Sasl.QOP))) {
logger.debug("SASL authentication successful for channel {}", client);
complete(true);
return;
return true;
}

logger.debug("Enabling encryption for channel {}", client);
SaslEncryption.addToChannel(channel, saslServer, conf.maxSaslEncryptedBlockSize());
complete(false);
return;
return true;
}
}

@Override
public void receive(TransportClient client, ByteBuffer message) {
delegate.receive(client, message);
}

@Override
public StreamCallbackWithID receiveStream(
TransportClient client,
ByteBuffer message,
RpcResponseCallback callback) {
return delegate.receiveStream(client, message, callback);
}

@Override
public StreamManager getStreamManager() {
return delegate.getStreamManager();
}

@Override
public void channelActive(TransportClient client) {
delegate.channelActive(client);
return false;
}

@Override
public void channelInactive(TransportClient client) {
try {
delegate.channelInactive(client);
super.channelInactive(client);
} finally {
if (saslServer != null) {
saslServer.dispose();
}
}
}

@Override
public void exceptionCaught(Throwable cause, TransportClient client) {
delegate.exceptionCaught(cause, client);
}

private void complete(boolean dispose) {
if (dispose) {
try {
Expand All @@ -177,7 +140,6 @@ private void complete(boolean dispose) {
}

saslServer = null;
isComplete = true;
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.network.server;

import java.nio.ByteBuffer;

import org.apache.spark.network.client.RpcResponseCallback;
import org.apache.spark.network.client.StreamCallbackWithID;
import org.apache.spark.network.client.TransportClient;

/**
* RPC Handler which performs authentication, and when it's successful, delegates further
* calls to another RPC handler. The authentication handshake itself should be implemented
* by subclasses.
*/
public abstract class AbstractAuthRpcHandler extends RpcHandler {
/** RpcHandler we will delegate to for authenticated connections. */
private final RpcHandler delegate;

private boolean isAuthenticated;

protected AbstractAuthRpcHandler(RpcHandler delegate) {
this.delegate = delegate;
}

/**
* Responds to an authentication challenge.
*
* @return Whether the client is authenticated.
*/
protected abstract boolean doAuthChallenge(
TransportClient client,
ByteBuffer message,
RpcResponseCallback callback);

@Override
public final void receive(
TransportClient client,
ByteBuffer message,
RpcResponseCallback callback) {
if (isAuthenticated) {
delegate.receive(client, message, callback);
} else {
isAuthenticated = doAuthChallenge(client, message, callback);
}
}

@Override
public final void receive(TransportClient client, ByteBuffer message) {
if (isAuthenticated) {
delegate.receive(client, message);
} else {
throw new SecurityException("Unauthenticated call to receive().");
}
}

@Override
public final StreamCallbackWithID receiveStream(
TransportClient client,
ByteBuffer message,
RpcResponseCallback callback) {
if (isAuthenticated) {
return delegate.receiveStream(client, message, callback);
} else {
throw new SecurityException("Unauthenticated call to receiveStream().");
}
}

@Override
public StreamManager getStreamManager() {
return delegate.getStreamManager();
}

@Override
public void channelActive(TransportClient client) {
delegate.channelActive(client);
}

@Override
public void channelInactive(TransportClient client) {
delegate.channelInactive(client);
}

@Override
public void exceptionCaught(Throwable cause, TransportClient client) {
delegate.exceptionCaught(cause, client);
}

public boolean isAuthenticated() {
return isAuthenticated;
}
}
Loading

0 comments on commit c80d5f7

Please sign in to comment.