Skip to content

Commit

Permalink
Improve channel tracking for gRPC authentication traces
Browse files Browse the repository at this point in the history
Main purpose of this PR is to make gRPC authentication tracing to
contain channel information in order to improve debugging.

pr-link: #9316
change-id: cid-bde726e6095f480cf91c1da586140d846c37ff6e
  • Loading branch information
Göktürk Gezer authored and alluxio-bot committed Jun 19, 2019
1 parent c659ba3 commit b2e7ff6
Show file tree
Hide file tree
Showing 6 changed files with 56 additions and 29 deletions.
2 changes: 1 addition & 1 deletion core/common/src/main/java/alluxio/grpc/GrpcChannel.java
Expand Up @@ -70,7 +70,7 @@ public GrpcChannel(GrpcManagedChannelPool.ChannelKey channelKey, AuthenticatedCh


// Store {@link AuthenticatedChannel::#close) for signaling end of // Store {@link AuthenticatedChannel::#close) for signaling end of
// authenticated session during shutdown. // authenticated session during shutdown.
mAuthCloseCallback = ((AuthenticatedChannel) channel)::close; mAuthCloseCallback = channel::close;
} }


@Override @Override
Expand Down
Expand Up @@ -152,14 +152,14 @@ public void authenticate() throws AlluxioStatusException {
SaslHandshakeClientHandler handshakeClient = SaslHandshakeClientHandler handshakeClient =
new DefaultSaslHandshakeClientHandler(saslClientHandler); new DefaultSaslHandshakeClientHandler(saslClientHandler);
// Create driver for driving sasl traffic from client side. // Create driver for driving sasl traffic from client side.
mClientDriver = mClientDriver = new SaslStreamClientDriver(handshakeClient, mAuthenticated, mChannelId,
new SaslStreamClientDriver(handshakeClient, mAuthenticated, mGrpcAuthTimeoutMs); mGrpcAuthTimeoutMs);
// Start authentication call with the service and update the client driver. // Start authentication call with the service and update the client driver.
StreamObserver<SaslMessage> requestObserver = StreamObserver<SaslMessage> requestObserver =
SaslAuthenticationServiceGrpc.newStub(mManagedChannel).authenticate(mClientDriver); SaslAuthenticationServiceGrpc.newStub(mManagedChannel).authenticate(mClientDriver);
mClientDriver.setServerObserver(requestObserver); mClientDriver.setServerObserver(requestObserver);
// Start authentication traffic with the target. // Start authentication traffic with the target.
mClientDriver.start(mChannelId.toString()); mClientDriver.start();
// Authentication succeeded! // Authentication succeeded!
mManagedChannel.notifyWhenStateChanged(ConnectivityState.READY, () -> { mManagedChannel.notifyWhenStateChanged(ConnectivityState.READY, () -> {
mAuthenticated.set(false); mAuthenticated.set(false);
Expand Down
Expand Up @@ -20,6 +20,7 @@


import javax.security.sasl.SaslClient; import javax.security.sasl.SaslClient;
import javax.security.sasl.SaslException; import javax.security.sasl.SaslException;
import java.util.UUID;


/** /**
* Default implementation of {@link SaslHandshakeClientHandler}. * Default implementation of {@link SaslHandshakeClientHandler}.
Expand Down Expand Up @@ -69,7 +70,7 @@ public SaslMessage handleSaslMessage(SaslMessage message) throws SaslException {
} }


@Override @Override
public SaslMessage getInitialMessage(String channelId) throws SaslException { public SaslMessage getInitialMessage(UUID channelId) throws SaslException {
byte[] initiateSaslResponse = null; byte[] initiateSaslResponse = null;
if (mSaslClientHandler.getSaslClient().hasInitialResponse()) { if (mSaslClientHandler.getSaslClient().hasInitialResponse()) {
initiateSaslResponse = mSaslClient.evaluateChallenge(S_INITIATE_CHALLENGE); initiateSaslResponse = mSaslClient.evaluateChallenge(S_INITIATE_CHALLENGE);
Expand All @@ -80,7 +81,7 @@ public SaslMessage getInitialMessage(String channelId) throws SaslException {
if (initiateSaslResponse != null) { if (initiateSaslResponse != null) {
initialResponse.setMessage(ByteString.copyFrom(initiateSaslResponse)); initialResponse.setMessage(ByteString.copyFrom(initiateSaslResponse));
} }
initialResponse.setClientId(channelId); initialResponse.setClientId(channelId.toString());
return initialResponse.build(); return initialResponse.build();
} }
} }
Expand Up @@ -14,6 +14,7 @@
import alluxio.grpc.SaslMessage; import alluxio.grpc.SaslMessage;


import javax.security.sasl.SaslException; import javax.security.sasl.SaslException;
import java.util.UUID;


/** /**
* Interface for providing client-side handshake routines for a particular authentication scheme. * Interface for providing client-side handshake routines for a particular authentication scheme.
Expand All @@ -29,9 +30,9 @@ public interface SaslHandshakeClientHandler {
public SaslMessage handleSaslMessage(SaslMessage message) throws SaslException; public SaslMessage handleSaslMessage(SaslMessage message) throws SaslException;


/** /**
* @param channelId channe for which the authentication is happening * @param channelId channel for which the authentication is happening
* @return the initial message for Sasl traffic to begin * @return the initial message for Sasl traffic to begin
* @throws SaslException * @throws SaslException
*/ */
public SaslMessage getInitialMessage(String channelId) throws SaslException; public SaslMessage getInitialMessage(UUID channelId) throws SaslException;
} }
Expand Up @@ -25,6 +25,7 @@
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;


import java.util.UUID;
import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException; import java.util.concurrent.TimeoutException;
Expand Down Expand Up @@ -58,6 +59,8 @@ public class SaslStreamClientDriver implements StreamObserver<SaslMessage> {
private SettableFuture<Boolean> mHandshakeFuture; private SettableFuture<Boolean> mHandshakeFuture;
/** Current authenticated state. */ /** Current authenticated state. */
private AtomicBoolean mAuthenticated; private AtomicBoolean mAuthenticated;
/** Channel Id that has started the authentication. */
private UUID mChannelId;


private final long mGrpcAuthTimeoutMs; private final long mGrpcAuthTimeoutMs;


Expand All @@ -66,12 +69,14 @@ public class SaslStreamClientDriver implements StreamObserver<SaslMessage> {
* *
* @param handshakeClient client handshake handler * @param handshakeClient client handshake handler
* @param authenticated boolean reference to receive authentication state changes * @param authenticated boolean reference to receive authentication state changes
* @param channelId channel Id for authentication
* @param grpcAuthTimeoutMs authentication timeout in milliseconds * @param grpcAuthTimeoutMs authentication timeout in milliseconds
*/ */
public SaslStreamClientDriver(SaslHandshakeClientHandler handshakeClient, public SaslStreamClientDriver(SaslHandshakeClientHandler handshakeClient,
AtomicBoolean authenticated, long grpcAuthTimeoutMs) { AtomicBoolean authenticated, UUID channelId, long grpcAuthTimeoutMs) {
mSaslHandshakeClientHandler = handshakeClient; mSaslHandshakeClientHandler = handshakeClient;
mHandshakeFuture = SettableFuture.create(); mHandshakeFuture = SettableFuture.create();
mChannelId = channelId;
mGrpcAuthTimeoutMs = grpcAuthTimeoutMs; mGrpcAuthTimeoutMs = grpcAuthTimeoutMs;
mAuthenticated = authenticated; mAuthenticated = authenticated;
} }
Expand All @@ -88,43 +93,44 @@ public void setServerObserver(StreamObserver<SaslMessage> requestObserver) {
@Override @Override
public void onNext(SaslMessage saslMessage) { public void onNext(SaslMessage saslMessage) {
try { try {
LOG.debug("SaslClientDriver received message: {}", LOG.debug("SaslClientDriver received message: {} for channel: {}", saslMessage, mChannelId);
saslMessage != null ? saslMessage.getMessageType().toString() : "<NULL>");
SaslMessage response = mSaslHandshakeClientHandler.handleSaslMessage(saslMessage); SaslMessage response = mSaslHandshakeClientHandler.handleSaslMessage(saslMessage);
if (response != null) { if (response != null) {
mRequestObserver.onNext(response); mRequestObserver.onNext(response);
} else { } else {
// {@code null} response means server message was a success. // {@code null} response means server message was a success.
mHandshakeFuture.set(true); mHandshakeFuture.set(true);
} }
} catch (SaslException e) { } catch (Exception e) {
LOG.debug("Exception while handling SASL message: {} for channel: {}. Error: {}", saslMessage,
mChannelId, e);
mHandshakeFuture.setException(e); mHandshakeFuture.setException(e);
mRequestObserver mRequestObserver.onError(e);
.onError(Status.fromCode(Status.Code.UNAUTHENTICATED).withCause(e).asException());
} }
} }


@Override @Override
public void onError(Throwable throwable) { public void onError(Throwable throwable) {
LOG.warn("Received error on client driver for channel: {}. Error: {}", mChannelId, throwable);
mHandshakeFuture.setException(throwable); mHandshakeFuture.setException(throwable);
} }


@Override @Override
public void onCompleted() { public void onCompleted() {
LOG.debug("Client authentication closed by server for channel: {}", mChannelId);
// Server completes the stream when authenticated session is terminated/revoked. // Server completes the stream when authenticated session is terminated/revoked.
mAuthenticated.set(false); mAuthenticated.set(false);
} }


/** /**
* Starts authentication with the server and wait until completion. * Starts authentication with the server and wait until completion.
* @param channelId channel that is authenticating with the server
* @throws UnauthenticatedException * @throws UnauthenticatedException
*/ */
public void start(String channelId) throws AlluxioStatusException { public void start() throws AlluxioStatusException {
try { try {
LOG.debug("Starting SASL handshake for ChannelId:{}", channelId); LOG.debug("Starting SASL handshake for channel: {}", mChannelId);
// Send the server initial message. // Send the server initial message.
mRequestObserver.onNext(mSaslHandshakeClientHandler.getInitialMessage(channelId)); mRequestObserver.onNext(mSaslHandshakeClientHandler.getInitialMessage(mChannelId));
// Wait until authentication status changes. // Wait until authentication status changes.
mAuthenticated.set(mHandshakeFuture.get(mGrpcAuthTimeoutMs, TimeUnit.MILLISECONDS)); mAuthenticated.set(mHandshakeFuture.get(mGrpcAuthTimeoutMs, TimeUnit.MILLISECONDS));
} catch (SaslException se) { } catch (SaslException se) {
Expand Down Expand Up @@ -152,6 +158,7 @@ public void start(String channelId) throws AlluxioStatusException {
* Stops authenticated session with the server by releasing the long poll. * Stops authenticated session with the server by releasing the long poll.
*/ */
public void stop() { public void stop() {
LOG.debug("Closing client driver for channel: {}", mChannelId);
try { try {
if (mAuthenticated.get()) { if (mAuthenticated.get()) {
mRequestObserver.onCompleted(); mRequestObserver.onCompleted();
Expand Down
Expand Up @@ -22,6 +22,7 @@


import java.io.IOException; import java.io.IOException;
import java.util.UUID; import java.util.UUID;
import java.util.concurrent.atomic.AtomicBoolean;


import javax.security.sasl.SaslException; import javax.security.sasl.SaslException;


Expand All @@ -30,16 +31,21 @@
*/ */
public class SaslStreamServerDriver implements StreamObserver<SaslMessage> { public class SaslStreamServerDriver implements StreamObserver<SaslMessage> {
private static final Logger LOG = LoggerFactory.getLogger(SaslStreamServerDriver.class); private static final Logger LOG = LoggerFactory.getLogger(SaslStreamServerDriver.class);

/** Used to represent uninitialized channel Id. */
private static final UUID EMPTY_UUID = new UUID(0L, 0L);
/** Client's sasl stream. */ /** Client's sasl stream. */
private StreamObserver<SaslMessage> mRequestObserver = null; private StreamObserver<SaslMessage> mRequestObserver = null;
/** Handshake handler for server. */ /** Handshake handler for server. */
private SaslHandshakeServerHandler mSaslHandshakeServerHandler; private SaslHandshakeServerHandler mSaslHandshakeServerHandler;
/** Authentication server. */ /** Authentication server. */
private AuthenticationServer mAuthenticationServer; private AuthenticationServer mAuthenticationServer;
/** Id for client-side channel that is authenticating. */ /** Id for client-side channel that is authenticating. */
private UUID mChannelId; private UUID mChannelId = EMPTY_UUID;
/** Sasl server handler that will be used for authentication. */ /** Sasl server handler that will be used for authentication. */
private SaslServerHandler mSaslServerHandler = null; private SaslServerHandler mSaslServerHandler = null;
/** Whether client stream observer is still valid. */
private AtomicBoolean mClientStreamValid = new AtomicBoolean(true);


/** /**
* Creates {@link SaslStreamServerDriver} for given {@link AuthenticationServer}. * Creates {@link SaslStreamServerDriver} for given {@link AuthenticationServer}.
Expand All @@ -64,12 +70,11 @@ public void onNext(SaslMessage saslMessage) {
/** Whether to close the handler after this message. */ /** Whether to close the handler after this message. */
boolean closeHandler = false; boolean closeHandler = false;
try { try {
LOG.debug("SaslServerDriver received message: {}",
saslMessage != null ? saslMessage.getMessageType().toString() : "<NULL>");

if (mSaslHandshakeServerHandler == null) { if (mSaslHandshakeServerHandler == null) {
// First message received from the client. // First message received from the client.
// ChannelId and the AuthenticationName will be set only in the first call. // ChannelId and the AuthenticationName will be set only in the first call.
LOG.debug("SaslServerDriver received authentication request of type:{} from channel: {}",
saslMessage.getAuthenticationScheme(), saslMessage.getClientId());
// Initialize this server driver accordingly. // Initialize this server driver accordingly.
mChannelId = UUID.fromString(saslMessage.getClientId()); mChannelId = UUID.fromString(saslMessage.getClientId());
// Get authentication server to create the Sasl handler for requested scheme. // Get authentication server to create the Sasl handler for requested scheme.
Expand All @@ -79,6 +84,8 @@ public void onNext(SaslMessage saslMessage) {
// Unregister from registry if in case it was authenticated before. // Unregister from registry if in case it was authenticated before.
mAuthenticationServer.unregisterChannel(mChannelId); mAuthenticationServer.unregisterChannel(mChannelId);
} }

LOG.debug("SaslServerDriver received message: {} from channel: {}", saslMessage, mChannelId);
// Respond to client. // Respond to client.
SaslMessage response = mSaslHandshakeServerHandler.handleSaslMessage(saslMessage); SaslMessage response = mSaslHandshakeServerHandler.handleSaslMessage(saslMessage);
// Complete the call from server-side before sending success response to client. // Complete the call from server-side before sending success response to client.
Expand All @@ -92,15 +99,18 @@ public void onNext(SaslMessage saslMessage) {
} }
mRequestObserver.onNext(response); mRequestObserver.onNext(response);
} catch (SaslException se) { } catch (SaslException se) {
LOG.debug("Exception while handling SASL message: {}", saslMessage, se); LOG.debug("Exception while handling SASL message: {} for channel: {}. Error: {}", saslMessage,
mChannelId, se);
mRequestObserver.onError(new UnauthenticatedException(se).toGrpcStatusException()); mRequestObserver.onError(new UnauthenticatedException(se).toGrpcStatusException());
closeHandler = true; closeHandler = true;
} catch (UnauthenticatedException ue) { } catch (UnauthenticatedException ue) {
LOG.debug("Exception while handling SASL message: {}", saslMessage, ue); LOG.debug("Exception while handling SASL message: {} for channel: {}. Error: {}", saslMessage,
mChannelId, ue);
mRequestObserver.onError(ue.toGrpcStatusException()); mRequestObserver.onError(ue.toGrpcStatusException());
closeHandler = true; closeHandler = true;
} catch (Exception e) { } catch (Exception e) {
LOG.debug("Exception while handling SASL message: {}", saslMessage, e); LOG.debug("Exception while handling SASL message: {} for channel: {}. Error: {}", saslMessage,
mChannelId, e);
closeHandler = true; closeHandler = true;
throw e; throw e;
} finally { } finally {
Expand All @@ -116,7 +126,11 @@ public void onNext(SaslMessage saslMessage) {


@Override @Override
public void onError(Throwable throwable) { public void onError(Throwable throwable) {
if (mChannelId != null) { LOG.warn("Error received for channel: {}. Error: {}", mChannelId, throwable);
// Error on server invalidates client stream.
mClientStreamValid.set(false);

if (!mChannelId.equals(EMPTY_UUID)) {
LOG.debug("Closing authenticated channel: {} due to error: {}", mChannelId, throwable); LOG.debug("Closing authenticated channel: {} due to error: {}", mChannelId, throwable);
if (!mAuthenticationServer.unregisterChannel(mChannelId)) { if (!mAuthenticationServer.unregisterChannel(mChannelId)) {
// Channel was not registered. Close driver explicitly. // Channel was not registered. Close driver explicitly.
Expand All @@ -140,6 +154,7 @@ public void onCompleted() {
* Closes the authentication stream. * Closes the authentication stream.
*/ */
public void close() { public void close() {
LOG.debug("Closing server driver for channel: {}", mChannelId);
// Complete the client stream. // Complete the client stream.
completeStreamQuietly(); completeStreamQuietly();
// Close handler if not already. // Close handler if not already.
Expand All @@ -148,7 +163,7 @@ public void close() {
mSaslServerHandler.close(); mSaslServerHandler.close();
} catch (Exception exc) { } catch (Exception exc) {
LogUtils.warnWithException(LOG, "Failed to close server driver for channel: {}.", LogUtils.warnWithException(LOG, "Failed to close server driver for channel: {}.",
(mChannelId != null) ? mChannelId : "<NULL>", exc); mChannelId, exc);
} }
} }
} }
Expand All @@ -157,11 +172,14 @@ public void close() {
* Completes the stream with a debug blanket over possible exceptions. * Completes the stream with a debug blanket over possible exceptions.
*/ */
private void completeStreamQuietly() { private void completeStreamQuietly() {
if (mRequestObserver != null) { if (mClientStreamValid.get() && mRequestObserver != null) {
try { try {
mRequestObserver.onCompleted(); mRequestObserver.onCompleted();
} catch (Exception exc) { } catch (Exception exc) {
LOG.debug("Failed to close authentication stream from server.", exc); LOG.debug("Failed to close authentication stream for channel: {}. Error: {}", mChannelId,
exc);
} finally {
mClientStreamValid.set(false);
} }
} }
} }
Expand Down

0 comments on commit b2e7ff6

Please sign in to comment.