Skip to content

Commit

Permalink
fix(iot-deps, iot-service): Fix bugs tied to error reporting for SSLH…
Browse files Browse the repository at this point in the history
…andshake cases

Added end to end tests to ensure that device client, service client, and provisioning device client validate the remote server certificate.
  • Loading branch information
timtay-microsoft committed Sep 19, 2018
1 parent a9c1487 commit 72361ee
Show file tree
Hide file tree
Showing 15 changed files with 528 additions and 95 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,7 @@
import javax.net.ssl.SSLContext;
import java.io.IOException;
import java.nio.BufferOverflowException;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.*;

public class AmqpsConnection extends BaseHandler
{
Expand Down Expand Up @@ -59,7 +56,7 @@ public class AmqpsConnection extends BaseHandler

private AmqpListener msgListener;

private final ObjectLock openLock;
private CountDownLatch openLatch;
private final ObjectLock closeLock;

private SSLContext sslContext;
Expand Down Expand Up @@ -94,7 +91,7 @@ public AmqpsConnection(String hostName, AmqpDeviceOperations amqpDeviceOperation
this.saslListener = new SaslListenerImpl(saslHandler);
}

this.openLock = new ObjectLock();
this.openLatch = new CountDownLatch(1);
this.closeLock = new ObjectLock();

this.sslContext = sslContext;
Expand Down Expand Up @@ -167,10 +164,7 @@ public void open() throws IOException

try
{
synchronized (openLock)
{
openLock.waitLock(MAX_WAIT_TO_OPEN_CLOSE_CONNECTION);
}
openLatch.await(MAX_WAIT_TO_OPEN_CLOSE_CONNECTION, TimeUnit.MILLISECONDS);
}
catch (InterruptedException e)
{
Expand All @@ -179,6 +173,12 @@ public void open() throws IOException
throw new IOException("Waited too long for the connection to open.");
}
}

if (!this.isOpen)
{
throw new IOException("Failed to open the connection");
}

logger.LogDebug("Exited from method %s", logger.getMethodName());
}

Expand All @@ -189,6 +189,9 @@ public void open() throws IOException
public void openAmqpAsync()
{
logger.LogDebug("Entered in method %s", logger.getMethodName());

this.openLatch = new CountDownLatch(1);

if (executorService == null)
{
executorService = Executors.newFixedThreadPool(THREAD_POOL_MAX_NUMBER);
Expand Down Expand Up @@ -412,10 +415,7 @@ public void onLinkRemoteOpen(Event event)
{
msgListener.connectionEstablished();

synchronized (openLock)
{
openLock.notifyLock();
}
openLatch.countDown();
}
}
logger.LogDebug("Exited from method %s", logger.getMethodName());
Expand Down Expand Up @@ -549,6 +549,12 @@ public void onTransportError(Event event)
logger.LogDebug("Exited from method %s", logger.getMethodName());
}

@Override
public void onTransportHeadClosed(Event event)
{
this.openLatch.countDown();
}

/**
* Class which runs the reactor.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -189,28 +189,7 @@ public void OpenThrowsOnWaitLock() throws IOException, InterruptedException

//assert
}

@Test
public void OpenSucceeds() throws IOException, InterruptedException
{
AmqpsConnection amqpsConnection = new AmqpsConnection(TEST_HOST_NAME, mockedProvisionOperations, null, null, false);

new NonStrictExpectations()
{
{
new AmqpReactor((Reactor)any);
result = mockedAmqpReactor;

mockedObjectLock.waitLock(anyLong);
}
};

// Act
amqpsConnection.open();

//assert
}


@Test (expected = IOException.class)
public void closeThrowsOnWaitLock() throws IOException, InterruptedException
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,16 @@ public static String retrieveEnvironmentVariableValue(String environmentVariable

return environmentVariableValue;
}

/**
* Checks if the provided exception contains a certain type of exception in its cause chain
* @param possibleExceptionCause the type of exception to be searched for
* @param exceptionToSearch the exception to search the stacktrace of
* @return if any variant of the possibleExceptionCause is found at any depth of the exception cause chain
*/
public static boolean isCause(Class<? extends Throwable> possibleExceptionCause, Throwable exceptionToSearch)
{
return possibleExceptionCause.isInstance(exceptionToSearch) || (exceptionToSearch != null && isCause(possibleExceptionCause, exceptionToSearch.getCause()));
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import tests.integration.com.microsoft.azure.sdk.iot.helpers.Tools;
import tests.integration.com.microsoft.azure.sdk.iot.helpers.X509Cert;

import javax.net.ssl.SSLHandshakeException;
import java.io.IOException;
import java.net.URISyntaxException;
import java.security.GeneralSecurityException;
Expand All @@ -45,6 +46,7 @@
import static com.microsoft.azure.sdk.iot.common.iothubservices.IotHubServicesCommon.sendMessagesExpectingUnrecoverableConnectionLossAndTimeout;
import static com.microsoft.azure.sdk.iot.device.IotHubClientProtocol.*;
import static com.microsoft.azure.sdk.iot.service.auth.AuthenticationType.*;
import static junit.framework.TestCase.assertTrue;
import static junit.framework.TestCase.fail;
import static tests.integration.com.microsoft.azure.sdk.iot.helpers.SasTokenGenerator.generateSasTokenForIotDevice;

Expand Down Expand Up @@ -808,6 +810,66 @@ public void sendMessagesWithTcpConnectionDropNotifiesUserIfRetryExpires() throws
testInstance.client.setRetryPolicy(new ExponentialBackoffWithJitter());
}

@Test
public void deviceClientVerifiesRemoteServerCertificate() throws URISyntaxException, IOException, InterruptedException
{

String connectionStringToUntrustwortyDevice = Tools.retrieveEnvironmentVariableValue("IOTHUB_DEVICE_CONN_STRING_INVALIDCERT");
DeviceClient client = new DeviceClient(connectionStringToUntrustwortyDevice, testInstance.protocol);

//HTTP has a separate test because it's open call won't trigger any exception
if (testInstance.protocol != HTTPS)
{
boolean expectedExceptionThrown = false;

try
{
client.open();
}
catch (Exception e)
{
if (testInstance.protocol == MQTT || testInstance.protocol == MQTT_WS)
{
if (Tools.isCause(SSLHandshakeException.class, e))
{
expectedExceptionThrown = true;
}
else
{
fail("expected SSLHandshakeException, but got " + e.getCause().getCause().getCause().getClass());
}
}
else //AMQPS or AMQPS_WS
{
//no way to verify that thrown exception was due to SSLHandshakeException since Proton-j only logs that exception instead of throwing.
expectedExceptionThrown = true;
}
}

assertTrue("Expected SSLHandshakeException, but no exception was thrown", expectedExceptionThrown);
}
else
{
deviceClientVerifiesRemoteServerCertificateHttp(client);
}
}


private void deviceClientVerifiesRemoteServerCertificateHttp(DeviceClient client) throws URISyntaxException, IOException, InterruptedException
{
client.open();
Success success = new Success();
client.sendEventAsync(new Message("asdf"), new EventCallback(IotHubStatusCode.ERROR), success);
while (!success.wasCallbackFired())
{
Thread.sleep(200);
}

assertTrue("Expected message callback of ERROR, but got " + success.getCallbackStatusCode(), success.getResult());

client.closeNow();
}

private void errorInjectionTestFlowNoDisconnect(Message errorInjectionMessage, IotHubStatusCode expectedStatus, boolean noRetry) throws IOException, IotHubException, URISyntaxException, InterruptedException
{
// Arrange
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@
import tests.integration.com.microsoft.azure.sdk.iot.helpers.Tools;
import tests.integration.com.microsoft.azure.sdk.iot.helpers.X509Cert;

import javax.net.ssl.SSLHandshakeException;
import java.io.IOException;
import java.lang.reflect.Array;
import java.util.*;

import static com.microsoft.azure.sdk.iot.provisioning.device.ProvisioningDeviceClientStatus.PROVISIONING_DEVICE_STATUS_ASSIGNED;
Expand All @@ -53,9 +53,15 @@ public class ProvisioningClientIT
private static final String DPS_CONNECTION_STRING_ENV_VAR_NAME = "IOT_DPS_CONNECTION_STRING";
private static String provisioningServiceConnectionString = "";

private static final String DPS_CONNECTION_STRING_WITH_INVALID_CERT_ENV_VAR_NAME = "PROVISIONING_CONNECTION_STRING_INVALIDCERT";
private static String provisioningServiceWithInvalidCertConnectionString = "";

private static final String DPS_GLOBAL_ENDPOINT_ENV_VAR_NAME = "IOT_DPS_GLOBAL_ENDPOINT";
private static String provisioningServiceGlobalEndpoint = "";

private static final String DPS_GLOBAL_ENDPOINT_WITH_INVALID_CERT_ENV_VAR_NAME = "DPS_GLOBALDEVICEENDPOINT_INVALIDCERT";
private static String provisioningServiceGlobalEndpointWithInvalidCert = "";

private static final String DPS_ID_SCOPE_ENV_VAR_NAME = "IOT_DPS_ID_SCOPE";
private static String provisioningServiceIdScope = "";

Expand Down Expand Up @@ -127,9 +133,12 @@ public void setUp() throws Exception
provisioningServiceGlobalEndpoint = Tools.retrieveEnvironmentVariableValue(DPS_GLOBAL_ENDPOINT_ENV_VAR_NAME);
provisioningServiceIdScope = Tools.retrieveEnvironmentVariableValue(DPS_ID_SCOPE_ENV_VAR_NAME);
tpmSimulatorIpAddress = Tools.retrieveEnvironmentVariableValue(TPM_SIMULATOR_IP_ADDRESS_ENV_NAME);
provisioningServiceGlobalEndpointWithInvalidCert = Tools.retrieveEnvironmentVariableValue(DPS_GLOBAL_ENDPOINT_WITH_INVALID_CERT_ENV_VAR_NAME);
provisioningServiceWithInvalidCertConnectionString = Tools.retrieveEnvironmentVariableValue(DPS_CONNECTION_STRING_WITH_INVALID_CERT_ENV_VAR_NAME);

provisioningServiceClient =
ProvisioningServiceClient.createFromConnectionString(provisioningServiceConnectionString);

registryManager = RegistryManager.createFromConnectionString(iotHubConnectionString);

for (int i = 0; i < IOTHUB_NUM_OF_MESSAGES_TO_SEND; i++)
Expand Down Expand Up @@ -180,7 +189,7 @@ public void run(ProvisioningDeviceClientRegistrationResult provisioningDeviceCli
}
}

private void waitForRegistrationCallback(ProvisioningStatus provisioningStatus) throws InterruptedException
private void waitForRegistrationCallback(ProvisioningStatus provisioningStatus) throws Exception
{
long startTime = System.currentTimeMillis();
while (provisioningStatus.provisioningDeviceClientRegistrationInfoClient.getProvisioningDeviceClientStatus() != PROVISIONING_DEVICE_STATUS_ASSIGNED)
Expand All @@ -191,7 +200,7 @@ private void waitForRegistrationCallback(ProvisioningStatus provisioningStatus)
{
provisioningStatus.exception.printStackTrace();
System.out.println("Registration error, bailing out");
throw new InterruptedException(provisioningStatus.exception.getMessage());
throw new Exception(provisioningStatus.exception);
}
System.out.println("Waiting for Provisioning Service to register");

Expand All @@ -208,11 +217,11 @@ private void waitForRegistrationCallback(ProvisioningStatus provisioningStatus)
assertNotNull(provisioningStatus.provisioningDeviceClientRegistrationInfoClient.getDeviceId());
}

private ProvisioningStatus registerDevice(ProvisioningDeviceClientTransportProtocol protocol, SecurityProvider securityProvider) throws ProvisioningDeviceClientException
private ProvisioningStatus registerDevice(ProvisioningDeviceClientTransportProtocol protocol, SecurityProvider securityProvider, String globalEndpoint) throws ProvisioningDeviceClientException
{
ProvisioningStatus provisioningStatus = new ProvisioningStatus();

ProvisioningDeviceClient provisioningDeviceClient = ProvisioningDeviceClient.create(provisioningServiceGlobalEndpoint, provisioningServiceIdScope,
ProvisioningDeviceClient provisioningDeviceClient = ProvisioningDeviceClient.create(globalEndpoint, provisioningServiceIdScope,
protocol,
securityProvider);
provisioningStatus.provisioningDeviceClient = provisioningDeviceClient;
Expand Down Expand Up @@ -386,7 +395,7 @@ public void individualEnrollmentTPMSimulator() throws Exception
assertEquals(TEST_VALUE_DP, individualEnrollmentResult.getInitialTwin().getDesiredProperty().get(TEST_KEY_DP));

// Register device
ProvisioningStatus provisioningStatus = registerDevice(testInstance.protocol, securityProviderTPMEmulator);
ProvisioningStatus provisioningStatus = registerDevice(testInstance.protocol, securityProviderTPMEmulator, provisioningServiceGlobalEndpoint);
waitForRegistrationCallback(provisioningStatus);
provisioningStatus.provisioningDeviceClient.closeNow();

Expand Down Expand Up @@ -449,7 +458,7 @@ public void individualEnrollmentX509() throws Exception
assertEquals(TEST_VALUE_DP, individualEnrollmentResult.getInitialTwin().getDesiredProperty().get(TEST_KEY_DP));

// Register device
ProvisioningStatus provisioningStatus = registerDevice(testInstance.protocol, securityProviderX509);
ProvisioningStatus provisioningStatus = registerDevice(testInstance.protocol, securityProviderX509, provisioningServiceGlobalEndpoint);
waitForRegistrationCallback(provisioningStatus);
provisioningStatus.provisioningDeviceClient.closeNow();

Expand Down Expand Up @@ -482,6 +491,82 @@ public void individualEnrollmentX509() throws Exception

}

@Test (timeout = OVERALL_TEST_TIMEOUT)
public void individualEnrollmentWithInvalidRemoteServerCertificateFails() throws Exception
{
boolean expectedExceptionEncountered = false;
String registrationId = REGISTRATION_ID_X509_PREFIX + UUID.randomUUID().toString();
X509Cert certs = new X509Cert(0, false, registrationId, null);
final String leafPublicPem = certs.getPublicCertLeafPem();
String leafPrivateKey = certs.getPrivateKeyLeafPem();
Collection<String> signerCertificates = new LinkedList<>();
SecurityProvider securityProviderX509 = new SecurityProviderX509Cert(leafPublicPem, leafPrivateKey, signerCertificates);

// Create a device with Zero Root, Zero Intermediate and 1 leaf
String deviceID = String.format(DEVICE_ID_X509_PREFIX, "R0-I0-L1") + UUID.randomUUID().toString();

// setup service client with a unique registration id
assertEquals(registrationId, securityProviderX509.getRegistrationId());

//
TwinCollection tags = new TwinCollection();
final String TEST_KEY_TAG = "testTag";
final String TEST_VALUE_TAG = "testValue";
tags.put(TEST_KEY_TAG, TEST_VALUE_TAG);

final String TEST_KEY_DP = "testDP";
final String TEST_VALUE_DP = "testDPValue";
TwinCollection desiredProperties = new TwinCollection();
desiredProperties.put(TEST_KEY_DP, TEST_VALUE_DP);

TwinState twinState = new TwinState(tags, desiredProperties);
IndividualEnrollment individualEnrollmentResult = createIndividualEnrollmentX509(leafPublicPem, registrationId, deviceID, twinState);

assertNotNull(individualEnrollmentResult.getInitialTwin());
assertEquals(TEST_VALUE_TAG, individualEnrollmentResult.getInitialTwin().getTags().get(TEST_KEY_TAG));
assertEquals(TEST_VALUE_DP, individualEnrollmentResult.getInitialTwin().getDesiredProperty().get(TEST_KEY_DP));

// Register device
try
{
ProvisioningStatus provisioningStatus = registerDevice(testInstance.protocol, securityProviderX509, provisioningServiceGlobalEndpointWithInvalidCert);
waitForRegistrationCallback(provisioningStatus);
}
catch (Exception e)
{
if (testInstance.protocol == HTTPS)
{
//SSLHandshakeException is buried in the message, not the cause, for HTTP
if (e.getMessage().contains("SSLHandshakeException"))
{
expectedExceptionEncountered = true;
}
else
{
fail("Expected an SSLHandshakeException, but received " + e.getMessage());
}
}
else if (testInstance.protocol == MQTT || testInstance.protocol == MQTT_WS)
{
if (Tools.isCause(SSLHandshakeException.class, e))
{
expectedExceptionEncountered = true;
}
else
{
fail("Expected an SSLHandshakeException, but received " + e.getMessage());
}
}
else //amqp and amqps_ws
{
//Exception will never have any hint that it was due to SSL failure since proton-j only logs this issue, and closes the transport head.
expectedExceptionEncountered = true;
}
}

assertTrue("Expected an exception to be thrown due to invalid server certificates", expectedExceptionEncountered);
}

// Following test are defined by Provisioning Spec (currently not implemented)
@Ignore
@Test
Expand Down
Loading

0 comments on commit 72361ee

Please sign in to comment.