Skip to content

Commit

Permalink
Add tracing for channel pooling/authentication (#8382)
Browse files Browse the repository at this point in the history
* Add tracing for channel pooling/authentication
* Trace channel authentication failure with target host
* Repackage channel authentication error with inner cause
  • Loading branch information
ggezer committed Feb 12, 2019
1 parent f42860e commit dfe2bb0
Show file tree
Hide file tree
Showing 5 changed files with 136 additions and 35 deletions.
Expand Up @@ -157,6 +157,54 @@ public static AlluxioStatusException from(Status status, String m) {
}
}

/**
* Converts an Alluxio exception from status and message representation to native representation.
* The status must not be null or {@link Status#OK}.
*
* @param status the status
* @param m the message
* @param cause the cause
* @return an {@link AlluxioStatusException} for the given status and message
*/
public static AlluxioStatusException from(Status status, String m, Throwable cause) {
Preconditions.checkNotNull(status, "status");
Preconditions.checkArgument(status != Status.OK, "OK is not an error status");
switch (status) {
case CANCELED:
return new CanceledException(m, cause);
case INVALID_ARGUMENT:
return new InvalidArgumentException(m, cause);
case DEADLINE_EXCEEDED:
return new DeadlineExceededException(m, cause);
case NOT_FOUND:
return new NotFoundException(m, cause);
case ALREADY_EXISTS:
return new AlreadyExistsException(m, cause);
case PERMISSION_DENIED:
return new PermissionDeniedException(m, cause);
case UNAUTHENTICATED:
return new UnauthenticatedException(m, cause);
case RESOURCE_EXHAUSTED:
return new ResourceExhaustedException(m, cause);
case FAILED_PRECONDITION:
return new FailedPreconditionException(m, cause);
case ABORTED:
return new AbortedException(m, cause);
case OUT_OF_RANGE:
return new OutOfRangeException(m, cause);
case UNIMPLEMENTED:
return new UnimplementedException(m, cause);
case INTERNAL:
return new InternalException(m, cause);
case UNAVAILABLE:
return new UnavailableException(m, cause);
case DATA_LOSS:
return new DataLossException(m, cause);
default:
return new UnknownException(m, cause);
}
}

/**
* Converts checked throwables to Alluxio status exceptions. Unchecked throwables should not be
* passed to this method. Use Throwables.propagateIfPossible before passing a Throwable to this
Expand Down
34 changes: 28 additions & 6 deletions core/common/src/main/java/alluxio/grpc/GrpcManagedChannelPool.java
Expand Up @@ -76,7 +76,16 @@ public GrpcManagedChannelPool() {
mLock = new ReentrantReadWriteLock(true);
}

private void shutdownManagedChannel(ManagedChannel managedChannel, long shutdownTimeoutMs) {
/**
* Shuts down the managed channel for given key.
*
* (Should be called with {@code mLock} acquired.)
*
* @param channelKey channel key
* @param shutdownTimeoutMs shutdown timeout in miliseconds
*/
private void shutdownManagedChannel(ChannelKey channelKey, long shutdownTimeoutMs) {
ManagedChannel managedChannel = mChannels.get(channelKey).get();
managedChannel.shutdown();
try {
managedChannel.awaitTermination(shutdownTimeoutMs, TimeUnit.MILLISECONDS);
Expand All @@ -87,6 +96,7 @@ private void shutdownManagedChannel(ManagedChannel managedChannel, long shutdown
managedChannel.shutdownNow();
}
Verify.verify(managedChannel.isShutdown());
LOG.debug("Shut down managed channel. ChannelKey: {}", channelKey);
}

private boolean waitForChannelReady(ManagedChannel managedChannel, long healthCheckTimeoutMs) {
Expand Down Expand Up @@ -127,8 +137,10 @@ public ManagedChannel acquireManagedChannel(ChannelKey channelKey, long healthCh
try (LockResource lockShared = new LockResource(mLock.readLock())) {
if (mChannels.containsKey(channelKey)) {
ManagedChannelReference managedChannelRef = mChannels.get(channelKey);
if (waitForChannelReady(mChannels.get(channelKey).get(),
if (waitForChannelReady(managedChannelRef.get(),
healthCheckTimeoutMs)) {
LOG.debug("Acquiring an existing managed channel. ChannelKey: {}. Ref-count: {}",
channelKey, managedChannelRef.getRefCount());
return managedChannelRef.reference();
} else {
// Postpone channel shutdown under exclusive lock below.
Expand All @@ -140,11 +152,16 @@ public ManagedChannel acquireManagedChannel(ChannelKey channelKey, long healthCh
// Dispose existing channel if required.
int existingRefCount = 0;
if (shutdownExistingChannel && mChannels.containsKey(channelKey)) {
shutdownManagedChannel(mChannels.get(channelKey).get(), shutdownTimeoutMs);
existingRefCount = mChannels.get(channelKey).getRefCount();
LOG.debug(
"Shutting down an existing unhealthy managed channel. ChannelKey: {}. Existing Ref-count: {}",
channelKey, existingRefCount);
shutdownManagedChannel(channelKey, shutdownTimeoutMs);
mChannels.remove(channelKey);
}
if (!mChannels.containsKey(channelKey)) {
LOG.debug("Creating a new managed channel. ChannelKey: {}. Ref-count:{}", channelKey,
existingRefCount);
mChannels.put(channelKey,
new ManagedChannelReference(createManagedChannel(channelKey), existingRefCount));
}
Expand All @@ -163,15 +180,18 @@ public void releaseManagedChannel(ChannelKey channelKey, long shutdownTimeoutMs)
boolean shutdownManagedChannel;
try (LockResource lockShared = new LockResource(mLock.readLock())) {
Verify.verify(mChannels.containsKey(channelKey));
mChannels.get(channelKey).dereference();
shutdownManagedChannel = mChannels.get(channelKey).getRefCount() <= 0;
ManagedChannelReference channelRef = mChannels.get(channelKey);
channelRef.dereference();
shutdownManagedChannel = channelRef.getRefCount() <= 0;
LOG.debug("Released managed channel for: {}. Ref-count: {}", channelKey,
channelRef.getRefCount());
}
if (shutdownManagedChannel) {
try (LockResource lockExclusive = new LockResource(mLock.writeLock())) {
if (mChannels.containsKey(channelKey)) {
ManagedChannelReference channelRef = mChannels.get(channelKey);
if (channelRef.getRefCount() <= 0) {
shutdownManagedChannel(mChannels.remove(channelKey).get(), shutdownTimeoutMs);
shutdownManagedChannel(channelKey, shutdownTimeoutMs);
}
}
}
Expand All @@ -181,6 +201,8 @@ public void releaseManagedChannel(ChannelKey channelKey, long shutdownTimeoutMs)
/**
* Creates a {@link ManagedChannel} by given pool key.
*
* (Should be called with {@code mLock} acquired.)
*
* @param channelKey channel pool key
* @return the created channel
*/
Expand Down
Expand Up @@ -15,6 +15,7 @@
import alluxio.conf.PropertyKey;
import alluxio.exception.status.AlluxioStatusException;
import alluxio.exception.status.UnauthenticatedException;
import alluxio.exception.status.UnknownException;
import alluxio.grpc.SaslAuthenticationServiceGrpc;
import alluxio.grpc.SaslMessage;
import alluxio.util.SecurityUtils;
Expand All @@ -26,6 +27,8 @@
import io.grpc.ManagedChannel;
import io.grpc.netty.NettyChannelBuilder;
import io.grpc.stub.StreamObserver;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import javax.security.auth.Subject;
import javax.security.sasl.SaslClient;
Expand All @@ -38,7 +41,7 @@
* Used to authenticate with the target host. Used internally by {@link GrpcChannelBuilder}.
*/
public class ChannelAuthenticator {

private static final Logger LOG = LoggerFactory.getLogger(ChannelAuthenticator.class);
/** Whether to use mParentSubject as authentication user. */
protected boolean mUseSubject;
/** Subject for authentication. */
Expand Down Expand Up @@ -106,39 +109,53 @@ public ChannelAuthenticator(String userName, String password, String impersonati
*/
public Channel authenticate(ManagedChannel managedChannel, AlluxioConfiguration conf)
throws AlluxioStatusException {
LOG.debug("Channel authentication initiated. ChannelId:{}, AuthType:{}, Target:{}", mChannelId,
mAuthType, managedChannel.authority());

if (mAuthType == AuthType.NOSASL) {
return managedChannel;
}

// Create a channel for talking with target host's authentication service.
// Create SaslClient for authentication based on provided credentials.
SaslClient saslClient;
if (mUseSubject) {
saslClient =
SaslParticipantProvider.Factory.create(mAuthType).createSaslClient(mParentSubject, conf);
} else {
saslClient = SaslParticipantProvider.Factory.create(mAuthType).createSaslClient(mUserName,
mPassword, mImpersonationUser);
try {
// Create a channel for talking with target host's authentication service.
// Create SaslClient for authentication based on provided credentials.
SaslClient saslClient;
if (mUseSubject) {
saslClient = SaslParticipantProvider.Factory.create(mAuthType)
.createSaslClient(mParentSubject, conf);
} else {
saslClient = SaslParticipantProvider.Factory.create(mAuthType).createSaslClient(mUserName,
mPassword, mImpersonationUser);
}

// Create authentication scheme specific handshake handler.
SaslHandshakeClientHandler handshakeClient =
SaslHandshakeClientHandler.Factory.create(mAuthType, saslClient);
// Create driver for driving sasl traffic from client side.
SaslStreamClientDriver clientDriver =
new SaslStreamClientDriver(handshakeClient, mGrpcAuthTimeoutMs);
// Start authentication call with the service and update the client driver.
StreamObserver<SaslMessage> requestObserver =
SaslAuthenticationServiceGrpc.newStub(managedChannel).authenticate(clientDriver);
clientDriver.setServerObserver(requestObserver);
// Start authentication traffic with the target.
clientDriver.start(mChannelId.toString());
// Authentication succeeded!
// Attach scheme specific interceptors to the channel.

Channel authenticatedChannel =
ClientInterceptors.intercept(managedChannel, getInterceptors(saslClient));
return authenticatedChannel;
} catch (Exception exc) {
String message = String.format(
"Channel authentication failed. ChannelId: %s, AuthType: %s, Target: %s, Error: %s",
mChannelId, mAuthType, managedChannel.authority(), exc.toString());
if (exc instanceof AlluxioStatusException) {
throw AlluxioStatusException.from(((AlluxioStatusException) exc).getStatus(), message, exc);
} else {
throw new UnknownException(message, exc);
}
}

// Create authentication scheme specific handshake handler.
SaslHandshakeClientHandler handshakeClient =
SaslHandshakeClientHandler.Factory.create(mAuthType, saslClient);
// Create driver for driving sasl traffic from client side.
SaslStreamClientDriver clientDriver =
new SaslStreamClientDriver(handshakeClient, mGrpcAuthTimeoutMs);
// Start authentication call with the service and update the client driver.
StreamObserver<SaslMessage> requestObserver =
SaslAuthenticationServiceGrpc.newStub(managedChannel).authenticate(clientDriver);
clientDriver.setServerObserver(requestObserver);
// Start authentication traffic with the target.
clientDriver.start(mChannelId.toString());
// Authentication succeeded!
// Attach scheme specific interceptors to the channel.

Channel authenticatedChannel =
ClientInterceptors.intercept(managedChannel, getInterceptors(saslClient));
return authenticatedChannel;
}

/**
Expand Down
Expand Up @@ -22,6 +22,8 @@
import io.grpc.Status;
import io.grpc.StatusRuntimeException;
import io.grpc.stub.StreamObserver;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import javax.security.sasl.SaslException;
import java.util.concurrent.ExecutionException;
Expand All @@ -32,6 +34,7 @@
* Responsible for driving sasl traffic from client-side. Acts as a client's Sasl stream.
*/
public class SaslStreamClientDriver implements StreamObserver<SaslMessage> {
private static final Logger LOG = LoggerFactory.getLogger(SaslStreamClientDriver.class);
/** Server's sasl stream. */
private StreamObserver<SaslMessage> mRequestObserver;
/** Handshake handler for client. */
Expand Down Expand Up @@ -66,6 +69,8 @@ public void setServerObserver(StreamObserver<SaslMessage> requestObserver) {
@Override
public void onNext(SaslMessage saslMessage) {
try {
LOG.debug("SaslClientDriver received message: {}",
saslMessage != null ? saslMessage.getMessageType().toString() : "<NULL>");
SaslMessage response = mSaslHandshakeClientHandler.handleSaslMessage(saslMessage);
if (response == null) {
mRequestObserver.onCompleted();
Expand Down Expand Up @@ -96,6 +101,7 @@ public void onCompleted() {
*/
public void start(String channelId) throws AlluxioStatusException {
try {
LOG.debug("Starting SASL handshake for ChannelId:{}", channelId);
// Send the server initial message.
mRequestObserver.onNext(mSaslHandshakeClientHandler.getInitialMessage(channelId));
// Wait until authentication status changes.
Expand Down
Expand Up @@ -17,6 +17,8 @@
import alluxio.grpc.SaslMessage;

import io.grpc.stub.StreamObserver;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import javax.security.sasl.SaslException;
import javax.security.sasl.SaslServer;
Expand All @@ -26,6 +28,7 @@
* Responsible for driving sasl traffic from server-side. Acts as a server's Sasl stream.
*/
public class SaslStreamServerDriver implements StreamObserver<SaslMessage> {
private static final Logger LOG = LoggerFactory.getLogger(SaslStreamServerDriver.class);
/** Client's sasl stream. */
private StreamObserver<SaslMessage> mRequestObserver = null;
/** Handshake handler for server. */
Expand Down Expand Up @@ -62,12 +65,17 @@ public void setClientObserver(StreamObserver<SaslMessage> requestObserver) {
@Override
public void onNext(SaslMessage saslMessage) {
try {
LOG.debug("SaslServerDriver received message: {}",
saslMessage != null ? saslMessage.getMessageType().toString() : "<NULL>");

if (mSaslHandshakeServerHandler == null) {
// First message received from the client.
// ChannelId and the AuthenticationName will be set only in the first call.
// Initialize this server driver accordingly.
mChannelId = UUID.fromString(saslMessage.getClientId());
AuthType authType = AuthType.valueOf(saslMessage.getAuthenticationName());
LOG.debug("SaslServerDriver received authentication request. ChannelId: {}, AuthType: {}",
mChannelId, authType);
// TODO(ggezer) wire server name?
mSaslServer =
SaslParticipantProvider.Factory.create(authType).createSaslServer("localhost",
Expand Down

0 comments on commit dfe2bb0

Please sign in to comment.